My code has similar pattern with tensorflow 2.0 tutorial. I want my dataset object to reshuffle in every epochs.
dataset = tf.data.Dataset.from_tensor_slices(['a','b','c','d'])
dataset = dataset.shuffle(100)
for epoch in range(10):
for d in dataset:
print(d)
Result:
tf.Tensor(b'c', shape=(), dtype=string)
tf.Tensor(b'a', shape=(), dtype=string)
tf.Tensor(b'b', shape=(), dtype=string)
tf.Tensor(b'd', shape=(), dtype=string)
tf.Tensor(b'c', shape=(), dtype=string)
tf.Tensor(b'a', shape=(), dtype=string)
tf.Tensor(b'b', shape=(), dtype=string)
tf.Tensor(b'd', shape=(), dtype=string)
...
It seems the dataset doesn't shuffle for each epoch. Should I call .shuffle() for each epoch?
Yes, you should call .shuffle
during the inner loop. Moreover, it is better to do not mix python code and TensorFlow code when pure tf.* method equivalent to the Python statements are available.
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices(["a", "b", "c", "d"])
# dataset = dataset.shuffle(2)
@tf.function
def loop():
for epoch in tf.range(10):
for d in dataset.shuffle(2):
tf.print(d)
loop()
The loop call produces the different values every time (and tf.print
prints the content of the tf.Tensor
, differently from print
that prints the object).