def test_resnet3d_network_creation(self, model_id, temporal_size,
                                       spatial_size, activation,
                                       aggregate_endpoints):
        """Test for creation of a ResNet3D-50 classifier."""
        input_specs = tf.keras.layers.InputSpec(
            shape=[None, temporal_size, spatial_size, spatial_size, 3])
        temporal_strides = [1, 1, 1, 1]
        temporal_kernel_sizes = [(3, 3, 3), (3, 1, 3, 1), (3, 1, 3, 1, 3, 1),
                                 (1, 3, 1)]

        tf.keras.backend.set_image_data_format('channels_last')

        backbone = backbones.ResNet3D(
            model_id=model_id,
            temporal_strides=temporal_strides,
            temporal_kernel_sizes=temporal_kernel_sizes,
            input_specs=input_specs,
            activation=activation)

        num_classes = 1000
        model = video_classification_model.VideoClassificationModel(
            backbone=backbone,
            num_classes=num_classes,
            input_specs={'image': input_specs},
            dropout_rate=0.2,
            aggregate_endpoints=aggregate_endpoints,
        )

        inputs = np.random.rand(2, temporal_size, spatial_size, spatial_size,
                                3)
        logits = model(inputs)
        self.assertAllEqual([2, num_classes], logits.numpy().shape)
Example #2
0
    def test_resnet_3d_creation(self, model_type):
        """Test creation of ResNet 3D models."""
        backbone_cfg = backbones_3d_cfg.Backbone3D(type=model_type).get()
        temporal_strides = []
        temporal_kernel_sizes = []
        for block_spec in backbone_cfg.block_specs:
            temporal_strides.append(block_spec.temporal_strides)
            temporal_kernel_sizes.append(block_spec.temporal_kernel_sizes)

        _ = backbones.ResNet3D(model_id=backbone_cfg.model_id,
                               temporal_strides=temporal_strides,
                               temporal_kernel_sizes=temporal_kernel_sizes,
                               norm_momentum=0.99,
                               norm_epsilon=1e-5)
Example #3
0
def build_backbone_3d(
        input_specs: tf.keras.layers.InputSpec,
        model_config,
        l2_regularizer: tf.keras.regularizers.Regularizer = None):
    """Builds 3d backbone from a config.

  Args:
    input_specs: tf.keras.layers.InputSpec.
    model_config: a OneOfConfig. Model config.
    l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None.

  Returns:
    tf.keras.Model instance of the backbone.
  """
    backbone_type = model_config.backbone.type
    backbone_cfg = model_config.backbone.get()
    norm_activation_config = model_config.norm_activation

    # Flatten configs before passing to the backbone.
    temporal_strides = []
    temporal_kernel_sizes = []
    use_self_gating = []
    for block_spec in backbone_cfg.block_specs:
        temporal_strides.append(block_spec.temporal_strides)
        temporal_kernel_sizes.append(block_spec.temporal_kernel_sizes)
        use_self_gating.append(block_spec.use_self_gating)

    if backbone_type == 'resnet_3d':
        backbone = backbones.ResNet3D(
            model_id=backbone_cfg.model_id,
            temporal_strides=temporal_strides,
            temporal_kernel_sizes=temporal_kernel_sizes,
            use_self_gating=use_self_gating,
            input_specs=input_specs,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    else:
        raise ValueError('Backbone {!r} not implement'.format(backbone_type))

    return backbone
    def test_serialize_deserialize(self):
        """Validate the classification network can be serialized and deserialized."""
        model_id = 50
        temporal_strides = [1, 1, 1, 1]
        temporal_kernel_sizes = [(3, 3, 3), (3, 1, 3, 1), (3, 1, 3, 1, 3, 1),
                                 (1, 3, 1)]

        backbone = backbones.ResNet3D(
            model_id=model_id,
            temporal_strides=temporal_strides,
            temporal_kernel_sizes=temporal_kernel_sizes)

        model = video_classification_model.VideoClassificationModel(
            backbone=backbone, num_classes=1000)

        config = model.get_config()
        new_model = video_classification_model.VideoClassificationModel.from_config(
            config)

        # Validate that the config can be forced to JSON.
        _ = new_model.to_json()

        # If the serialization was successful, the new config should match the old.
        self.assertAllEqual(model.get_config(), new_model.get_config())