def _create_averager(self): """Create an adaptive averager.""" return AdaptiveAverager( tensor_spec=self._tensor_spec, speed=self._speed)
def __init__(self, output_dim, noise_dim=32, input_tensor_spec=None, hidden_layers=(256, ), net: Network = None, net_moving_average_rate=None, entropy_regularization=0., mi_weight=None, mi_estimator_cls=MIEstimator, par_vi="gfsf", optimizer=None, name="Generator"): r"""Create a Generator. Args: output_dim (int): dimension of output noise_dim (int): dimension of noise input_tensor_spec (nested TensorSpec): spec of inputs. If there is no inputs, this should be None. hidden_layers (tuple): size of hidden layers. net (Network): network for generating outputs from [noise, inputs] or noise (if inputs is None). If None, a default one with hidden_layers will be created net_moving_average_rate (float): If provided, use a moving average version of net to do prediction. This has been shown to be effective for GAN training (arXiv:1907.02544, arXiv:1812.04948). entropy_regularization (float): weight of entropy regularization mi_estimator_cls (type): the class of mutual information estimator for maximizing the mutual information between [noise, inputs] and [outputs, inputs]. par_vi (string): ParVI methods, options are [``svgd``, ``svgd2``, ``svgd3``, ``gfsf``], * svgd: empirical expectation of SVGD is evaluated by a single resampled particle. The main benefit of this choice is it supports conditional case, while all other options do not. * svgd2: empirical expectation of SVGD is evaluated by splitting half of the sampled batch. It is a trade-off between computational efficiency and convergence speed. * svgd3: empirical expectation of SVGD is evaluated by resampled particles of the same batch size. It has better convergence but involves resampling, so less efficient computaionally comparing with svgd2. * gfsf: wasserstein gradient flow with smoothed functions. It involves a kernel matrix inversion, so computationally most expensive, but in some case the convergence seems faster than svgd approaches. optimizer (torch.optim.Optimizer): (optional) optimizer for training name (str): name of this generator """ super().__init__(train_state_spec=(), optimizer=optimizer, name=name) self._noise_dim = noise_dim self._entropy_regularization = entropy_regularization self._par_vi = par_vi if entropy_regularization == 0: self._grad_func = self._ml_grad else: if par_vi == 'gfsf': self._grad_func = self._gfsf_grad elif par_vi == 'svgd': self._grad_func = self._svgd_grad elif par_vi == 'svgd2': self._grad_func = self._svgd_grad2 elif par_vi == 'svgd3': self._grad_func = self._svgd_grad3 else: raise ValueError("Unsupported par_vi method: %s" % par_vi) self._kernel_width_averager = AdaptiveAverager( tensor_spec=TensorSpec(shape=())) noise_spec = TensorSpec(shape=(noise_dim, )) if net is None: net_input_spec = noise_spec if input_tensor_spec is not None: net_input_spec = [net_input_spec, input_tensor_spec] net = EncodingNetwork(input_tensor_spec=net_input_spec, fc_layer_params=hidden_layers, last_layer_size=output_dim, last_activation=math_ops.identity, name="Generator") self._mi_estimator = None self._input_tensor_spec = input_tensor_spec if mi_weight is not None: x_spec = noise_spec y_spec = TensorSpec((output_dim, )) if input_tensor_spec is not None: x_spec = [x_spec, input_tensor_spec] self._mi_estimator = mi_estimator_cls(x_spec, y_spec, sampler='shift') self._mi_weight = mi_weight self._net = net self._predict_net = None self._net_moving_average_rate = net_moving_average_rate if net_moving_average_rate: self._predict_net = net.copy(name="Genrator_average") self._predict_net_updater = common.get_target_updater( self._net, self._predict_net, tau=net_moving_average_rate)
class Generator(Algorithm): """Generator Generator generates outputs given `inputs` (can be None) by transforming a random noise and input using `net`: outputs = net([noise, input]) if input is not None else net(noise) The generator is trained to minimize the following objective: :math:`E(loss\_func(net([noise, input]))) - entropy\_regulariztion \cdot H(P)` where P is the (conditional) distribution of outputs given the inputs implied by `net` and H(P) is the (conditional) entropy of P. If the loss is the (unnormalized) negative log probability of some distribution Q and the ``entropy_regularization`` is 1, this objective is equivalent to minimizing :math:`KL(P||Q)`. It uses two different ways to optimize `net` depending on ``entropy_regularization``: * ``entropy_regularization`` = 0: the minimization is achieved by simply minimizing loss_func(net([noise, inputs])) * entropy_regularization > 0: the minimization is achieved using amortized particle-based variational inference (ParVI), in particular, two ParVI methods are implemented: 1. amortized Stein Variational Gradient Descent (SVGD): Feng et al "Learning to Draw Samples with Amortized Stein Variational Gradient Descent" https://arxiv.org/pdf/1707.06626.pdf 2. amortized Wasserstein ParVI with Smooth Functions (GFSF): Liu, Chang, et al. "Understanding and accelerating particle-based variational inference." International Conference on Machine Learning. 2019. It also supports an additional optional objective of maximizing the mutual information between [noise, inputs] and outputs by using mi_estimator to prevent mode collapse. This might be useful for ``entropy_regulariztion`` = 0 as suggested in section 5.1 of the following paper: Hjelm et al `Learning Deep Representations by Mutual Information Estimation and Maximization <https://arxiv.org/pdf/1808.06670.pdf>`_ """ def __init__(self, output_dim, noise_dim=32, input_tensor_spec=None, hidden_layers=(256, ), net: Network = None, net_moving_average_rate=None, entropy_regularization=0., mi_weight=None, mi_estimator_cls=MIEstimator, par_vi="gfsf", optimizer=None, name="Generator"): r"""Create a Generator. Args: output_dim (int): dimension of output noise_dim (int): dimension of noise input_tensor_spec (nested TensorSpec): spec of inputs. If there is no inputs, this should be None. hidden_layers (tuple): size of hidden layers. net (Network): network for generating outputs from [noise, inputs] or noise (if inputs is None). If None, a default one with hidden_layers will be created net_moving_average_rate (float): If provided, use a moving average version of net to do prediction. This has been shown to be effective for GAN training (arXiv:1907.02544, arXiv:1812.04948). entropy_regularization (float): weight of entropy regularization mi_estimator_cls (type): the class of mutual information estimator for maximizing the mutual information between [noise, inputs] and [outputs, inputs]. par_vi (string): ParVI methods, options are [``svgd``, ``svgd2``, ``svgd3``, ``gfsf``], * svgd: empirical expectation of SVGD is evaluated by a single resampled particle. The main benefit of this choice is it supports conditional case, while all other options do not. * svgd2: empirical expectation of SVGD is evaluated by splitting half of the sampled batch. It is a trade-off between computational efficiency and convergence speed. * svgd3: empirical expectation of SVGD is evaluated by resampled particles of the same batch size. It has better convergence but involves resampling, so less efficient computaionally comparing with svgd2. * gfsf: wasserstein gradient flow with smoothed functions. It involves a kernel matrix inversion, so computationally most expensive, but in some case the convergence seems faster than svgd approaches. optimizer (torch.optim.Optimizer): (optional) optimizer for training name (str): name of this generator """ super().__init__(train_state_spec=(), optimizer=optimizer, name=name) self._noise_dim = noise_dim self._entropy_regularization = entropy_regularization self._par_vi = par_vi if entropy_regularization == 0: self._grad_func = self._ml_grad else: if par_vi == 'gfsf': self._grad_func = self._gfsf_grad elif par_vi == 'svgd': self._grad_func = self._svgd_grad elif par_vi == 'svgd2': self._grad_func = self._svgd_grad2 elif par_vi == 'svgd3': self._grad_func = self._svgd_grad3 else: raise ValueError("Unsupported par_vi method: %s" % par_vi) self._kernel_width_averager = AdaptiveAverager( tensor_spec=TensorSpec(shape=())) noise_spec = TensorSpec(shape=(noise_dim, )) if net is None: net_input_spec = noise_spec if input_tensor_spec is not None: net_input_spec = [net_input_spec, input_tensor_spec] net = EncodingNetwork(input_tensor_spec=net_input_spec, fc_layer_params=hidden_layers, last_layer_size=output_dim, last_activation=math_ops.identity, name="Generator") self._mi_estimator = None self._input_tensor_spec = input_tensor_spec if mi_weight is not None: x_spec = noise_spec y_spec = TensorSpec((output_dim, )) if input_tensor_spec is not None: x_spec = [x_spec, input_tensor_spec] self._mi_estimator = mi_estimator_cls(x_spec, y_spec, sampler='shift') self._mi_weight = mi_weight self._net = net self._predict_net = None self._net_moving_average_rate = net_moving_average_rate if net_moving_average_rate: self._predict_net = net.copy(name="Genrator_average") self._predict_net_updater = common.get_target_updater( self._net, self._predict_net, tau=net_moving_average_rate) def _trainable_attributes_to_ignore(self): return ["_predict_net"] @property def noise_dim(self): return self._noise_dim def _predict(self, inputs=None, noise=None, batch_size=None, training=True): if inputs is None: assert self._input_tensor_spec is None if noise is None: assert batch_size is not None noise = torch.randn(batch_size, self._noise_dim) gen_inputs = noise else: nest.assert_same_structure(inputs, self._input_tensor_spec) batch_size = nest.get_nest_batch_size(inputs) if noise is None: noise = torch.randn(batch_size, self._noise_dim) else: assert noise.shape[0] == batch_size assert noise.shape[1] == self._noise_dim gen_inputs = [noise, inputs] if self._predict_net and not training: outputs = self._predict_net(gen_inputs)[0] else: outputs = self._net(gen_inputs)[0] return outputs, gen_inputs def predict_step(self, inputs=None, noise=None, batch_size=None, training=False, state=None): """Generate outputs given inputs. Args: inputs (nested Tensor): if None, the outputs is generated only from noise. noise (Tensor): input to the generator. batch_size (int): batch_size. Must be provided if inputs is None. Its is ignored if inputs is not None training (bool): whether train the generator. state: not used Returns: AlgorithmStep: outputs with shape (batch_size, output_dim) """ outputs, _ = self._predict(inputs=inputs, noise=noise, batch_size=batch_size, training=training) return AlgStep(output=outputs, state=(), info=()) def train_step(self, inputs, loss_func, outputs=None, batch_size=None, entropy_regularization=None, state=None): """ Args: inputs (nested Tensor): if None, the outputs is generated only from noise. outputs (Tensor): generator's output (possibly from previous runs) used for this train_step. loss_func (Callable): loss_func([outputs, inputs]) (loss_func(outputs) if inputs is None) returns a Tensor or namedtuple of tensors with field `loss`, which is a Tensor of shape [batch_size] a loss term for optimizing the generator. batch_size (int): batch_size. Must be provided if inputs is None. Its is ignored if inputs is not None. state: not used Returns: AlgorithmStep: outputs: Tensor with shape (batch_size, dim) info: LossInfo """ if outputs is None: outputs, gen_inputs = self._predict(inputs, batch_size=batch_size) if entropy_regularization is None: entropy_regularization = self._entropy_regularization loss, loss_propagated = self._grad_func(inputs, outputs, loss_func, entropy_regularization) mi_loss = () if self._mi_estimator is not None: mi_step = self._mi_estimator.train_step([gen_inputs, outputs]) mi_loss = mi_step.info.loss loss_propagated = loss_propagated + self._mi_weight * mi_loss return AlgStep(output=outputs, state=(), info=LossInfo( loss=loss_propagated, extra=GeneratorLossInfo(generator=loss, mi_estimator=mi_loss))) def _ml_grad(self, inputs, outputs, loss_func, entropy_regularization=None): loss_inputs = outputs if inputs is None else [outputs, inputs] loss = loss_func(loss_inputs) grad = torch.autograd.grad(loss.sum(), outputs)[0] loss_propagated = torch.sum(grad.detach() * outputs, dim=-1) return loss, loss_propagated def _kernel_width(self, dist): """Update kernel_width averager and get latest kernel_width. """ if dist.ndim > 1: dist = torch.sum(dist, dim=-1) assert dist.ndim == 1, "dist must have dimension 1 or 2." width, _ = torch.median(dist, dim=0) width = width / np.log(len(dist)) self._kernel_width_averager.update(width) return self._kernel_width_averager.get() def _rbf_func(self, x, y): """Compute RGF kernel, used by svgd_grad. """ d = (x - y)**2 d = torch.sum(d, -1) h = self._kernel_width(d) w = torch.exp(-d / h) return w def _rbf_func2(self, x, y): r""" Compute the rbf kernel and its gradient w.r.t. first entry :math:`K(x, y), \nabla_x K(x, y)`, used by svgd_grad2. Args: x (Tensor): set of N particles, shape (Nx x W), where W is the dimenseion of each particle y (Tensor): set of N particles, shape (Ny x W), where W is the dimenseion of each particle Returns: :math:`K(x, y)` (Tensor): the RBF kernel of shape (Nx x Ny) :math:`\nabla_x K(x, y)` (Tensor): the derivative of RBF kernel of shape (Nx x Ny x D) """ Nx, Dx = x.shape Ny, Dy = y.shape assert Dx == Dy diff = x.unsqueeze(1) - y.unsqueeze(0) # [Nx, Ny, W] dist_sq = torch.sum(diff**2, -1) # [Nx, Ny] h = self._kernel_width(dist_sq.view(-1)) kappa = torch.exp(-dist_sq / h) # [Nx, Nx] kappa_grad = torch.einsum('ij,ijk->ijk', kappa, -2 * diff / h) # [Nx, Ny, W] return kappa, kappa_grad def _score_func(self, x, alpha=1e-5): r""" Compute the stein estimator of the score function :math:`\nabla\log q = -(K + \alpha I)^{-1}\nabla K`, used by gfsf_grad. Args: x (Tensor): set of N particles, shape (N x D), where D is the dimenseion of each particle alpha (float): weight of regularization for inverse kernel this parameter turns out to be crucial for convergence. Returns: :math:`\nabla\log q` (Tensor): the score function of shape (N x D) """ N, D = x.shape diff = x.unsqueeze(1) - x.unsqueeze(0) # [N, N, D] dist_sq = torch.sum(diff**2, -1) # [N, N] h, _ = torch.median(dist_sq.view(-1), dim=0) h = h / np.log(N) kappa = torch.exp(-dist_sq / h) # [N, N] kappa_inv = torch.inverse(kappa + alpha * torch.eye(N)) # [N, N] kappa_grad = torch.einsum('ij,ijk->jk', kappa, -2 * diff / h) # [N, D] return kappa_inv @ kappa_grad def _svgd_grad(self, inputs, outputs, loss_func, entropy_regularization): """ Compute particle gradients via SVGD, empirical expectation evaluated by a single resampled particle. """ outputs2, _ = self._predict(inputs, batch_size=outputs.shape[0]) kernel_weight = self._rbf_func(outputs, outputs2) weight_sum = entropy_regularization * kernel_weight.sum() kernel_grad = torch.autograd.grad(weight_sum, outputs2)[0] loss_inputs = outputs2 if inputs is None else [outputs2, inputs] loss = loss_func(loss_inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss weighted_loss = kernel_weight.detach() * neglogp loss_grad = torch.autograd.grad(weighted_loss.sum(), outputs2)[0] grad = loss_grad - kernel_grad loss_propagated = torch.sum(grad.detach() * outputs, dim=-1) return loss, loss_propagated def _svgd_grad2(self, inputs, outputs, loss_func, entropy_regularization): """ Compute particle gradients via SVGD, empirical expectation evaluated by splitting half of the sampled batch. """ assert inputs is None, '"svgd2" does not support conditional generator' num_particles = outputs.shape[0] // 2 outputs_i, outputs_j = torch.split(outputs, num_particles, dim=0) loss_inputs = outputs_j loss = loss_func(loss_inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss loss_grad = torch.autograd.grad(neglogp.sum(), outputs_j)[0] # [Nj, D] # [Nj, Ni], [Nj, Ni, D] kernel_weight, kernel_grad = self._rbf_func2(outputs_j, outputs_i) kernel_logp = torch.matmul(kernel_weight.t(), loss_grad) / num_particles # [Ni, D] grad = kernel_logp - entropy_regularization * kernel_grad.mean(0) loss_propagated = torch.sum(grad.detach() * outputs_i, dim=-1) return loss, loss_propagated def _svgd_grad3(self, inputs, outputs, loss_func, entropy_regularization): """ Compute particle gradients via SVGD, empirical expectation evaluated by resampled particles of the same batch size. """ assert inputs is None, '"svgd3" does not support conditional generator' num_particles = outputs.shape[0] outputs2, _ = self._predict(inputs, batch_size=num_particles) loss_inputs = outputs2 loss = loss_func(loss_inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss loss_grad = torch.autograd.grad(neglogp.sum(), outputs2)[0] # [N2, D] # [N2, N], [N2, N, D] kernel_weight, kernel_grad = self._rbf_func2(outputs2, outputs) kernel_logp = torch.matmul(kernel_weight.t(), loss_grad) / num_particles # [N, D] grad = kernel_logp - entropy_regularization * kernel_grad.mean(0) loss_propagated = torch.sum(grad.detach() * outputs, dim=-1) return loss, loss_propagated def _gfsf_grad(self, inputs, outputs, loss_func, entropy_regularization): """Compute particle gradients via GFSF (Stein estimator). """ assert inputs is None, '"gfsf" does not support conditional generator' loss_inputs = outputs loss = loss_func(loss_inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss loss_grad = torch.autograd.grad(neglogp.sum(), outputs)[0] # [N2, D] logq_grad = self._score_func(outputs) grad = loss_grad - entropy_regularization * logq_grad loss_propagated = torch.sum(grad.detach() * outputs, dim=-1) return loss, loss_propagated def after_update(self, training_info): if self._predict_net: self._predict_net_updater()
class Generator(Algorithm): r"""Generator Generator generates outputs given `inputs` (can be None) by transforming a random noise and input using `net`: outputs = net([noise, input]) if input is not None else net(noise) The generator is trained to minimize the following objective: :math:`E(loss\_func(net([noise, input]))) - entropy\_regulariztion \cdot H(P)` where P is the (conditional) distribution of outputs given the inputs implied by `net` and H(P) is the (conditional) entropy of P. If the loss is the (unnormalized) negative log probability of some distribution Q and the ``entropy_regularization`` is 1, this objective is equivalent to minimizing :math:`KL(P||Q)`. It uses two different ways to optimize `net` depending on ``entropy_regularization``: * ``entropy_regularization`` = 0: the minimization is achieved by simply minimizing loss_func(net([noise, inputs])) * entropy_regularization > 0: the minimization is achieved using amortized particle-based variational inference (ParVI), in particular, three ParVI methods are implemented: 1. amortized Stein Variational Gradient Descent (SVGD): Feng et al "Learning to Draw Samples with Amortized Stein Variational Gradient Descent" https://arxiv.org/pdf/1707.06626.pdf 2. amortized Wasserstein ParVI with Smooth Functions (GFSF): Liu, Chang, et al. "Understanding and accelerating particle-based variational inference." International Conference on Machine Learning. 2019. 3. amortized Fisher Neural Sampler with Hutchinson's estimator (MINMAX): Hu et at. "Stein Neural Sampler." https://arxiv.org/abs/1810.03545, 2018. It also supports an additional optional objective of maximizing the mutual information between [noise, inputs] and outputs by using mi_estimator to prevent mode collapse. This might be useful for ``entropy_regulariztion`` = 0 as suggested in section 5.1 of the following paper: Hjelm et al `Learning Deep Representations by Mutual Information Estimation and Maximization <https://arxiv.org/pdf/1808.06670.pdf>`_ """ def __init__(self, output_dim, noise_dim=32, input_tensor_spec=None, hidden_layers=(256, ), net: Network = None, net_moving_average_rate=None, entropy_regularization=0., mi_weight=None, mi_estimator_cls=MIEstimator, par_vi=None, critic_input_dim=None, critic_hidden_layers=(100, 100), critic_l2_weight=10., critic_iter_num=2, critic_relu_mlp=False, critic_use_bn=True, minmax_resample=True, critic_optimizer=None, optimizer=None, name="Generator"): r"""Create a Generator. Args: output_dim (int): dimension of output noise_dim (int): dimension of noise input_tensor_spec (nested TensorSpec): spec of inputs. If there is no inputs, this should be None. hidden_layers (tuple): sizes of hidden layers. net (Network): network for generating outputs from [noise, inputs] or noise (if inputs is None). If None, a default one with hidden_layers will be created net_moving_average_rate (float): If provided, use a moving average version of net to do prediction. This has been shown to be effective for GAN training (arXiv:1907.02544, arXiv:1812.04948). entropy_regularization (float): weight of entropy regularization. mi_weight (float): weight of mutual information loss. mi_estimator_cls (type): the class of mutual information estimator for maximizing the mutual information between [noise, inputs] and [outputs, inputs]. par_vi (string): ParVI methods, options are [``svgd``, ``svgd2``, ``svgd3``, ``gfsf``, ``minmax``], * svgd: empirical expectation of SVGD is evaluated by a single resampled particle. The main benefit of this choice is it supports conditional case, while all other options do not. * svgd2: empirical expectation of SVGD is evaluated by splitting half of the sampled batch. It is a trade-off between computational efficiency and convergence speed. * svgd3: empirical expectation of SVGD is evaluated by resampled particles of the same batch size. It has better convergence but involves resampling, so less efficient computaionally comparing with svgd2. * gfsf: wasserstein gradient flow with smoothed functions. It involves a kernel matrix inversion, so computationally most expensive, but in some case the convergence seems faster than svgd approaches. * minmax: Fisher Neural Sampler, optimal descent direction of the Stein discrepancy is solved by an inner optimization procedure in the space of L2 neural networks. critic_input_dim (int): dimension of critic input, used for ``minmax``. critic_hidden_layers (tuple): sizes of hidden layers of the critic, used for ``minmax``. critic_l2_weight (float): weight of L2 regularization in training the critic, used for ``minmax``. critic_iter_num (int): number of critic updates for each generator train_step, used for ``minmax``. critic_relu_mlp (bool): whether use ReluMLP as the critic constructor, used for ``minmax``. critic_use_bn (book): whether use batch norm for each layers of the critic, used for ``minmax``. minmax_resample (bool): whether resample the generator for each critic update, used for ``minmax``. critic_optimizer (torch.optim.Optimizer): Optimizer for training the critic, used for ``minmax``. optimizer (torch.optim.Optimizer): (optional) optimizer for training name (str): name of this generator """ super().__init__(train_state_spec=(), optimizer=optimizer, name=name) self._output_dim = output_dim self._noise_dim = noise_dim self._entropy_regularization = entropy_regularization self._par_vi = par_vi if entropy_regularization == 0: self._grad_func = self._ml_grad else: if par_vi == 'gfsf': self._grad_func = self._gfsf_grad elif par_vi == 'svgd': self._grad_func = self._svgd_grad elif par_vi == 'svgd2': self._grad_func = self._svgd_grad2 elif par_vi == 'svgd3': self._grad_func = self._svgd_grad3 elif par_vi == 'minmax': if critic_input_dim is None: critic_input_dim = output_dim self._grad_func = self._minmax_grad self._critic_iter_num = critic_iter_num self._critic_l2_weight = critic_l2_weight self._critic_relu_mlp = critic_relu_mlp self._minmax_resample = minmax_resample self._critic = CriticAlgorithm( TensorSpec(shape=(critic_input_dim, )), hidden_layers=critic_hidden_layers, use_relu_mlp=critic_relu_mlp, use_bn=critic_use_bn, optimizer=critic_optimizer) else: raise ValueError("Unsupported par_vi method: %s" % par_vi) self._kernel_width_averager = AdaptiveAverager( tensor_spec=TensorSpec(shape=())) noise_spec = TensorSpec(shape=(noise_dim, )) if net is None: net_input_spec = noise_spec if input_tensor_spec is not None: net_input_spec = [net_input_spec, input_tensor_spec] net = EncodingNetwork(input_tensor_spec=net_input_spec, fc_layer_params=hidden_layers, last_layer_size=output_dim, last_activation=math_ops.identity, name="Generator") self._mi_estimator = None self._input_tensor_spec = input_tensor_spec if mi_weight is not None: x_spec = noise_spec y_spec = TensorSpec((output_dim, )) if input_tensor_spec is not None: x_spec = [x_spec, input_tensor_spec] self._mi_estimator = mi_estimator_cls(x_spec, y_spec, sampler='shift') self._mi_weight = mi_weight self._net = net self._predict_net = None self._net_moving_average_rate = net_moving_average_rate if net_moving_average_rate: self._predict_net = net.copy(name="Genrator_average") self._predict_net_updater = common.get_target_updater( self._net, self._predict_net, tau=net_moving_average_rate) def _trainable_attributes_to_ignore(self): return ["_predict_net", "_critic"] @property def noise_dim(self): return self._noise_dim def _predict(self, inputs=None, noise=None, batch_size=None, training=True): if inputs is None: assert self._input_tensor_spec is None if noise is None: assert batch_size is not None noise = torch.randn(batch_size, self._noise_dim) gen_inputs = noise else: nest.assert_same_structure(inputs, self._input_tensor_spec) batch_size = nest.get_nest_batch_size(inputs) if noise is None: noise = torch.randn(batch_size, self._noise_dim) else: assert noise.shape[0] == batch_size assert noise.shape[1] == self._noise_dim gen_inputs = [noise, inputs] if self._predict_net and not training: outputs = self._predict_net(gen_inputs)[0] else: outputs = self._net(gen_inputs)[0] return outputs, gen_inputs def predict_step(self, inputs=None, noise=None, batch_size=None, training=False, state=None): """Generate outputs given inputs. Args: inputs (nested Tensor): if None, the outputs is generated only from noise. noise (Tensor): input to the generator. batch_size (int): batch_size. Must be provided if inputs is None. Its is ignored if inputs is not None training (bool): whether train the generator. state: not used Returns: AlgorithmStep: outputs with shape (batch_size, output_dim) """ outputs, _ = self._predict(inputs=inputs, noise=noise, batch_size=batch_size, training=training) return AlgStep(output=outputs, state=(), info=()) def train_step(self, inputs, loss_func, batch_size=None, transform_func=None, entropy_regularization=None, state=None): """ Args: inputs (nested Tensor): if None, the outputs is generated only from noise. loss_func (Callable): loss_func([outputs, inputs]) (loss_func(outputs) if inputs is None) returns a Tensor or namedtuple of tensors with field `loss`, which is a Tensor of shape [batch_size] a loss term for optimizing the generator. batch_size (int): batch_size. Must be provided if inputs is None. Its is ignored if inputs is not None. transform_func (Callable): transform function on generator's outputs. Used in function value based par_vi (currently supported by [``svgd2``, ``svgd3``, ``gfsf``]) for evaluating the network(s) parameterized by the generator's outputs (given by self._predict) on the training batch (predefined with transform_func). It can be called in two ways - transform_func(params): params is a tensor of parameters for a network, of shape ``[D]`` or ``[B, D]`` - ``B``: batch size - ``D``: length of network parameters In this case, transform_func first samples additional data besides the predefined training batch and then evaluate the network(s) parameterized by ``params`` on the training batch plus additional sampled data. - transform_func((params, extra_samples)): params is the same as above case and extra_samples is the tensor of additional sampled data. In this case, transform_func evaluates the network(s) parameterized by ``params`` on predefined training batch plus ``extra_samples``. It returns three tensors - outputs: outputs of network parameterized by params evaluated on predined training batch. - density_outputs: outputs of network parameterized by params evaluated on additional sampled data. - extra_samples: additional sampled data, same as input extra_samples if called as transform_func((params, extra_samples)) entropy_regularization (float): weight of entropy regularization. state: not used Returns: AlgorithmStep: outputs: Tensor with shape (batch_size, dim) info: LossInfo """ outputs, gen_inputs = self._predict(inputs, batch_size=batch_size) if entropy_regularization is None: entropy_regularization = self._entropy_regularization loss, loss_propagated = self._grad_func(inputs, outputs, loss_func, entropy_regularization, transform_func) mi_loss = () if self._mi_estimator is not None: mi_step = self._mi_estimator.train_step([gen_inputs, outputs]) mi_loss = mi_step.info.loss loss_propagated = loss_propagated + self._mi_weight * mi_loss return AlgStep(output=outputs, state=(), info=LossInfo( loss=loss_propagated, extra=GeneratorLossInfo(generator=loss, mi_estimator=mi_loss))) def _ml_grad(self, inputs, outputs, loss_func, entropy_regularization=None, transform_func=None): assert transform_func is None, ( "function value based vi is not supported for ml_grad") loss_inputs = outputs if inputs is None else [outputs, inputs] loss = loss_func(loss_inputs) grad = torch.autograd.grad(loss.sum(), outputs)[0] loss_propagated = torch.sum(grad.detach() * outputs, dim=-1) return loss, loss_propagated def _kernel_width(self, dist): """Update kernel_width averager and get latest kernel_width. """ if dist.ndim > 1: dist = torch.sum(dist, dim=-1) assert dist.ndim == 1, "dist must have dimension 1 or 2." width, _ = torch.median(dist, dim=0) width = width / np.log(len(dist)) self._kernel_width_averager.update(width) return self._kernel_width_averager.get() def _rbf_func(self, x, y): """Compute RGF kernel, used by svgd_grad. """ d = (x - y)**2 d = torch.sum(d, -1) h = self._kernel_width(d) w = torch.exp(-d / h) return w def _rbf_func2(self, x, y): r""" Compute the rbf kernel and its gradient w.r.t. first entry :math:`K(x, y), \nabla_x K(x, y)`, used by svgd_grad2 and svgd_grad3. Args: x (Tensor): set of N particles, shape (Nx, ...), where Nx is the number of particles. y (Tensor): set of N particles, shape (Ny, ...), where Ny is the number of particles. Returns: :math:`K(x, y)` (Tensor): the RBF kernel of shape (Nx x Ny) :math:`\nabla_x K(x, y)` (Tensor): the derivative of RBF kernel of shape (Nx x Ny x D) """ Nx = x.shape[0] Ny = y.shape[0] x = x.view(Nx, -1) y = y.view(Ny, -1) Dx = x.shape[1] Dy = y.shape[1] assert Dx == Dy diff = x.unsqueeze(1) - y.unsqueeze(0) # [Nx, Ny, D] dist_sq = torch.sum(diff**2, -1) # [Nx, Ny] h = self._kernel_width(dist_sq.view(-1)) kappa = torch.exp(-dist_sq / h) # [Nx, Nx] kappa_grad = kappa.unsqueeze(-1) * (-2 * diff / h) # [Nx, Ny, D] return kappa, kappa_grad def _score_func(self, x, alpha=1e-5): r""" Compute the stein estimator of the score function :math:`\nabla\log q = -(K + \alpha I)^{-1}\nabla K`, used by gfsf_grad. Args: x (Tensor): set of N particles, shape (N x D), where D is the dimenseion of each particle alpha (float): weight of regularization for inverse kernel this parameter turns out to be crucial for convergence. Returns: :math:`\nabla\log q` (Tensor): the score function of shape (N x D) """ N, D = x.shape diff = x.unsqueeze(1) - x.unsqueeze(0) # [N, N, D] dist_sq = torch.sum(diff**2, -1) # [N, N] h, _ = torch.median(dist_sq.view(-1), dim=0) h = h / np.log(N) kappa = torch.exp(-dist_sq / h) # [N, N] kappa_inv = torch.inverse(kappa + alpha * torch.eye(N)) # [N, N] kappa_grad = torch.einsum('ij,ijk->jk', kappa, -2 * diff / h) # [N, D] return -kappa_inv @ kappa_grad def _svgd_grad(self, inputs, outputs, loss_func, entropy_regularization, transform_func=None): """ Compute particle gradients via SVGD, empirical expectation evaluated by a single resampled particle. """ outputs2, _ = self._predict(inputs, batch_size=outputs.shape[0]) assert transform_func is None, ( "function value based vi is not supported for svgd_grad") kernel_weight = self._rbf_func(outputs, outputs2) weight_sum = entropy_regularization * kernel_weight.sum() kernel_grad = torch.autograd.grad(weight_sum, outputs2)[0] loss_inputs = outputs2 if inputs is None else [outputs2, inputs] loss = loss_func(loss_inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss weighted_loss = kernel_weight.detach() * neglogp loss_grad = torch.autograd.grad(weighted_loss.sum(), outputs2)[0] grad = loss_grad - kernel_grad loss_propagated = torch.sum(grad.detach() * outputs, dim=-1) return loss, loss_propagated def _svgd_grad2(self, inputs, outputs, loss_func, entropy_regularization, transform_func=None): """ Compute particle gradients via SVGD, empirical expectation evaluated by splitting half of the sampled batch. """ assert inputs is None, '"svgd2" does not support conditional generator' if transform_func is not None: outputs, extra_outputs, _ = transform_func(outputs) aug_outputs = torch.cat([outputs, extra_outputs], dim=-1) else: aug_outputs = outputs num_particles = outputs.shape[0] // 2 outputs_i, outputs_j = torch.split(outputs, num_particles, dim=0) aug_outputs_i, aug_outputs_j = torch.split(aug_outputs, num_particles, dim=0) loss_inputs = outputs_j loss = loss_func(loss_inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss loss_grad = torch.autograd.grad(neglogp.sum(), loss_inputs)[0] # [Nj, D] # [Nj, Ni], [Nj, Ni, D'] kernel_weight, kernel_grad = self._rbf_func2(aug_outputs_j.detach(), aug_outputs_i.detach()) kernel_logp = torch.matmul(kernel_weight.t(), loss_grad) / num_particles # [Ni, D] loss_prop_kernel_logp = torch.sum(kernel_logp.detach() * outputs_i, dim=-1) loss_prop_kernel_grad = torch.sum(-entropy_regularization * kernel_grad.mean(0).detach() * aug_outputs_i, dim=-1) loss_propagated = loss_prop_kernel_logp + loss_prop_kernel_grad return loss, loss_propagated def _svgd_grad3(self, inputs, outputs, loss_func, entropy_regularization, transform_func=None): """ Compute particle gradients via SVGD, empirical expectation evaluated by resampled particles of the same batch size. """ assert inputs is None, '"svgd3" does not support conditional generator' num_particles = outputs.shape[0] outputs2, _ = self._predict(inputs, batch_size=num_particles) if transform_func is not None: outputs, extra_outputs, samples = transform_func(outputs) outputs2, extra_outputs2, _ = transform_func((outputs2, samples)) aug_outputs = torch.cat([outputs, extra_outputs], dim=-1) aug_outputs2 = torch.cat([outputs2, extra_outputs2], dim=-1) else: aug_outputs = outputs # [N, D'] aug_outputs2 = outputs2 # [N2, D'] loss_inputs = outputs2 loss = loss_func(loss_inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss loss_grad = torch.autograd.grad(neglogp.sum(), loss_inputs)[0] # [N2, D] # [N2, N], [N2, N, D'] kernel_weight, kernel_grad = self._rbf_func2(aug_outputs2.detach(), aug_outputs.detach()) kernel_logp = torch.matmul(kernel_weight.t(), loss_grad) / num_particles # [N, D] loss_prop_kernel_logp = torch.sum(kernel_logp.detach() * outputs, dim=-1) loss_prop_kernel_grad = torch.sum(-entropy_regularization * kernel_grad.mean(0).detach() * aug_outputs, dim=-1) loss_propagated = loss_prop_kernel_logp + loss_prop_kernel_grad return loss, loss_propagated def _gfsf_grad(self, inputs, outputs, loss_func, entropy_regularization, transform_func=None): """Compute particle gradients via GFSF (Stein estimator). """ assert inputs is None, '"gfsf" does not support conditional generator' if transform_func is not None: outputs, extra_outputs, _ = transform_func(outputs) aug_outputs = torch.cat([outputs, extra_outputs], dim=-1) else: aug_outputs = outputs score_inputs = aug_outputs.detach() loss_inputs = outputs loss = loss_func(loss_inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss loss_grad = torch.autograd.grad(neglogp.sum(), loss_inputs)[0] # [N2, D] logq_grad = self._score_func(score_inputs) * entropy_regularization loss_prop_neglogp = torch.sum(loss_grad.detach() * outputs, dim=-1) loss_prop_logq = torch.sum(logq_grad.detach() * aug_outputs, dim=-1) loss_propagated = loss_prop_neglogp + loss_prop_logq return loss, loss_propagated def _jacobian_trace(self, fx, x): """Hutchinson's trace Jacobian estimator O(1) call to autograd, used by "\"minmax\" method""" assert fx.shape[-1] == x.shape[-1], ( "Jacobian is not square, no trace defined.") eps = torch.randn_like(fx) jvp = torch.autograd.grad(fx, x, grad_outputs=eps, retain_graph=True, create_graph=True)[0] tr_jvp = torch.einsum('bi,bi->b', jvp, eps) return tr_jvp def _critic_train_step(self, inputs, loss_func, entropy_regularization=1.): """ Compute the loss for critic training. """ loss = loss_func(inputs) if isinstance(loss, tuple): neglogp = loss.loss else: neglogp = loss loss_grad = torch.autograd.grad(neglogp.sum(), inputs)[0] # [N, D] if self._critic_relu_mlp: critic_step = self._critic.predict_step(inputs, requires_jac_diag=True) outputs, jac_diag = critic_step.output tr_gradf = jac_diag.sum(-1) # [N] else: outputs = self._critic.predict_step(inputs).output tr_gradf = self._jacobian_trace(outputs, inputs) # [N] f_loss_grad = (loss_grad.detach() * outputs).sum(1) # [N] loss_stein = f_loss_grad - entropy_regularization * tr_gradf # [N] l2_penalty = (outputs * outputs).sum(1).mean() * self._critic_l2_weight critic_loss = loss_stein.mean() + l2_penalty return critic_loss def _minmax_grad(self, inputs, outputs, loss_func, entropy_regularization, transform_func=None): """ Compute particle gradients via minmax svgd (Fisher Neural Sampler). """ assert inputs is None, '"minmax" does not support conditional generator' # optimize the critic using resampled particles assert transform_func is None, ( "function value based vi is not supported for minmax_grad") num_particles = outputs.shape[0] for i in range(self._critic_iter_num): if self._minmax_resample: critic_inputs, _ = self._predict(inputs, batch_size=num_particles) else: critic_inputs = outputs.detach().clone() critic_inputs.requires_grad = True critic_loss = self._critic_train_step(critic_inputs, loss_func, entropy_regularization) self._critic.update_with_gradient(LossInfo(loss=critic_loss)) # compute amortized svgd loss = loss_func(outputs.detach()) critic_outputs = self._critic.predict_step(outputs.detach()).output loss_propagated = torch.sum(-critic_outputs.detach() * outputs, dim=-1) return loss, loss_propagated def after_update(self, training_info): if self._predict_net: self._predict_net_updater()
def __init__(self, output_dim, noise_dim=32, input_tensor_spec=None, hidden_layers=(256, ), net: Network = None, net_moving_average_rate=None, entropy_regularization=0., kernel_sharpness=2., mi_weight=None, mi_estimator_cls=MIEstimator, optimizer: tf.optimizers.Optimizer = None, name="Generator"): """Create a Generator. Args: output_dim (int): dimension of output noise_dim (int): dimension of noise input_tensor_spec (nested TensorSpec): spec of inputs. If there is no inputs, this should be None. hidden_layers (tuple): size of hidden layers. net (Network): network for generating outputs from [noise, inputs] or noise (if inputs is None). If None, a default one with hidden_layers will be created net_moving_average_rate (float): If provided, use a moving average version of net to do prediction. This has been shown to be effective for GAN training (arXiv:1907.02544, arXiv:1812.04948). entropy_regularization (float): weight of entropy regularization kernel_sharpness (float): Used only for entropy_regularization > 0. We calcualte the kernel in SVGD as: exp(-kernel_sharpness * reduce_mean((x-y)^2/width)), where width is the elementwise moving average of (x-y)^2 mi_estimator_cls (type): the class of mutual information estimator for maximizing the mutual information between [noise, inputs] and [outputs, inputs]. optimizer (tf.optimizers.Optimizer): optimizer (optional) name (str): name of this generator """ super().__init__(train_state_spec=(), optimizer=optimizer, name=name) self._noise_dim = noise_dim self._entropy_regularization = entropy_regularization if entropy_regularization == 0: self._grad_func = self._ml_grad else: self._grad_func = self._stein_grad self._kernel_width_averager = AdaptiveAverager( tensor_spec=tf.TensorSpec(shape=(output_dim, ))) self._kernel_sharpness = kernel_sharpness noise_spec = tf.TensorSpec(shape=[noise_dim]) if net is None: net_input_spec = noise_spec if input_tensor_spec is not None: net_input_spec = [net_input_spec, input_tensor_spec] net = EncodingNetwork( name="Generator", input_tensor_spec=net_input_spec, fc_layer_params=hidden_layers, last_layer_size=output_dim) self._mi_estimator = None self._input_tensor_spec = input_tensor_spec if mi_weight is not None: x_spec = noise_spec y_spec = tf.TensorSpec((output_dim, )) if input_tensor_spec is not None: x_spec = [x_spec, input_tensor_spec] self._mi_estimator = mi_estimator_cls( x_spec, y_spec, sampler='shift') self._mi_weight = mi_weight self._net = net self._predict_net = None self._net_moving_average_rate = net_moving_average_rate if net_moving_average_rate: self._predict_net = net.copy(name="Genrator_average") tfa_common.soft_variables_update( self._net.variables, self._predict_net.variables, tau=1.0)
class Generator(Algorithm): """Generator Generator generates outputs given `inputs` (can be None) by transforming a random noise and input using `net`: outputs = net([noise, input]) if input is not None else net(noise) The generator is trained to minimize the following objective: E(loss_func(net([noise, input]))) - entropy_regulariztion * H(P) where P is the (conditional) distribution of outputs given the inputs implied by `net` and H(P) is the (conditional) entropy of P. If the loss is the (unnormalized) negative log probability of some distribution Q and the entropy_regularization is 1, this objective is equivalent to minimizing KL(P||Q). It uses two different ways to optimize `net` depending on entropy_regularization: * entropy_regularization = 0: the minimization is achieved by simply minimizing loss_func(net([noise, inputs])) * entropy_regularization > 0: the minimization is achieved using amortized Stein variational gradient descent (SVGD). See the following paper for reference: Feng et al "Learning to Draw Samples with Amortized Stein Variational Gradient Descent" https://arxiv.org/pdf/1707.06626.pdf It also supports an additional optional objective of maximizing the mutual information between [noise, inputs] and outputs by using mi_estimator to prevent mode collapse. This might be useful for entropy_regulariztion = 0 as suggested in section 5.1 of the following paper: Hjelm et al "Learning Deep Representations by Mutual Information Estimation and Maximization" https://arxiv.org/pdf/1808.06670.pdf """ def __init__(self, output_dim, noise_dim=32, input_tensor_spec=None, hidden_layers=(256, ), net: Network = None, net_moving_average_rate=None, entropy_regularization=0., kernel_sharpness=2., mi_weight=None, mi_estimator_cls=MIEstimator, optimizer: tf.optimizers.Optimizer = None, name="Generator"): """Create a Generator. Args: output_dim (int): dimension of output noise_dim (int): dimension of noise input_tensor_spec (nested TensorSpec): spec of inputs. If there is no inputs, this should be None. hidden_layers (tuple): size of hidden layers. net (Network): network for generating outputs from [noise, inputs] or noise (if inputs is None). If None, a default one with hidden_layers will be created net_moving_average_rate (float): If provided, use a moving average version of net to do prediction. This has been shown to be effective for GAN training (arXiv:1907.02544, arXiv:1812.04948). entropy_regularization (float): weight of entropy regularization kernel_sharpness (float): Used only for entropy_regularization > 0. We calcualte the kernel in SVGD as: exp(-kernel_sharpness * reduce_mean((x-y)^2/width)), where width is the elementwise moving average of (x-y)^2 mi_estimator_cls (type): the class of mutual information estimator for maximizing the mutual information between [noise, inputs] and [outputs, inputs]. optimizer (tf.optimizers.Optimizer): optimizer (optional) name (str): name of this generator """ super().__init__(train_state_spec=(), optimizer=optimizer, name=name) self._noise_dim = noise_dim self._entropy_regularization = entropy_regularization if entropy_regularization == 0: self._grad_func = self._ml_grad else: self._grad_func = self._stein_grad self._kernel_width_averager = AdaptiveAverager( tensor_spec=tf.TensorSpec(shape=(output_dim, ))) self._kernel_sharpness = kernel_sharpness noise_spec = tf.TensorSpec(shape=[noise_dim]) if net is None: net_input_spec = noise_spec if input_tensor_spec is not None: net_input_spec = [net_input_spec, input_tensor_spec] net = EncodingNetwork( name="Generator", input_tensor_spec=net_input_spec, fc_layer_params=hidden_layers, last_layer_size=output_dim) self._mi_estimator = None self._input_tensor_spec = input_tensor_spec if mi_weight is not None: x_spec = noise_spec y_spec = tf.TensorSpec((output_dim, )) if input_tensor_spec is not None: x_spec = [x_spec, input_tensor_spec] self._mi_estimator = mi_estimator_cls( x_spec, y_spec, sampler='shift') self._mi_weight = mi_weight self._net = net self._predict_net = None self._net_moving_average_rate = net_moving_average_rate if net_moving_average_rate: self._predict_net = net.copy(name="Genrator_average") tfa_common.soft_variables_update( self._net.variables, self._predict_net.variables, tau=1.0) def _trainable_attributes_to_ignore(self): return ["_predict_net"] def _predict(self, inputs, batch_size=None, training=True): if inputs is None: assert self._input_tensor_spec is None assert batch_size is not None else: tf.nest.assert_same_structure(inputs, self._input_tensor_spec) batch_size = tf.shape(tf.nest.flatten(inputs)[0])[0] shape = common.concat_shape([batch_size], [self._noise_dim]) noise = tf.random.normal(shape=shape) gen_inputs = noise if inputs is None else [noise, inputs] if self._predict_net and not training: outputs = self._predict_net(gen_inputs)[0] else: outputs = self._net(gen_inputs)[0] return outputs, gen_inputs def predict(self, inputs, batch_size=None, state=None): """Generate outputs given inputs. Args: inputs (nested Tensor): if None, the outputs is generated only from noise. batch_size (int): batch_size. Must be provided if inputs is None. Its is ignored if inputs is not None state: not used Returns: AlgorithmStep: outputs with shape (batch_size, output_dim) """ outputs, _ = self._predict(inputs, batch_size, training=False) return AlgorithmStep(outputs=outputs, state=(), info=()) def train_step(self, inputs, loss_func, batch_size=None, state=None): """ Args: inputs (nested Tensor): if None, the outputs is generated only from noise. loss_func (Callable): loss_func([outputs, inputs]) (loss_func(outputs) if inputs is None) returns a Tensor with shape [batch_size] as a loss for optimizing the generator batch_size (int): batch_size. Must be provided if inputs is None. Its is ignored if inputs is not None state: not used Returns: AlgorithmStep: outputs: Tensor with shape (batch_size, dim) info: LossInfo """ outputs, gen_inputs = self._predict(inputs, batch_size) loss, grad = self._grad_func(inputs, outputs, loss_func) loss_propagated = tf.reduce_sum( tf.stop_gradient(grad) * outputs, axis=-1) mi_loss = () if self._mi_estimator is not None: mi_step = self._mi_estimator.train_step([gen_inputs, outputs]) mi_loss = mi_step.info.loss loss_propagated += self._mi_weight * mi_loss return AlgorithmStep( outputs=outputs, state=(), info=LossInfo( loss=loss_propagated, extra=GeneratorLossInfo(generator=loss, mi_estimator=mi_loss))) def _ml_grad(self, inputs, outputs, loss_func): with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(outputs) loss_inputs = outputs if inputs is None else [outputs, inputs] loss = loss_func(loss_inputs) scalar_loss = tf.reduce_sum(loss) grad = tape.gradient(scalar_loss, outputs) return loss, grad def _kernel_func(self, x, y): d = tf.square(x - y) self._kernel_width_averager.update(tf.reduce_mean(d, axis=0)) d = tf.reduce_mean(d / self._kernel_width_averager.get(), axis=-1) w = tf.math.exp(-self._kernel_sharpness * d) return w def _stein_grad(self, inputs, outputs, loss_func): outputs2, _ = self._predict(inputs, batch_size=tf.shape(outputs)[0]) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(outputs2) kernel_weight = self._kernel_func(outputs, outputs2) weight_sum = self._entropy_regularization * tf.reduce_sum( kernel_weight) kernel_grad = tape.gradient(weight_sum, outputs2) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(outputs2) loss_inputs = outputs2 if inputs is None else [outputs2, inputs] loss = loss_func(loss_inputs) weighted_loss = tf.stop_gradient(kernel_weight) * loss scalar_loss = tf.reduce_sum(weighted_loss) loss_grad = tape.gradient(scalar_loss, outputs2) return loss, loss_grad - kernel_grad def after_train(self, training_info): if self._predict_net: tfa_common.soft_variables_update( self._net.variables, self._predict_net.variables, tau=self._net_moving_average_rate)