def update_opt(self, f, target, inputs, reg_coeff): self.target = target self.reg_coeff = reg_coeff params = target.get_params(trainable=True) constraint_grads = tf.gradients(f, xs=params) for idx, (grad, param) in enumerate(zip(constraint_grads, params)): if grad is None: constraint_grads[idx] = tf.zeros_like(param) xs = tuple([ tensor_utils.new_tensor_like(p.name.split(":")[0], p) for p in params ]) def Hx_plain(): Hx_plain_splits = tf.gradients( tf.reduce_sum( tf.stack([ tf.reduce_sum(g * x) for g, x in zip(constraint_grads, xs) ])), params) for idx, (Hx, param) in enumerate(zip(Hx_plain_splits, params)): if Hx is None: Hx_plain_splits[idx] = tf.zeros_like(param) return tensor_utils.flatten_tensor_variables(Hx_plain_splits) self.opt_fun = ext.lazydict( f_Hx_plain=lambda: tensor_utils.compile_function( inputs=inputs + xs, outputs=Hx_plain(), log_name="f_Hx_plain", ), )
def update_opt(self, f, target, inputs, reg_coeff): self.target = target self.reg_coeff = reg_coeff params = target.get_params(trainable=True) constraint_grads = tf.gradients(f, xs=params) for idx, (grad, param) in enumerate(zip(constraint_grads, params)): if grad is None: constraint_grads[idx] = tf.zeros_like(param) xs = tuple([tensor_utils.new_tensor_like(p.name.split(":")[0], p) for p in params]) def Hx_plain(): Hx_plain_splits = tf.gradients( tf.reduce_sum( tf.stack([tf.reduce_sum(g * x) for g, x in zip(constraint_grads, xs)]) ), params ) for idx, (Hx, param) in enumerate(zip(Hx_plain_splits, params)): if Hx is None: Hx_plain_splits[idx] = tf.zeros_like(param) return tensor_utils.flatten_tensor_variables(Hx_plain_splits) self.opt_fun = ext.lazydict( f_Hx_plain=lambda: tensor_utils.compile_function( inputs=inputs + xs, outputs=Hx_plain(), log_name="f_Hx_plain", ), )