def test_simple_point_embedder_forward_pass(self): input_features = tf.constant([[1.0, 2.0, 3.0]]) model = models.SimpleModel(output_shape=(4, ), embedder=models.TYPE_EMBEDDER_GAUSSIAN, hidden_dim=2, num_residual_linear_blocks=3, num_layers_per_block=2, use_batch_norm=False, weight_initializer='ones') outputs = model(input_features) self.assertAllClose(outputs[0], [[1937.0, 1937.0, 1937.0, 1937.0]])
def test_simple_gaussian_embedder_shapes(self): input_features = tf.zeros([4, 6], tf.float32) model = models.SimpleModel(output_shape=(4, ), embedder=models.TYPE_EMBEDDER_GAUSSIAN, hidden_dim=1024, num_residual_linear_blocks=2, num_layers_per_block=2) outputs = model(input_features) self.assertAllEqual(outputs[0].shape, [4, 4]) self.assertAllEqual(outputs[1]['flatten'].shape, [4, 6]) self.assertAllEqual(outputs[1]['fc0'].shape, [4, 1024]) self.assertAllEqual(outputs[1]['res_fcs1'].shape, [4, 1024]) self.assertAllEqual(outputs[1]['res_fcs2'].shape, [4, 1024]) self.assertAllEqual(outputs[1]['embedder'].shape, [4, 4])
def build(self, input_shape): """Builds the model. Args: input_shape: A TensorShape for the shape of the input. """ self.decoder = models.SimpleModel( output_shape=input_shape[1:], embedder=models.TYPE_EMBEDDER_POINT, hidden_dim=MODEL_LINEAR_HIDDEN_DIM, num_residual_linear_blocks=MODEL_LINEAR_NUM_RESIDUAL_BLOCKS, num_layers_per_block=MODEL_LINEAR_NUM_LAYERS_PER_BLOCK, dropout_rate=MODEL_LINEAR_DROPOUT_RATE, use_batch_norm=True, weight_max_norm=MODEL_LINEAR_WEIGHT_MAX_NORM, weight_initializer='he_normal')
def test_simple_gaussian_embedder(self): input_features = tf.ones([1, 6], tf.float32) model = models.SimpleModel(output_shape=(1, ), embedder=models.TYPE_EMBEDDER_GAUSSIAN, hidden_dim=1024, num_residual_linear_blocks=2, num_layers_per_block=2, weight_initializer='ones') tf.random.set_seed(0) outputs_x = model(input_features, training=True) outputs_y = model(input_features, training=True) self.assertNotAllEqual(outputs_x[0], outputs_y[0]) outputs_x = model(input_features, training=False) outputs_y = model(input_features, training=False) self.assertAllEqual(outputs_x[0], outputs_y[0])
def get_encoder(embedding_dim, embedder_type=models.TYPE_EMBEDDER_POINT): """Gets default encoder for InfoMix. Args: embedding_dim: An integer for the dimension of the embedding. embedder_type: A string for the type of the embedder. Returns: A configured encoder. """ return models.SimpleModel( output_shape=(embedding_dim, ), embedder=embedder_type, hidden_dim=MODEL_LINEAR_HIDDEN_DIM, num_residual_linear_blocks=MODEL_LINEAR_NUM_RESIDUAL_BLOCKS, num_layers_per_block=MODEL_LINEAR_NUM_LAYERS_PER_BLOCK, dropout_rate=MODEL_LINEAR_DROPOUT_RATE, use_batch_norm=True, weight_max_norm=MODEL_LINEAR_WEIGHT_MAX_NORM, weight_initializer='he_normal')