Ejemplo n.º 1
0
  def test_model_creation(self, project_dim, num_proj_layers, ft_proj_idx):
    input_size = 224
    inputs = np.random.rand(2, input_size, input_size, 3)
    input_specs = tf.keras.layers.InputSpec(
        shape=[None, input_size, input_size, 3])

    tf.keras.backend.set_image_data_format('channels_last')

    backbone = backbones.ResNet(model_id=50, activation='relu',
                                input_specs=input_specs)
    projection_head = simclr_head.ProjectionHead(
        proj_output_dim=project_dim,
        num_proj_layers=num_proj_layers,
        ft_proj_idx=ft_proj_idx
    )
    num_classes = 10
    supervised_head = simclr_head.ClassificationHead(
        num_classes=10
    )

    model = simclr_model.SimCLRModel(
        input_specs=input_specs,
        backbone=backbone,
        projection_head=projection_head,
        supervised_head=supervised_head,
        mode=simclr_model.PRETRAIN
    )
    outputs = model(inputs)
    projection_outputs = outputs[simclr_model.PROJECTION_OUTPUT_KEY]
    supervised_outputs = outputs[simclr_model.SUPERVISED_OUTPUT_KEY]

    self.assertAllEqual(projection_outputs.shape.as_list(),
                        [2, project_dim])
    self.assertAllEqual([2, num_classes],
                        supervised_outputs.numpy().shape)
Ejemplo n.º 2
0
    def test_outputs(self, num_proj_layers, proj_output_dim, ft_proj_idx):
        test_layer = simclr_head.ProjectionHead(
            num_proj_layers=num_proj_layers,
            proj_output_dim=proj_output_dim,
            ft_proj_idx=ft_proj_idx)

        input_dim = 64
        batch_size = 2
        inputs = np.random.rand(batch_size, input_dim)
        proj_head_output, proj_finetune_output = test_layer(inputs)

        if num_proj_layers == 0:
            self.assertAllClose(inputs, proj_head_output)
            self.assertAllClose(inputs, proj_finetune_output)
        else:
            self.assertAllEqual(proj_head_output.shape.as_list(),
                                [batch_size, proj_output_dim])
            if ft_proj_idx == 0:
                self.assertAllClose(inputs, proj_finetune_output)
            elif ft_proj_idx < num_proj_layers:
                self.assertAllEqual(proj_finetune_output.shape.as_list(),
                                    [batch_size, input_dim])
            else:
                self.assertAllEqual(proj_finetune_output.shape.as_list(),
                                    [batch_size, proj_output_dim])
Ejemplo n.º 3
0
  def __init__(self, config: simclr_multitask_config.SimCLRMTModelConfig,
               **kwargs):
    self._config = config

    # Build shared backbone.
    self._input_specs = tf.keras.layers.InputSpec(shape=[None] +
                                                  config.input_size)

    l2_weight_decay = config.l2_weight_decay
    # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
    # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
    # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
    self._l2_regularizer = (
        tf.keras.regularizers.l2(l2_weight_decay /
                                 2.0) if l2_weight_decay else None)

    self._backbone = backbones.factory.build_backbone(
        input_specs=self._input_specs,
        backbone_config=config.backbone,
        norm_activation_config=config.norm_activation,
        l2_regularizer=self._l2_regularizer)

    # Build the shared projection head
    norm_activation_config = self._config.norm_activation
    projection_head_config = self._config.projection_head
    self._projection_head = simclr_head.ProjectionHead(
        proj_output_dim=projection_head_config.proj_output_dim,
        num_proj_layers=projection_head_config.num_proj_layers,
        ft_proj_idx=projection_head_config.ft_proj_idx,
        kernel_regularizer=self._l2_regularizer,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon)

    super().__init__(**kwargs)
