Warm tip: This article is reproduced from stackoverflow.com, please click
tensorflow tensorflow-datasets tensorflow2.0

Tensorflow dataset.shuffle seems not shuffle without repeat()

发布于 2020-03-27 10:28:02

My code has similar pattern with tensorflow 2.0 tutorial. I want my dataset object to reshuffle in every epochs.

dataset = tf.data.Dataset.from_tensor_slices(['a','b','c','d'])
dataset = dataset.shuffle(100)

for epoch in range(10):
    for d in dataset:
        print(d)

Result:

tf.Tensor(b'c', shape=(), dtype=string)
tf.Tensor(b'a', shape=(), dtype=string)
tf.Tensor(b'b', shape=(), dtype=string)
tf.Tensor(b'd', shape=(), dtype=string)
tf.Tensor(b'c', shape=(), dtype=string)
tf.Tensor(b'a', shape=(), dtype=string)
tf.Tensor(b'b', shape=(), dtype=string)
tf.Tensor(b'd', shape=(), dtype=string)
...

It seems the dataset doesn't shuffle for each epoch. Should I call .shuffle() for each epoch?

Questioner
EyesBear
Viewed
124
nessuno 2019-07-03 23:13

Yes, you should call .shuffle during the inner loop. Moreover, it is better to do not mix python code and TensorFlow code when pure tf.* method equivalent to the Python statements are available.

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices(["a", "b", "c", "d"])
# dataset = dataset.shuffle(2)


@tf.function
def loop():
    for epoch in tf.range(10):
        for d in dataset.shuffle(2):
            tf.print(d)


loop()

The loop call produces the different values every time (and tf.print prints the content of the tf.Tensor, differently from print that prints the object).