def _define_model( self, model_type, data_dim, num_hidden, kernel_initializer='ones', use_weight_norm=True, data_init=True): if use_weight_norm: base_layer = tfkl.Dense(num_hidden, kernel_initializer=kernel_initializer) layer = weight_norm.WeightNorm( base_layer, data_init=data_init, name='maybe_norm_layer') else: layer = tfkl.Dense(num_hidden, kernel_initializer=kernel_initializer, name='maybe_norm_layer') if model_type == 'layer': return layer elif model_type == 'sequential': return tfk.Sequential( [tfkl.InputLayer((data_dim,)), layer, tfkl.Dense(1, kernel_initializer=kernel_initializer)]) elif model_type == 'sequential_no_input': return tfk.Sequential( [layer, tfkl.Dense(1, kernel_initializer=kernel_initializer)]) elif model_type == 'functional': inputs = tfkl.Input(shape=(data_dim,)) net = layer(inputs) outputs = tfkl.Dense(1, kernel_initializer=kernel_initializer)(net) return tfk.Model(inputs=inputs, outputs=outputs) else: raise ValueError('{} is not a valid model type'.format(model_type))
def testVariableCreationNoBias(self): conv_layer = tfkl.Conv2D(filters=self.num_conv_filters, kernel_size=(2, 2), kernel_initializer='ones', use_bias=False) norm_layer = weight_norm.WeightNorm(conv_layer, name='norm_layer') norm_layer.build(self.conv_random_input.shape) self.assertLen(norm_layer.trainable_variables, 2)
def testConv2DInitializedCorrectly(self, transpose): conv = tfkl.Conv2DTranspose if transpose else tfkl.Conv2D conv_layer = conv(filters=self.num_conv_filters, kernel_size=(2, 2), kernel_initializer='ones') norm_layer = weight_norm.WeightNorm(conv_layer, name='norm_layer') norm_layer.build(self.conv_random_input.shape) self.evaluate([v.initializer for v in norm_layer.weights]) self.evaluate(norm_layer(self.conv_random_input)) true_init_g, true_init_bias = self._calculate_true_initial_variables_conv( self.conv_random_input, self.num_conv_filters, transpose=transpose) self.assertAllClose(true_init_g, self.evaluate(norm_layer.g)) self.assertAllClose(true_init_bias, self.evaluate(norm_layer.layer.bias))
def wrapped_layer(*args, **kwargs): return weight_norm.WeightNorm( layer(*args, **kwargs), data_init=use_data_init)