Ejemplo n.º 4
0
  def build_model(self):
    model_config = self.task_config.model
    input_specs = tf.keras.layers.InputSpec(shape=[None] +
                                            model_config.input_size)

    l2_weight_decay = self.task_config.loss.l2_weight_decay
    # Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
    # (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
    # (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
    l2_regularizer = (
        tf.keras.regularizers.l2(l2_weight_decay /
                                 2.0) if l2_weight_decay else None)

    # Build backbone
    backbone = backbones.factory.build_backbone(
        input_specs=input_specs,
        backbone_config=model_config.backbone,
        norm_activation_config=model_config.norm_activation,
        l2_regularizer=l2_regularizer)

    # Build projection head
    norm_activation_config = model_config.norm_activation
    projection_head_config = model_config.projection_head
    projection_head = simclr_head.ProjectionHead(
        proj_output_dim=projection_head_config.proj_output_dim,
        num_proj_layers=projection_head_config.num_proj_layers,
        ft_proj_idx=projection_head_config.ft_proj_idx,
        kernel_regularizer=l2_regularizer,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon)

    # Build supervised head
    supervised_head_config = model_config.supervised_head
    if supervised_head_config:
      if supervised_head_config.zero_init:
        s_kernel_initializer = 'zeros'
      else:
        s_kernel_initializer = 'random_uniform'
      supervised_head = simclr_head.ClassificationHead(
          num_classes=supervised_head_config.num_classes,
          kernel_initializer=s_kernel_initializer,
          kernel_regularizer=l2_regularizer)
    else:
      supervised_head = None

    model = simclr_model.SimCLRModel(
        input_specs=input_specs,
        backbone=backbone,
        projection_head=projection_head,
        supervised_head=supervised_head,
        mode=model_config.mode,
        backbone_trainable=model_config.backbone_trainable)

    logging.info(model.get_config())

    return model
Ejemplo n.º 5
0
    def test_head_creation(self, num_proj_layers, proj_output_dim):
        test_layer = simclr_head.ProjectionHead(
            num_proj_layers=num_proj_layers, proj_output_dim=proj_output_dim)

        input_dim = 64
        x = tf.keras.Input(shape=(input_dim, ))
        proj_head_output, proj_finetune_output = test_layer(x)

        proj_head_output_dim = input_dim
        if num_proj_layers > 0:
            proj_head_output_dim = proj_output_dim
        self.assertAllEqual(proj_head_output.shape.as_list(),
                            [None, proj_head_output_dim])

        if num_proj_layers > 0:
            proj_finetune_output_dim = input_dim
            self.assertAllEqual(proj_finetune_output.shape.as_list(),
                                [None, proj_finetune_output_dim])
Ejemplo n.º 6
0
    def _instantiate_sub_tasks(self) -> Dict[Text, tf.keras.Model]:
        tasks = {}

        # Build the shared projection head
        norm_activation_config = self._config.norm_activation
        projection_head_config = self._config.projection_head
        projection_head = simclr_head.ProjectionHead(
            proj_output_dim=projection_head_config.proj_output_dim,
            num_proj_layers=projection_head_config.num_proj_layers,
            ft_proj_idx=projection_head_config.ft_proj_idx,
            kernel_regularizer=self._l2_regularizer,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon)

        for model_config in self._config.heads:
            # Build supervised head
            supervised_head_config = model_config.supervised_head
            if supervised_head_config:
                if supervised_head_config.zero_init:
                    s_kernel_initializer = 'zeros'
                else:
                    s_kernel_initializer = 'random_uniform'
                supervised_head = simclr_head.ClassificationHead(
                    num_classes=supervised_head_config.num_classes,
                    kernel_initializer=s_kernel_initializer,
                    kernel_regularizer=self._l2_regularizer)
            else:
                supervised_head = None

            tasks[model_config.mode] = simclr_model.SimCLRModel(
                input_specs=self._input_specs,
                backbone=self._backbone,
                projection_head=projection_head,
                supervised_head=supervised_head,
                mode=model_config.mode,
                backbone_trainable=self._config.backbone_trainable)

        return tasks