示例#1
0
    def test_simple_point_embedder_forward_pass(self):
        input_features = tf.constant([[1.0, 2.0, 3.0]])
        model = models.SimpleModel(output_shape=(4, ),
                                   embedder=models.TYPE_EMBEDDER_GAUSSIAN,
                                   hidden_dim=2,
                                   num_residual_linear_blocks=3,
                                   num_layers_per_block=2,
                                   use_batch_norm=False,
                                   weight_initializer='ones')

        outputs = model(input_features)
        self.assertAllClose(outputs[0], [[1937.0, 1937.0, 1937.0, 1937.0]])
示例#2
0
    def test_simple_gaussian_embedder_shapes(self):
        input_features = tf.zeros([4, 6], tf.float32)
        model = models.SimpleModel(output_shape=(4, ),
                                   embedder=models.TYPE_EMBEDDER_GAUSSIAN,
                                   hidden_dim=1024,
                                   num_residual_linear_blocks=2,
                                   num_layers_per_block=2)

        outputs = model(input_features)
        self.assertAllEqual(outputs[0].shape, [4, 4])
        self.assertAllEqual(outputs[1]['flatten'].shape, [4, 6])
        self.assertAllEqual(outputs[1]['fc0'].shape, [4, 1024])
        self.assertAllEqual(outputs[1]['res_fcs1'].shape, [4, 1024])
        self.assertAllEqual(outputs[1]['res_fcs2'].shape, [4, 1024])
        self.assertAllEqual(outputs[1]['embedder'].shape, [4, 4])
示例#3
0
    def build(self, input_shape):
        """Builds the model.

    Args:
      input_shape: A TensorShape for the shape of the input.
    """
        self.decoder = models.SimpleModel(
            output_shape=input_shape[1:],
            embedder=models.TYPE_EMBEDDER_POINT,
            hidden_dim=MODEL_LINEAR_HIDDEN_DIM,
            num_residual_linear_blocks=MODEL_LINEAR_NUM_RESIDUAL_BLOCKS,
            num_layers_per_block=MODEL_LINEAR_NUM_LAYERS_PER_BLOCK,
            dropout_rate=MODEL_LINEAR_DROPOUT_RATE,
            use_batch_norm=True,
            weight_max_norm=MODEL_LINEAR_WEIGHT_MAX_NORM,
            weight_initializer='he_normal')
示例#4
0
    def test_simple_gaussian_embedder(self):
        input_features = tf.ones([1, 6], tf.float32)
        model = models.SimpleModel(output_shape=(1, ),
                                   embedder=models.TYPE_EMBEDDER_GAUSSIAN,
                                   hidden_dim=1024,
                                   num_residual_linear_blocks=2,
                                   num_layers_per_block=2,
                                   weight_initializer='ones')

        tf.random.set_seed(0)
        outputs_x = model(input_features, training=True)
        outputs_y = model(input_features, training=True)
        self.assertNotAllEqual(outputs_x[0], outputs_y[0])

        outputs_x = model(input_features, training=False)
        outputs_y = model(input_features, training=False)
        self.assertAllEqual(outputs_x[0], outputs_y[0])
示例#5
0
def get_encoder(embedding_dim, embedder_type=models.TYPE_EMBEDDER_POINT):
    """Gets default encoder for InfoMix.

  Args:
    embedding_dim: An integer for the dimension of the embedding.
    embedder_type: A string for the type of the embedder.

  Returns:
    A configured encoder.
  """
    return models.SimpleModel(
        output_shape=(embedding_dim, ),
        embedder=embedder_type,
        hidden_dim=MODEL_LINEAR_HIDDEN_DIM,
        num_residual_linear_blocks=MODEL_LINEAR_NUM_RESIDUAL_BLOCKS,
        num_layers_per_block=MODEL_LINEAR_NUM_LAYERS_PER_BLOCK,
        dropout_rate=MODEL_LINEAR_DROPOUT_RATE,
        use_batch_norm=True,
        weight_max_norm=MODEL_LINEAR_WEIGHT_MAX_NORM,
        weight_initializer='he_normal')