Warm tip: This article is reproduced from serverfault.com, please click

TensorFlow: Convert GRUCell weights from compat.v1 to tensorflow 2

发布于 2020-11-16 15:38:06

I am trying to convert a saved model in tensorflow 1 to tensorflow 2. I am migrating the code to tensorflow 2, as higlighted in tensorflow docs. However, I would like to simply update my model_weights.ckpt to tensorflow 2. Some weights (Linear, Embdedding) have a similar shape to tensorflow 2 syntax, but I am struggling to transform the weights from my GRUCell.

How to convert the GRUCell weights from compat.v1.nn.rnn_cell.GRUCell to keras.layers.GRUCell ?

The GRUCell has four weights:

  • gru_cell/gates/kernel:0 of shape (S + H, 2 x H),
  • gru_cell/gates/bias:0 of shape (2 x H, ),
  • gru_cell/candidate/kernel:0 of shape (S + H, H),
  • gru_cell/candidate/bias:0 of shape (H, )

I would like to have weights with a similar shape to tensoflow 2 API (or PyTorch API), i.e. a GRUCell with the following weights:

  • gru_cell/kernel:0 of shape (S, 3 x H)
  • gru_cell/recurrent_kernel:0 of shape (H, 3 x H)
  • gru_cell/bias:0 of shape (2, 3 x H)

To illustrate, you can reproduce these results:

1. GRUCell with tensorflow 1 API

import tensorflow as tf

SEQ_LENGTH = 4
HIDDEN_SIZE = 512
BATCH_SIZE = 1
inputs = tf.random.normal([BATCH_SIZE, SEQ_LENGTH])

# GRU cell
gru = tf.compat.v1.nn.rnn_cell.GRUCell(HIDDEN_SIZE)
# Hidden state
state = gru.zero_state(BATCH_SIZE, tf.float32)
# Forward
output, state = gru(inputs, state)

for weight in gru.weights:
    print(weight.name, weight.shape)

Output:

gru_cell/gates/kernel:0 (516, 1024)
gru_cell/gates/bias:0 (1024,)
gru_cell/candidate/kernel:0 (516, 512)
gru_cell/candidate/bias:0 (512,)

2. GRUCell with tensorflow 2 API

import tensorflow as tf

SEQ_LENGTH = 4
HIDDEN_SIZE = 512
BATCH_SIZE = 1
inputs = tf.random.normal([BATCH_SIZE , SEQ_LENGTH])

# GRU cell
gru = tf.keras.layers.GRUCell(HIDDEN_SIZE)
# Hidden state
state = tf.zeros((BATCH_SIZE, HIDDEN_SIZE), dtype=tf.float32)
# Forward
output, state = gru(inputs, state)

# Display the weigths
for weight in gru.weights:
    print(weight.name, weight.shape)

Output:

gru_cell/kernel:0 (4, 1536)
gru_cell/recurrent_kernel:0 (512, 1536)
gru_cell/bias:0 (2, 1536)

Note

Questioner
polop
Viewed
0
community wiki 2020-11-28 22:13:52

For the benefit of community providing solution here though it is presented in Github.

In short, the weights between compat.v1.nn.rnn_cell.GRUCell and keras.layers.GRUCell are not compatible between each other. We don't have a function to convert between them, and if you really want to do it, you will need to do it manually.

Math wise, if you have the numpy value of the v1 weights, the formula are:

B = batch_size

H = state_size

  1. all_kernel = np.concat([gru_cell/gates/kernel, gru_cell/candidate/kernel], axis=1) # shape (B+H, 3 * H)
  2. kernel = all_kernel[:B] # shape(B, 3 * H)
  3. recurrent_kernel = all_kernel[B:] # shape (H, 3 * H)
  4. bias = np.concat([gru_cell/gates/bias, gru_cell/candidate/bias], axis=0) # shape (B, 3 * H)
  5. zero_bias = np.zeros([B, 3 * H])
  6. bias = np.concat([bias, zero_bias], axis=0)