Warm tip: This article is reproduced from serverfault.com, please click

Keras: how to modify input of keras model while training for every batch

发布于 2020-12-02 06:20:22

I am trying to train a neural network with keras and tensorflow. And I would like to randomly zoom in or out of the images in a batch by the same amount. The following code works, but the architecture remains fixed after the first batch, i.e. the value of rand_int and p remain fixed, and do not change randomly for consecutive batches.

X_input = layers.Input((240, 240, 1), name='input')
    
          
def field_of_view(X_input):
    rand_int = randrange(2)

    if rand_int == 0:
        # Add padding to batch of images and resize to 240x240. This will effectively scale down the image.
        p = randrange(40)
        paddings = tf.constant([[0, 0,], [p, p], [p, p], [0,0]])
        X_input = tf.pad(X_input, paddings, "CONSTANT")
        X_input = tf.image.resize(X_input, [240,240])

    if rand_int == 1:
        # Crop batch of images and resize back to 240x240
        X_input = tf.image.central_crop(X_input, 0.85)
        X_input = tf.image.resize(X_input, [240,240])

    return X_input

X_input = field_of_view(X_input)
    
...

model.fit(X, y epochs = 10, batch_size = 20, validation_split=0.1)

My question is, how would I allow for randint and p to change every batch?

Questioner
PiccolMan
Viewed
0
Andrey 2020-12-02 16:22:56

Use tf.random.uniform() instead of python function randint:

def field_of_view(X_input):
    #rand_int = random.randrange(2)
    rand_int = tf.random.uniform((), 0, 2, dtype=tf.int32)

    if rand_int == 0:
        # Add padding to batch of images and resize to 240x240. This will effectively scale down the image.
        #p = random.randrange(40)
        p = tf.random.uniform((1,1), 0, 40, dtype=tf.int32)
        p = tf.tile(p, [2, 2])
        paddings = tf.pad(p, [[1, 1], [0, 0]])
        #paddings = tf.constant([[0, 0,], [p, p], [p, p], [0,0]])
        X_input = tf.pad(X_input, paddings, "CONSTANT")
        X_input = tf.image.resize(X_input, [240,240])