def forward_MLP(self, name, all_params, input_tensor=None, batch_normalization=False, reuse=True, is_training=False, bias_transform=False): # is_training and reuse are for batch norm, irrelevant if batch_norm set to False # set reuse to False if the first time this func is called. with tf.variable_scope(name): if input_tensor is None: l_in = make_input(shape=self.input_shape, input_var=None, name='input') else: l_in = input_tensor l_hid = l_in for idx in range(self.n_hidden): bias_transform_ = all_params['bias_transform' + str(idx)] if bias_transform else None l_hid = forward_dense_bias_transform(l_hid, all_params['W' + str(idx)], all_params['b' + str(idx)], bias_transform=bias_transform_, batch_norm=batch_normalization, nonlinearity=self.hidden_nonlinearity, scope=str(idx), reuse=reuse, is_training=is_training ) bias_transform = all_params['bias_transform' + str(self.n_hidden)] if bias_transform else None output = forward_dense_bias_transform(l_hid, all_params['W' + str(self.n_hidden)], all_params['b' + str(self.n_hidden)], bias_transform=bias_transform, batch_norm=False, nonlinearity=self.output_nonlinearity, ) return l_in, output
def forward_MLP(self, name, all_params, input_tensor=None, input_shape=None, n_hidden=-1, hidden_nonlinearity=tf.identity, output_nonlinearity=tf.identity, batch_normalization=False, reuse=True, is_training=False): # is_training and reuse are for batch norm, irrelevant if batch_norm set to False # set reuse to False if the first time this func is called. with tf.variable_scope(name): if input_tensor is None: assert input_shape is not None l_in = make_input(shape=(None, ) + input_shape, input_var=None, name='input') else: l_in = input_tensor l_hid = l_in for idx in range(n_hidden): l_hid = forward_dense_layer(l_hid, all_params['W' + str(idx)], all_params['b' + str(idx)], batch_norm=batch_normalization, nonlinearity=hidden_nonlinearity, scope=str(idx), reuse=reuse, is_training=is_training) output = forward_dense_layer( l_hid, all_params['W' + str(n_hidden)], all_params['b' + str(n_hidden)], batch_norm=False, nonlinearity=output_nonlinearity, ) return l_in, output