Пример #1
0
    def validation_step(self, inputs, model, metrics=None):
        """Validatation step.

    Args:
      inputs: a dictionary of input tensors. output_dict = {
        "video_ids": batch_video_ids,
        "video_matrix": batch_video_matrix,
        "labels": batch_labels,
        "num_frames": batch_frames, }
      model: the model, forward definition
      metrics: a nested structure of metrics objects.

    Returns:
      a dictionary of logs.
    """
        features, labels = inputs['video_matrix'], inputs['labels']
        num_frames = inputs['num_frames']

        # Normalize input features.
        feature_dim = len(features.shape) - 1
        features = tf.nn.l2_normalize(features, feature_dim)

        # sample random frames (None, 5, 1152) -> (None, 30, 1152)
        sample_frames = self.task_config.validation_data.num_frames
        if self.task_config.model.sample_random_frames:
            features = utils.sample_random_frames(features, num_frames,
                                                  sample_frames)
        else:
            features = utils.sample_random_sequence(features, num_frames,
                                                    sample_frames)

        outputs = self.inference_step(features, model)
        outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
                                        outputs)
        if self.task_config.validation_data.segment_labels:
            # workaround to ignore the unrated labels.
            outputs *= inputs['label_weights']
            # remove padding
            outputs = outputs[~tf.reduce_all(labels == -1, axis=1)]
            labels = labels[~tf.reduce_all(labels == -1, axis=1)]
        loss, model_loss = self.build_losses(model_outputs=outputs,
                                             labels=labels,
                                             aux_losses=model.losses)

        logs = {self.loss: loss}

        all_losses = {'total_loss': loss, 'model_loss': model_loss}

        logs.update({self.avg_prec_metric.name: (labels, outputs)})

        if metrics:
            for m in metrics:
                m.update_state(all_losses[m.name])
                logs.update({m.name: m.result()})
        return logs
Пример #2
0
    def train_step(self, inputs, model, optimizer, metrics=None):
        """Does forward and backward.

    Args:
      inputs: a dictionary of input tensors. output_dict = {
          "video_ids": batch_video_ids,
          "video_matrix": batch_video_matrix,
          "labels": batch_labels,
          "num_frames": batch_frames, }
      model: the model, forward pass definition.
      optimizer: the optimizer for this training step.
      metrics: a nested structure of metrics objects.

    Returns:
      a dictionary of logs.
    """
        features, labels = inputs['video_matrix'], inputs['labels']
        num_frames = inputs['num_frames']

        # Normalize input features.
        feature_dim = len(features.shape) - 1
        features = tf.nn.l2_normalize(features, feature_dim)

        # sample random frames / random sequence
        num_frames = tf.cast(num_frames, tf.float32)
        sample_frames = self.task_config.train_data.num_frames
        if self.task_config.model.sample_random_frames:
            features = utils.sample_random_frames(features, num_frames,
                                                  sample_frames)
        else:
            features = utils.sample_random_sequence(features, num_frames,
                                                    sample_frames)

        num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
        with tf.GradientTape() as tape:
            outputs = model(features, training=True)
            # Casting output layer as float32 is necessary when mixed_precision is
            # mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
            outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
                                            outputs)

            # Computes per-replica loss
            loss, model_loss = self.build_losses(model_outputs=outputs,
                                                 labels=labels,
                                                 aux_losses=model.losses)
            # Scales loss as the default gradients allreduce performs sum inside the
            # optimizer.
            scaled_loss = loss / num_replicas

            # For mixed_precision policy, when LossScaleOptimizer is used, loss is
            # scaled for numerical stability.
            if isinstance(optimizer,
                          tf.keras.mixed_precision.LossScaleOptimizer):
                scaled_loss = optimizer.get_scaled_loss(scaled_loss)

        tvars = model.trainable_variables
        grads = tape.gradient(scaled_loss, tvars)
        # Scales back gradient before apply_gradients when LossScaleOptimizer is
        # used.
        if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
            grads = optimizer.get_unscaled_gradients(grads)

        # Apply gradient clipping.
        if self.task_config.gradient_clip_norm > 0:
            grads, _ = tf.clip_by_global_norm(
                grads, self.task_config.gradient_clip_norm)
        optimizer.apply_gradients(list(zip(grads, tvars)))

        logs = {self.loss: loss}

        all_losses = {'total_loss': loss, 'model_loss': model_loss}

        if metrics:
            for m in metrics:
                m.update_state(all_losses[m.name])
                logs.update({m.name: m.result()})

        return logs