コード例 #1
0
  def build_model(self):
    """Builds video ssl pretraining model."""
    common_input_shape = [
        d1 if d1 == d2 else None
        for d1, d2 in zip(self.task_config.train_data.feature_shape,
                          self.task_config.validation_data.feature_shape)
    ]
    input_specs = tf.keras.layers.InputSpec(shape=[None] + common_input_shape)
    logging.info('Build model input %r', common_input_shape)

    model = factory_3d.build_model(
        self.task_config.model.model_type,
        input_specs=input_specs,
        model_config=self.task_config.model,
        num_classes=self.task_config.train_data.num_classes)
    return model
コード例 #2
0
    def build_model(self):
        """Builds video classification model."""
        common_input_shape = self._get_feature_shape()
        input_specs = tf.keras.layers.InputSpec(shape=[None] +
                                                common_input_shape)
        logging.info('Build model input %r', common_input_shape)

        l2_weight_decay = self.task_config.losses.l2_weight_decay
        # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
        # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
        # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
        l2_regularizer = (tf.keras.regularizers.l2(l2_weight_decay / 2.0)
                          if l2_weight_decay else None)

        model = factory_3d.build_model(self.task_config.model.model_type,
                                       input_specs=input_specs,
                                       model_config=self.task_config.model,
                                       num_classes=self._get_num_classes(),
                                       l2_regularizer=l2_regularizer)
        return model