def test_step(self, data): data = data_adapter.expand_1d(data) x, y = data y_pred = self.network(x, training=False) losses = self.loss_fn(y, y_pred, VaeLossNet.InputWeight(), training=False) loss = tf.reduce_mean(losses.loss) return { self._output_keys_renamed[k]: v for k, v in losses._asdict().items() }
def custom_train_step(self, data): """ Custom training logic :param data: :return: """ data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) with tf.GradientTape() as tape: y_pred = self.keras_model(x, training=True) loss = self.keras_model.compiled_loss( y, y_pred, sample_weight, regularization_losses=self.keras_model.losses) if self.task == 'regression': variance_loss = mse_var_wrapper(y_pred[0], x['labels_err']) output_loss = mse_lin_wrapper(y_pred[1], x['labels_err']) elif self.task == 'classification': output_loss = bayesian_categorical_crossentropy_wrapper( y_pred[1]) variance_loss = bayesian_categorical_crossentropy_var_wrapper( y_pred[0]) elif self.task == 'binary_classification': output_loss = bayesian_binary_crossentropy_wrapper(y_pred[1]) variance_loss = bayesian_binary_crossentropy_var_wrapper( y_pred[0]) else: raise RuntimeError( 'Only "regression", "classification" and "binary_classification" are supported' ) loss = output_loss(y['output'], y_pred[0]) + variance_loss( y['variance_output'], y_pred[1]) # apply gradient here if version.parse(tf.__version__) >= version.parse("2.4.0"): self.keras_model.optimizer.minimize( loss, self.keras_model.trainable_variables, tape=tape) else: tf.python.keras.engine.training._minimize( self.keras_model.distribute_strategy, tape, self.keras_model.optimizer, loss, self.keras_model.trainable_variables) self.keras_model.compiled_metrics.update_state(y, y_pred, sample_weight) return {m.name: m.result() for m in self.keras_model.metrics}
def train_step(self, data): data = data_adapter.expand_1d(data) x, ground_truths, _ = data_adapter.unpack_x_y_sample_weight(data) with tf.GradientTape() as tape: y_pred = self(x, training=True) input_shape = tf.cast(tf.shape(x[DatasetField.IMAGES])[1:3], self.compute_dtype) loss = self.compute_loss(ground_truths, y_pred, input_shape) loss += self.compiled_loss(None, y_pred, None, regularization_losses=self.losses) self.optimizer.minimize(loss, self.trainable_variables, tape=tape) self.loss_metric.update_state(loss) return {m.name: m.result() for m in self.metrics}
def test_step(self, data): data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) y_pred = self(x, training=False) mask = tf.not_equal(y, -1) y = tf.expand_dims(tf.boolean_mask(y, mask), 1) y_pred = tf.boolean_mask(y_pred, mask) self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) self.compiled_metrics.update_state(y, y_pred, sample_weight) return {m.name: m.result() for m in self.metrics}
def test_step(self, data): data = data_adapter.expand_1d(data) x, y, _ = data_adapter.unpack_x_y_sample_weight(data) x['ground_truths'] = y # In our graph all the metrics are computed inside the call method # So we set training to True to benefit from those metrics # Of course there is no backpropagation at the test step y_pred = self(x, training=True) _ = self.compiled_loss(None, y_pred, None, regularization_losses=self.losses) return {m.name: m.result() for m in self.metrics}
def test_step(self, data): data = data_adapter.expand_1d(data) x, ground_truths, _ = data_adapter.unpack_x_y_sample_weight(data) # To compute the loss we need to get the results of each decoder layer # Setting training to True will provide it y_pred = self(x, training=True) input_shape = tf.cast( tf.shape(x[DatasetField.IMAGES])[1:3], self.compute_dtype) loss = self.compute_loss(ground_truths, y_pred, input_shape) loss += self.compiled_loss(None, y_pred, None, regularization_losses=self.losses) self.loss_metric.update_state(loss) return {m.name: m.result() for m in self.metrics}
def test_step(self, data): """ Overwrite function for the Keras model indicating how a test step will operate. :param Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]] data: The input training data is expected to be provided in the form (input_features, (true_labels, true_concepts)). """ # Massage the data data = data_adapter.expand_1d(data) input_features, (true_labels, true_concepts), sample_weight = \ data_adapter.unpack_x_y_sample_weight(data) # Obtain a prediction of labels and concepts predicted_labels, predicted_concepts, extra_losses = self._call_fn( input_features, training=False, ) # Compute the actual losses task_loss, concept_loss, concept_accuracy = self._compute_losses( predicted_labels=predicted_labels, predicted_concepts=predicted_concepts, true_labels=true_labels, true_concepts=true_concepts, ) # Accumulate both the concept and task-specific loss into a single value total_loss = ( task_loss + self.alpha * concept_loss ) for extra_loss in extra_losses: total_loss += extra_loss result = { self.concept_accuracy_tracker.name: concept_accuracy, self.concept_loss_tracker.name: concept_loss, self.task_loss_tracker.name: task_loss, self.total_loss_tracker.name: total_loss, } for metric in self.extra_metrics: result[metric.name] = metric( true_labels, predicted_labels, sample_weight, ) return result
def print_data_and_train_step(original_data): # Basically copied one-to-one from https://git.io/JvDTv data = data_adapter.expand_1d(original_data) x, y_true, w = data_adapter.unpack_x_y_sample_weight(data) y_pred = keras_model(x, training=True) # this is pretty much like on_train_batch_begin K.print_tensor(w, "Sample weight (w) =") K.print_tensor(x, "Batch input (x) =") K.print_tensor(y_true, "Batch output (y_true) =") K.print_tensor(y_pred, "Prediction (y_pred) =") result = original_train_step(original_data) # add anything here for on_train_batch_end-like behavior return result
def train_step(self, data, training: bool = False): data = data_adapter.expand_1d(data) x, y = data inputs, temp, weights = self.call_inputs(x) # Use a single pass over the network for efficiency. # Normaly would sequentially call generative and then descrimnative nets # Take multiple passes of the descriminator player according to 4.4 of # https://arxiv.org/pdf/1701.00160.pdf to ballance G and D. for i in range(self.config.training_ratio): with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: y_pred = self.network(inputs, y, training=True) gen_losses, descrim_losses = self.lossnet(y, y_pred, training=True) descrim_loss = tf.reduce_mean(descrim_losses) gen_loss = tf.reduce_mean(gen_losses) # Train the descriminator to identify real from fake samples self.optimizer.minimize( descrim_loss, self.network.descriminator.trainable_variables, tape=disc_tape, ) # Train the generator to fool the descriminator self.optimizer.minimize( gen_loss, self.network.generatornet.trainable_variables, tape=gen_tape, ) return { "loss/loss_generative": gen_loss, "loss/loss_descriminative": descrim_loss, **{ "loss/" + v.name: v.result() for v in self.metrics if "accuracy" not in v.name }, **{ "acc/" + v.name: v.result() for v in self.metrics if "accuracy" in v.name }, }
def train_step(self, data): """The logic for one training step. This method can be overridden to support custom training logic. This method is called by `Model.make_train_function`. This method should contain the mathemetical logic for one step of training. This typically includes the forward pass, loss calculation, backpropagation, and metric updates. Configuration details for *how* this logic is run (e.g. `tf.function` and `tf.distribute.Strategy` settings), should be left to `Model.make_train_function`, which can also be overridden. Arguments: data: A nested structure of `Tensor`s. Returns: A `dict` containing values that will be passed to `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the values of the `Model`'s metrics are returned. Example: `{'loss': 0.2, 'accuracy': 0.7}`. """ # These are the only transformations `Model.fit` applies to user-input # data when a `tf.data.Dataset` is provided. These utilities will be exposed # publicly. data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) with backprop.GradientTape() as tape: y_pred = self((x, y), training=True, samples=self.training_mode_samples) loss = self.compiled_loss( y, y_pred, sample_weight, regularization_losses=self.losses) if int(tf.__version__.replace(".","")) < 240: _minimize(self.distribute_strategy, tape, self.optimizer, loss, self.trainable_variables) else: self.optimizer.minimize(loss, self.trainable_variables, tape=tape) # Run in inference mode for other metrics if self.compiled_metrics._metrics is not None: y_pred_inference = self(x, training=False, samples=self.inference_samples_train, verbose=True) self.compiled_metrics.update_state(y, y_pred_inference, sample_weight) return {m.name: m.result() for m in self.metrics}
def train_step(self, data): """The logic for one training step. This method can be overridden to support custom training logic. This method is called by `Model.make_train_function`. This method should contain the mathemetical logic for one step of training. This typically includes the forward pass, loss calculation, backpropagation, and metric updates. Configuration details for *how* this logic is run (e.g. `tf.function` and `tf.distribute.Strategy` settings), should be left to `Model.make_train_function`, which can also be overridden. Arguments: data: A nested structure of `Tensor`s. Returns: A `dict` containing values that will be passed to `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the values of the `Model`'s metrics are returned. Example: `{'loss': 0.2, 'accuracy': 0.7}`. """ # These are the only transformations `Model.fit` applies to user-input # data when a `tf.data.Dataset` is provided. These utilities will be exposed # publicly. data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) self.assigned_inputs = x with backprop.GradientTape(persistent=True) as tape: #? self.tape_handler = tape tape.watch(x) y_pred = self(x, training=True) loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) # For custom training steps, users can just write: # trainable_variables = self.trainable_variables # gradients = tape.gradient(loss, trainable_variables) # self.optimizer.apply_gradients(zip(gradients, trainable_variables)) # The _minimize call does a few extra steps unnecessary in most cases, # such as loss scaling and gradient clipping. _minimize(self.distribute_strategy, tape, self.optimizer, loss, self.trainable_variables) self.compiled_metrics.update_state(y, y_pred, sample_weight) return {m.name: m.result() for m in self.metrics}
def train_step(self, data): data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) with backprop.GradientTape() as tape: y_pred = self(x, training=True) loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) gradients = _minimize(self.distribute_strategy, tape, self.optimizer, loss, self.trainable_variables) # Add context loss to layers self.add_context_loss(gradients) self.compiled_metrics.update_state(y, y_pred, sample_weight) return {m.name: m.result() for m in self.metrics}
def train_step(self, data): data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) use_weights = sample_weight is not None x = x if isinstance(x, tuple) else [x] y = y if isinstance(y, tuple) else [y] sample_weight = (sample_weight if isinstance(x, tuple) else [sample_weight]) batch_size = x[0].shape[0] minibatch_size = batch_size // self.gradient_accumulation_steps train_vars = self.trainable_variables accum_gradient = [tf.zeros_like(v) for v in train_vars] for step in range(self.gradient_accumulation_steps): start = step * minibatch_size end = start + minibatch_size x_step = [xi[start:end] for xi in x] y_step = [yi[start:end] for yi in y] weights_step = ( [weightsi[start:end] for weightsi in sample_weight] if use_weights else None) with backprop.GradientTape(watch_accessed_variables=False) as tape: tape.watch(train_vars) y_pred = self(x_step, training=True) loss = self.compiled_loss( y_step, y_pred, weights_step, regularization_losses=self.losses, ) gradients = tape.gradient(loss, train_vars) accum_gradient = [ (acum_grad + grad) for acum_grad, grad in zip(accum_gradient, gradients) ] accum_gradient = [ grad / self.gradient_accumulation_steps for grad in accum_gradient ] self.optimizer.apply_gradients(zip(accum_gradient, train_vars)) self.compiled_metrics.update_state(y, y_pred, sample_weight) return {m.name: m.result() for m in self.metrics}
def test_step(self, data): data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) x_mu, x_q, r = self.sample(x) decayed_sigma = self.sigma(self.optimizer.iterations) output_dist = tfp.distributions.Normal(loc=x_mu, scale=decayed_sigma) log_likelihood = tf.reduce_logsumexp(output_dist.log_prob(x_q)) output_dist_const_var = tfp.distributions.Normal( loc=x_mu, scale=self.const_sigma) log_likelihood_const_var = tf.reduce_logsumexp( output_dist_const_var.log_prob(x_q)) return { "ll": log_likelihood, "llc": log_likelihood_const_var, }
def test_step(self, data): data = data_adapter.expand_1d(data) x, y = data inputs, temp, weights = self.call_inputs(x) if temp is not None: inpputs = (x, 0.5) y_pred = self.network(inputs, training=False) losses = self.loss_fn(y, y_pred, self.lossnet.InputWeight(), training=False) loss = tf.reduce_mean(losses.loss) return { self._output_keys_renamed[k]: v for k, v in losses._asdict().items() }
def train_step(self, data): # These are the only transformations `Model.fit` applies to user-input # data when a `tf.data.Dataset` is provided. These utilities will be exposed # publicly. data = data_adapter.expand_1d(data) x, y, _ = data_adapter.unpack_x_y_sample_weight(data) with tf.GradientTape() as tape: x['ground_truths'] = y y_pred = self(x, training=True) # All the losses are computed in the call. It can seems weird but it those # the job in a clean way. They are automatically added to self.losses loss = self.compiled_loss(None, y_pred, None, regularization_losses=self.losses) self.optimizer.minimize(loss, self.trainable_variables, tape=tape) return {m.name: m.result() for m in self.metrics}
def predict_step(self, data): """The logic for one inference step. This method can be overridden to support custom inference logic. his method is called by `Model.make_predict_function`. his method should contain the mathematical logic for one step of inference. This typically includes the forward pass. Configuration details for *how* this logic is run (e.g. `tf.function` and `tf.distribute.Strategy` settings), should be left to `Model.make_predict_function`, which can also be overridden. Args: data: A nested structure of `Tensor`s. Returns: The result of one inference step, typically the output of calling the `Model` on data. """ data = data_adapter.expand_1d(data) x, _, _ = data_adapter.unpack_x_y_sample_weight(data) return self(x, training=False, samples=self.inference_samples_predict)
def predict_step(self, data): """The logic for one inference step. Standard prediction is performmed with one sample. To accommodate variational inference, the predictions are averaged over multiple samples from the model. Arguments: data: A nested structure of `Tensor`s. Returns: The result of one inference step, typically the output of calling the `Model` on data. """ data = data_adapter.expand_1d(data) x, _, _ = data_adapter.unpack_x_y_sample_weight(data) y_pred = tf.reduce_mean(self(x, training=False), axis=0) return y_pred
def test_step(self, data): data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight( data) inputs, _, meta = x y_pred, pre_proc_targets = self(x, training=False) # Updates stateful loss metrics. self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) self.compiled_metrics.update_state(y, y_pred, sample_weight) logs = {m.name: m.result() for m in self.metrics} logs["__outputs__"] = ((inputs, pre_proc_targets, meta), y, y_pred) return logs
def train_step(self, batch): self.my_test_count += 1 tf.print("Train step", self.my_test_count) print(self.my_test_count, batch) # Unpack the data data = data_adapter.expand_1d(batch) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(batch) # Perform a forward pass and calculate gradients y_pred, gradients = self._forward_pass(x, y, sample_weight) # Apply the gradients to the model self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) # Update the metrics self.compiled_metrics.update_state(y, y_pred, sample_weight) return {metric.name: metric.result() for metric in self.metrics}
def train_step(self, data): data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) with backprop.GradientTape() as tape: y_pred = self(x, training=True) loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) # For custom training steps, users can just write: trainable_variables = self.trainable_variables gradients = tape.gradient(loss, trainable_variables) for var, grad in zip(trainable_variables, gradients): grad = grad.numpy() print(var.name, np.mean(np.abs(grad))) self.optimizer.apply_gradients(zip(gradients, trainable_variables)) self.compiled_metrics.update_state(y, y_pred, sample_weight) return {m.name: m.result() for m in self.metrics}
def train_step(self, data): data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) with tf.GradientTape() as tape: y_pred = self(x, training=True) loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) trainable_variables = self.trainable_variables gradients = tape.gradient(loss, trainable_variables) self.optimizer.apply_gradients(zip(gradients, trainable_variables)) self.compiled_metrics.update_state(y, y_pred, sample_weight) result_dict = {m.name: m.result() for m in self.metrics} result_dict.update({f"{SMDEBUG_PREFIX}y": y}) result_dict.update({f"{SMDEBUG_PREFIX}gradients": gradients}) # to pass gradients and labels to the hook, add logs with the prefix SMDEBUG_ # For examples: # To save labels: the key will be smdebug_y # To save gradients: the key will be smdebug_gradients return result_dict
def custom_train_step(self, data): """ Custom training logic :param data: :return: """ data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) # Run forward pass. with tf.GradientTape() as tape: y_pred = self.keras_model(x, training=True) self.keras_model.compiled_loss._losses = self._output_loss( y_pred[1], x['labels_err']) self.keras_model.compiled_loss._losses = nest.map_structure( self.keras_model.compiled_loss._get_loss_object, self.keras_model.compiled_loss._losses) self.keras_model.compiled_loss._losses = nest.flatten( self.keras_model.compiled_loss._losses) loss = self.keras_model.compiled_loss( y, y_pred, sample_weight, regularization_losses=self.keras_model.losses) # Run backwards pass. self.keras_model.optimizer.minimize( loss, self.keras_model.trainable_variables, tape=tape) self.keras_model.compiled_metrics.update_state(y, y_pred, sample_weight) # Collect metrics to return return_metrics = {} for metric in self.keras_model.metrics: result = metric.result() if isinstance(result, dict): return_metrics.update(result) else: return_metrics[metric.name] = result return return_metrics
def test_step(self, data): data = data_adapter.expand_1d(data) noise_img, clean_img, sample_weight = data_adapter.unpack_x_y_sample_weight( data) generated_img = self._generator(noise_img, training=True) real_output = self._discriminator(noise_img, clean_img, training=True) gen_output = self._discriminator(noise_img, generated_img, training=True) disc_loss = self.disc_loss_fn(real_output, gen_output) gen_loss = self.gen_loss_fn(clean_img, generated_img, gen_output) disc_loss += self.gradient_penalty(clean_img, generated_img) * self.gp_weight mae = tf.reduce_mean( tf.keras.metrics.mean_absolute_error(clean_img, generated_img)) return { 'generator_loss': gen_loss, 'generator_mae': mae, 'discriminator_loss': disc_loss }
def test_step(self, data): """The logic for one evaluation step. This method can be overridden to support custom evaluation logic. This method is called by `Model.make_test_function`. This function should contain the mathematical logic for one step of evaluation. This typically includes the forward pass, loss calculation, and metrics updates. Configuration details for *how* this logic is run (e.g. `tf.function` and `tf.distribute.Strategy` settings), should be left to `Model.make_test_function`, which can also be overridden. Arguments: data: A nested structure of `Tensor`s. Returns: A `dict` containing values that will be passed to `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the values of the `Model`'s metrics are returned. """ data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) # Run in training mode to calculate loss y_pred = self((x,y), training=True, samples=self.training_mode_samples) # Updates stateful loss metrics. self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) # Run in inference mode for other metrics if self.compiled_metrics._metrics is not None: y_pred_inference = self(x, training=False, samples=self.inference_samples_test, verbose=True) self.compiled_metrics.update_state(y, y_pred_inference, sample_weight) return {m.name: m.result() for m in self.metrics}
def train_step(self, data: Any) -> Mapping[str, Any]: """ The logic for one training step. For more details of the implementation, see TensorFlow's documentation of how to `customize what happens in Model.fit <https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit>`_. """ from tensorflow.python.keras.engine import data_adapter data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) with tf.GradientTape() as tape: y_pred = self.__call__(x, training=True) loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) self._apply_backwards_pass(loss, tape=tape) self.compiled_metrics.update_state(y, y_pred, sample_weight) return {m.name: m.result() for m in self.metrics}
def train_step(self, data): data = data_adapter.expand_1d(data) inputs, targets, sample_weight = \ data_adapter.unpack_x_y_sample_weight(data) loss, outputs, attempts, stop_training = \ self.trainer.train_step(inputs, targets) self.compiled_metrics.update_state(targets, outputs) logs = { "damping_factor": self.trainer.damping_factor, "attempts": attempts, "loss": loss } logs.update({m.name: m.result() for m in self.metrics}) # BUG: In tensorflow v2.2.0 and v2.3.0 setting model.stop_training=True # does not stop training immediately, but only at the end of the epoch. # https://github.com/tensorflow/tensorflow/issues/41174 self.stop_training = stop_training return logs
def test_step(self, data): data = data_adapter.expand_1d(data) x, y = data inputs, temp, weights = self.call_inputs(x) if temp is not None: inpputs = (x, 0.5) y_pred = self.network(inputs, y, training=False) gen_losses, descrim_losses, recon_losses = self.lossnet( self.lossnet.Input.from_output( self.network, self.lossnet.generator_lossnet, y, y_pred, weights, ), training=True, ) descrim_loss = tf.reduce_mean(descrim_losses) gen_loss = tf.reduce_mean(gen_losses) recon_loss = tf.reduce_mean(recon_losses) return { "loss/loss_generative": gen_loss, "loss/loss_descriminative": descrim_loss, "loss/loss_reconstruction": recon_loss, **{ "loss/" + v.name: v.result() for v in self.metrics if "accuracy" not in v.name }, **{ "acc/" + v.name: v.result() for v in self.metrics if "accuracy" in v.name }, }
def test_step(self, data): data = data_adapter.expand_1d(data) input_data, gt, sample_weight = data_adapter.unpack_x_y_sample_weight(data) combined_input = ( tf.concat([input_data[0], input_data[1]], axis=0), tf.concat([input_data[1], input_data[0]], axis=0), tf.concat([input_data[2], input_data[2]], axis=0) ) y_pred = self(combined_input, training=True) depth0, depth1, obj_tran, bg_tran, rot, _ = y_pred loss_val = self.custom_loss.calc( combined_input[0], combined_input[1], depth0, depth1, obj_tran, obj_tran_inv, bg_tran, bg_tran_inv, rot, rot_inv, combined_input[2], gt[0], gt[1], self.train_step_counter) loss_dict = self.custom_loss.loss_vals.copy() loss_dict["sum"] = loss_val return loss_dict
def train_step(self, data): """Logic for one training step. Arguments: data: A nested structure of `Tensor`s. Returns: A `dict` containing values that will be passed to `tf.keras.callbacks.CallbackList.on_train_batch_end`. Typically, the values of the `Model`'s metrics are returned. Example: `{'loss': 0.2, 'accuracy': 0.7}`. Notes: It is assumed that the loss uses SUM_OVER_BATCH_SIZE. In the variational inference case, the loss for the entire training set is essentially loss = KL - CCE, (Eq. 1) where CCE denotes the sum of the CCE for all observations. It should be noted that CCE_all is also an expectation over posterior samples. There are two issues: 1) We are using a *batch* update strategy, which slightly alters the equation and means we need to be careful not to overcount the KL contribution. The `b` subscript indicates an arbitrary batch: loss_b = (KL / n_batch) - CCE_b. (Eq. 2) 2) The default TF reduction strategy `SUM_OVER_BATCH_SIZE` means that we are not actually computing a sum `CCE_b`, but an average: CCE_bavg = CCE_b / batch_size. To fix this, we need to proportionately scale the KL term, loss_b = KL / (n_batch * batch_size) - CCE_bavg (Eq. 3) Expressed more simply, loss_batch = kl_weight * KL - CCE_bavg (Eq. 4) where kl_weight = 1 / train_size. TODO ISSUE But wait, there's more! Observations may be weighted differently, which yields a Frankensteinian CCE_bavg since a proper average would divide by `effective_batch_size` (i.e., the sum of the weights) not `batch_size`. There are a few imperfect remedies: 1) Do not use `SUM_OVER_BATCH_SIZE`. This has many side-effects: must manually handle regularization and computation of mean loss. Mean loss is desirable for optimization stability reasons, although it is not strictly necessary. 2) Require the weights sum to n_sample. Close, but not actually correct. To be correct you would actually need the weights of each batch to sum to `batch_size`, which means the semantics of the weights changes from batch to batch. 3) Multiply Eq. 3 by (batch_size / effective_batch_size). This is tricky since this must be done before non-KL regularization is applied, which is handled inside TF's `compiled_loss`. Could hack this by writing a custom CCE that "pre-applies" correction term. loss_b = KL / (n_batch * effective_batch_size) - (batch_size / effective_batch_size) * CCE_bavg loss_b = KL / (effective_train_size) - (batch_size / effective_batch_size) * CCE_bavg 4) Pretend it's not a problem since both terms are being divided by the same incorrect `effective_batch_size`. """ data = data_adapter.expand_1d(data) x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) # NOTE: During computation of gradients, IndexedSlices are # created which generates a TensorFlow warning. I cannot # find an implementation that avoids IndexedSlices. The # following catch environment silences the offending # warning. with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=UserWarning, module=r'.*indexed_slices') with backprop.GradientTape() as tape: # Average over samples. y_pred = tf.reduce_mean(self(x, training=True), axis=0) loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses) # Custom training steps: trainable_variables = self.trainable_variables gradients = tape.gradient(loss, trainable_variables) # NOTE: There is an open issue for using constraints with # embedding-like layers (e.g., tf.keras.layers.Embedding, # psiz.keras.layers.GroupAttention), see # https://github.com/tensorflow/tensorflow/issues/33755. # There are also issues when using Eager Execution. A # work-around is to convert the problematic gradients, which # are returned as tf.IndexedSlices, into dense tensors. for idx, grad in enumerate(gradients): if gradients[idx].__class__.__name__ == 'IndexedSlices': gradients[idx] = tf.convert_to_tensor(gradients[idx]) self.optimizer.apply_gradients(zip(gradients, trainable_variables)) self.compiled_metrics.update_state(y, y_pred, sample_weight) return {m.name: m.result() for m in self.metrics}