How can I customize the gradient computation at training time in keras?

I want to implement the natural gradient in a keras layer. This has to happen inside the customized gradient that is already in place. I want to be able to choose which implementation should be computed (regular or natural gradient) when I call the optimizer.

The problem I'm facing is that when I pass the bool nat_grad=True to the Op at training time (not at graph build time), AutoGraph isn't happy.

At the moment the pseudocode of what is happening is as follows:

def MyOp(inputs, w, nat_grad=False):
    output = w*inputs
    def grad(dy):
        if nat_grad:
            return dy, 1.0
            return -dy, -1.0
    return output, grad

class MyKerasLayer(tf.keras.layers.Layer):
    def __init__(self):
        self.nat_grad = False
    def build(self, input_shape):
        self.w = self.add_weight("w", dtype=tf.float32, trainable=True, initializer=tf.random_normal_initializer)

    def call(self, inputs):
        return MyOp(inputs, self.w, self.nat_grad)

class MyModel(tf.keras.Sequential):
    def __init__(self, num_layers):
        super().__init__([tf.keras.Input(shape=[1], batch_size=None, dtype=tf.float32)]+[MyKerasLayer() for _ in range(num_layers)])

def optimize(model, X, Y, nat_grad:bool):
    for layer in model.layers:
        layer.nat_grad = nat_grad
    model.fit(x=X, y=Y)

model = MyModel(5)
model.compile(optimizer='SGD', loss=lambda x,y:x-y, metrics=[])
X = np.array([1.0, 2.0, 3.0])
Y = np.array([1.0, 2.0, 3.0])
optimize(model, X, Y, nat_grad=True)
>>> OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

What is the right way to do this?

David Vander Mijnsbrugge 2020-11-29 01:26:56

Tensorflow 2.x allows functions to be executed as tf.graphs [1]. So decorating grad(dy) with @tf.function should work, but you'll run into a new error since MyOp takes nat_grad as an input it will expect a gradient for this variable [2].

def MyOp(inputs, w, nat_grad=False):
    output = w*inputs
    def grad(dy):
        if nat_grad:
            return dy, 1.0, 0.
            return -dy, -1.0, 0.
    return output, grad

It seems to me that this is not the way to do this and rather split the gradient op in 2 parts and call them seperately in call.

def NatOp(inputs, w):
    output = w*inputs
    def grad(dy):
        return dy, 1.0     
    return output, grad

def RegOp(inputs, w):
    output = w*inputs
    def grad(dy):
        return -dy, -1.0
    return output, grad

class MyKerasLayer(tf.keras.layers.Layer):
    def __init__(self):
        self.nat_grad = False
    def build(self, input_shape):
        self.w = self.add_weight("w", dtype=tf.float32, trainable=True, initializer=tf.random_normal_initializer)

    def call(self, inputs):
        return NatOp(inputs, self.w) if self.nat_grad else RegOp(inputs, self.w)

