我在keras中的CNN代码如下:
from keras.models import Sequential
from keras.layers import Convolution2D
from keras.layers import MaxPooling2D
from keras.layers import Flatten
from keras.layers import Dense
from keras.layers import Dropout
classifier = Sequential()
#1st Conv layer
classifier.add(Convolution2D(64, (9, 9), input_shape=(64, 64, 3), activation='relu'))
classifier.add(MaxPooling2D(pool_size=(4,4)))
#2nd Conv layer
classifier.add(Convolution2D(32, (3, 3), activation='relu'))
classifier.add(MaxPooling2D(pool_size=(2,2)))
#Flattening
classifier.add(Flatten())
# Step 4 - Full connection
classifier.add(Dense(units = 128, activation = 'relu'))
classifier.add(Dropout(0.1))
classifier.add(Dense(units = 128, activation = 'relu'))
classifier.add(Dropout(0.2))
classifier.add(Dense(units = 128, activation = 'relu'))
classifier.add(Dense(units = 2, activation = 'softmax'))
classifier.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])
#Fitting dataset
from keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(rescale = 1./255,
shear_range = 0.2,
zoom_range = 0.2,
horizontal_flip = True)
test_datagen = ImageDataGenerator(rescale = 1./255)
training_set = train_datagen.flow_from_directory('dataset/training_set',
target_size = (64, 64),
batch_size = 32,
class_mode = 'categorical')
test_set = test_datagen.flow_from_directory('dataset/test_set',
target_size = (64, 64),
batch_size = 32,
class_mode = 'categorical')
classifier.fit_generator(
training_set,
steps_per_epoch=(1341+3875)/32,
epochs=15,
validation_data=test_set,
validation_steps=(234+390)/32)
无论在哪里看到sklearn.metrics中使用roc_curve的地方,它都带有x_train,y_train,x_test,y_test之类的参数,我知道它们可以是pandas DataFrames,但就我而言并非如此。我如何绘制ROC曲线并获得AUC分数,用于像这样的CNN模型训练?
我知道了 我所要做的就是匹配来自获得preds的数据类型preds = classifier.predict(test_set)
与我从拿到true_labels labels = test_set
。Preds基本上是一个numpy.ndarray,包含具有np.float32值的单个元素列表。将标签转换为相同的格式和形状后,即可完成roc_curve的工作。
另外,我必须在其中添加第三个变量阈值,fpr, tpr, threshold = roc_curve(true_labels, preds)
以免出现ValueError:弹出太多无法解包错误的值。