在tensorflow.keras的正式文件中,
validation_data可以是:Numpy数组的元组(x_val,y_val)或Numpy数组数据集的张量元组(x_val,y_val,val_sample_weights)对于前两种情况,必须提供batch_size。对于最后一种情况,可以提供validation_steps。
它没有提到生成器是否可以充当validation_data。所以我想知道validation_data是否可以作为数据生成器?类似于以下代码:
net.fit_generator(train_it.generator(), epoch_iterations * batch_size, nb_epoch=nb_epoch, verbose=1,
validation_data=val_it.generator(), nb_val_samples=3,
callbacks=[checker, tb, stopper, saver])
更新:在keras的正式文件中,内容相同,但又添加了另一句话:
- 数据集或数据集迭代器
考虑到
数据集对于前两种情况,必须提供batch_size。对于最后一种情况,可以提供validation_steps。
我认为应该有3种情况。Keras的文件是正确的。因此,我将在tensorflow.keras中发布问题以更新文档。
是的,可以,这很奇怪,它不在文档中,但其工作原理与x
参数完全相同,您也可以使用a keras.Sequence
或a generator
。在我的项目中,我经常使用keras.Sequence
它,就像发电机一样
最小工作示例,表明它可以工作:
import numpy as np
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Flatten
def generator(batch_size): # Create empty arrays to contain batch of features and labels
batch_features = np.zeros((batch_size, 1000))
batch_labels = np.zeros((batch_size,1))
while True:
for i in range(batch_size):
yield batch_features, batch_labels
model = Sequential()
model.add(Dense(125, input_shape=(1000,), activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
train_generator = generator(64)
validation_generator = generator(64)
model.fit(train_generator, validation_data=validation_generator, validation_steps=100, epochs=100, steps_per_epoch=100)
100/100 [=============================]-1s 13ms / step-损耗:0.6689-精度:1.0000-val_loss :0.6448-val_accuracy:1.000时间段2/100 100/100 [==============================]-0s 4ms / step -损失:0.6223-准确性:1.000-val损失:0.6000-val_accuracy:1.0000时期3/100 100/100 [========================= ====]-0s 4ms /步-损耗:0.5792-精度:1.0000-val_loss:0.5586-val_accuracy:1.0000时元4/100 100/100 [================= =============]-0s 4ms / step-损耗:0.5393-精度:1.0000-val_loss:0.5203-val_accuracy:1.0000
validation_data可以是GeneratorEnqueuer吗?从源代码tensorflow.keras我发现validation_data可能是
iterator_ops.Iterator
,iterator_ops.OwnedIterator
,dataset_ops.DatasetV2
,data_utils.Sequence
不完全是,您需要
GeneratorEnqueuer
在生成器(ge = GeneratorEnqueuer(generator(64))
)上启动,然后start
在您的工作人员或其他任何对象上启动,ge.start()
最后get
将生成的生成器output_generator = ge.get()
放置在模型上fit
。fit(output_generator,epochs = 5,steps_per_epoch = 5)正如您所说,我知道train_dataset可以是GeneratorEnqueuer。我在这里问的是,validation_dataset是否可以是GeneratorEnqueuer?我尝试了一下,发现对于train_data来说,它可以工作。但是对于validation_data,它不起作用。