示例#1
0
    def call(self, data, steps, seq_lens, training):
        """One pass through the model.

    Args:
      data: dict, batches of tensors from many videos. Available keys: 'audio',
      'frames', 'labels'.
      steps: Tensor, batch of indices of chosen frames in videos.
      seq_lens: Tensor, batch of sequence length of the full videos.
      training: Boolean, if True model is run in training mode.

    Returns:
      embeddings: Tensor, Float tensor containing embeddings

    Raises:
      ValueError: In case invalid configs are passed.
    """
        cnn = self.model['cnn']
        emb = self.model['emb']

        if training:
            num_steps = CONFIG.TRAIN.NUM_FRAMES
        else:
            num_steps = CONFIG.EVAL.NUM_FRAMES

        cnn_feats = get_cnn_feats(cnn, data, training)

        embs = emb(cnn_feats, num_steps)
        channels = embs.shape[-1]
        embs = tf.reshape(embs, [-1, num_steps, channels])

        return embs
示例#2
0
    def call(self, data, steps, seq_lens, training):
        """One pass through the model."""
        cnn = self.model['cnn']
        emb = self.model['emb']

        if training:
            num_steps = CONFIG.TRAIN.NUM_FRAMES * CONFIG.DATA.NUM_STEPS
        else:
            num_steps = CONFIG.EVAL.NUM_FRAMES * CONFIG.DATA.NUM_STEPS

        # Number of steps is doubled due to sampling of positives and anchors.
        cnn_feats = get_cnn_feats(cnn, data, training, 2 * num_steps)

        if training:
            num_steps = CONFIG.TRAIN.NUM_FRAMES
        else:
            num_steps = CONFIG.EVAL.NUM_FRAMES

        embs = emb(cnn_feats, 2 * num_steps)
        embs = tf.stack(tf.split(embs, 2 * num_steps, axis=0), axis=1)

        return embs