def _build(self, inputs, is_training): x_h, x_l = inputs if type(inputs) is tuple else (inputs, None) if self._stride == 2 or self._stride == (2, 2) or self._stride == [2, 2]: x_h = tf.nn.avg_pool(x_h, (1, 2, 2, 1), (1, 2, 2, 1), 'SAME') if x_h is not None else None x_l = tf.nn.avg_pool(x_l, (1, 2, 2, 1), (1, 2, 2, 1), 'SAME') if x_l is not None else None _, h, w, _ = x_h.shape.as_list() l_out = self._output_channels * self._ratio h_out = self._output_channels - l_out x_h2h, x_h2l, x_l2l, x_l2h = None, None, None, None if x_h is not None: x_h2h = snt.Conv2D(output_channels=h_out, kernel_shape=self._kernel_shape, padding=snt.SAME, name='h2h')(x_h) if l_out > 0: x_h2l = tf.nn.avg_pool(x_h, (1, 2, 2, 1), (1, 2, 2, 1), padding=snt.SAME) x_h2l = snt.Conv2D(output_channels=l_out, kernel_shape=self._kernel_shape, padding=snt.SAME, name='h2l')(x_h2l) if x_l is not None: if l_out > 0: x_l2l = snt.Conv2D(output_channels=l_out, kernel_shape=self._kernel_shape, padding=snt.SAME, name='l2l')(x_l) x_l2h = snt.Conv2D(output_channels=h_out, kernel_shape=self._kernel_shape, padding=snt.SAME, name='l2h')(x_l) x_l2h = tf.image.resize_nearest_neighbor(x_l2h, (h, w)) y_h = x_h2h + x_l2h if x_l2h is not None else x_h2h y_l = x_h2l + x_l2l if x_l2l is not None else x_h2l if self._use_bn: bn1 = snt.BatchNormV2(name='h_bn') bn2 = snt.BatchNormV2(name='l_bn') y_h = bn1(y_h, is_training=is_training, test_local_stats=False) if y_h is not None else None y_l = bn2(y_l, is_training=is_training, test_local_stats=False) if y_l is not None else None if self._actvation_fn is not None: y_h = self._actvation_fn(y_h) if y_h is not None else None y_l = self._actvation_fn(y_l) if y_l is not None else None return y_h if y_l is None else (y_h, y_l)
def factorized_reduction(cls, inputs, is_training, out_filters, stride): assert out_filters % 2 == 0, ( 'Need even number of filters when using this factorized reduction') if stride == 1: with tf.variable_scope('path_conv'): net = snt.Conv2D(out_filters, kernel_shape=1)(inputs) net = snt.BatchNormV2(scale=True, decay_rate=0.9, eps=1e-5)(net, is_training=is_training) return net # Skip path 1 path1 = slim.avg_pool2d(inputs, kernel_size=1, stride=stride, padding='VALID') with tf.variable_scope('path1_conv'): path1 = snt.Conv2D(out_filters // 2, kernel_shape=1)(path1) # Skip path 2 pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]] path2 = tf.pad(inputs, pad_arr)[:, 1:, 1:, :] path2 = slim.avg_pool2d(path2, kernel_size=1, stride=stride, padding='VALID') with tf.variable_scope('path2_conv'): path2 = snt.Conv2D(out_filters // 2, kernel_shape=1)(path2) final_path = tf.concat([path1, path2], axis=-1) final_path = snt.BatchNormV2(scale=True, decay_rate=0.9, eps=1e-5)(final_path, is_training=is_training) return final_path
def testFusedBatchNormV2(self, is_training, test_local_stats, scale, is_training_python_bool): input_shape = (32, 9, 9, 8) iterations = 5 x = tf.placeholder(tf.float32, shape=input_shape) bn1 = snt.BatchNormV2(scale=scale) bn2 = snt.BatchNormV2(fused=False, scale=scale) xx = np.random.random(input_shape) feed_dict = {x: xx} if not is_training_python_bool: is_training_node = tf.placeholder(tf.bool, shape=()) feed_dict.update({is_training_node: is_training}) is_training = is_training_node test_local_stats_node = tf.placeholder(tf.bool, shape=()) feed_dict.update({test_local_stats_node: test_local_stats}) test_local_stats = test_local_stats_node o1 = bn1(x, is_training=is_training, test_local_stats=test_local_stats) o2 = bn2(x, is_training=is_training, test_local_stats=test_local_stats) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) params = [ o1, o2, bn1._moving_mean, bn1._moving_variance, bn2._moving_mean, bn2._moving_variance ] for _ in range(iterations): y1, y2, mean1, var1, mean2, var2 = sess.run(params, feed_dict=feed_dict) self.assertAllClose(y1, y2, atol=1e-4) self.assertAllClose(mean1, mean2, atol=1e-4) self.assertAllClose(var1, var2, atol=1e-4)
def __init__(self, observation_spaces, action_spaces, shared_policy=False, shared_critic=False, hyperparameters=None, name=None): name = 'maddpg_module' if name is None else name super().__init__(name=name) hyperparameters = hyperparameters if hyperparameters else {} self.policy_group = PolicyGroup(observation_spaces, action_spaces, shared=shared_policy, hyperparameters=hyperparameters.get( 'policy', {}).copy()) self.critic_group = CriticGroup(observation_spaces, action_spaces, shared=shared_critic, hyperparameters=hyperparameters.get( 'critic', {}).copy()) self.normalize = hyperparameters.get('normalize', {}).copy() if self.normalize: if self.normalize.get('reward'): self.normalize['reward'] = { name: snt.BatchNormV2(decay_rate=1.0 - 1e-4) for name in action_spaces } if self.normalize.get('observation'): self.normalize['observation'] = { name: snt.BatchNormV2(decay_rate=1.0 - 1e-4) for name in action_spaces } self.observation_spaces = observation_spaces self.action_spaces = action_spaces
def testCheckpointCompatibility(self): save_path = os.path.join(self.get_temp_dir(), "basic_save_restore") input_shape_1 = (31, 7, 7, 5) input_shape_2 = (31, 5, 7, 7) x1 = tf.placeholder(tf.float32, shape=input_shape_1) bn1 = snt.BatchNormV2(data_format="NHWC") bn1(x1, is_training=True) saver1 = snt.get_saver(bn1) x2 = tf.placeholder(tf.float32, shape=input_shape_2) bn2 = snt.BatchNormV2(data_format="NCHW") bn2(x2, is_training=False) saver2 = snt.get_saver(bn2) x3 = tf.placeholder(tf.float32, shape=input_shape_1) bn3 = snt.BatchNormV2(data_format="NCHW") bn3(x3, is_training=False) saver3 = snt.get_saver(bn3) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) saver1.save(sess, save_path) saver2.restore(sess, save_path) with self.assertRaises(tf.errors.InvalidArgumentError): saver3.restore(sess, save_path)
def testInvalidRegularizationParameters(self): with self.assertRaisesRegexp(KeyError, "Invalid regularizer keys.*"): snt.BatchNormV2( regularizers={"not_gamma": tf.contrib.layers.l1_regularizer(0.5)}) err = "Regularizer for 'gamma' is not a callable function" with self.assertRaisesRegexp(TypeError, err): snt.BatchNormV2(regularizers={"gamma": tf.zeros([1, 2, 3])})
def enas_layer(self, prev_layers, layer_id, start_idx, sample_arc, is_training): layer = LayerCollection(self.layer_setups, self.out_filters, self.use_default_layer_func, name='layercol') inputs = prev_layers[-1] layer_type_id = sample_arc[start_idx] out = layer(inputs, is_training, layer_type_id) if layer_id > 0: skip_start = start_idx + 1 layer_skip_count = layer_id skip = sample_arc[skip_start:skip_start + layer_skip_count] with tf.variable_scope('skip'): res_layers = [] for i in range(layer_skip_count): res_layers.append( tf.cond(tf.equal(skip[i], 1), lambda: prev_layers[i + 1], lambda: tf.zeros_like(prev_layers[i + 1]))) res_layers.append(out) out = tf.add_n(res_layers) out = snt.BatchNormV2(scale=True, decay_rate=0.9, eps=1e-5)(out, is_training=is_training) return out
def _build(self, inputs, is_training, k=-1, r=0): if not isinstance(k, int): inputs = tf.cond(tf.equal(k, 0), true_fn=lambda: mixup_process(inputs=inputs, r=r), false_fn=lambda: inputs) h = snt.Linear(output_size=self.hidden_size)(inputs) h = tf.layers.Dropout(rate=self.drop_rate)(h, is_training) if not isinstance(k, int): h = tf.cond(tf.equal(k, 1), true_fn=lambda: mixup_process(inputs=h, r=r), false_fn=lambda: h) for i in range(self.num_highways): h = Highway()(h) if self.use_batch_norm: h = snt.BatchNormV2(data_format='NC')(h, is_training) elif self.use_layer_norm: h = snt.LayerNorm(axis=1)(h) if self.activation != 'linear': h = Activation(activation=self.activation)(h) if self.use_dropout: h = tf.layers.Dropout(rate=self.drop_rate)(h, is_training) if not isinstance(k, int): h = tf.cond(tf.equal(k, i + 2), true_fn=lambda: mixup_process(inputs=h, r=r), false_fn=lambda: h) outputs = snt.Linear(output_size=self.output_size)(h) return outputs
def testUpdatesInsideCond(self): """Demonstrate that updates inside a cond fail.""" _, input_v, inputs = self._get_inputs() bn = snt.BatchNormV2( offset=False, scale=False, decay_rate=0.5, update_ops_collection=tf.GraphKeys.UPDATE_OPS) condition = tf.placeholder(tf.bool) cond = tf.cond(condition, lambda: bn(inputs, is_training=True), lambda: inputs) init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) out_v = sess.run(cond, feed_dict={condition: False}) self.assertAllClose(input_v, out_v) out_v = sess.run(cond, feed_dict={condition: True}) self.assertAllClose(np.zeros([7, 6]), out_v, rtol=1e-4, atol=1e-4) # Variables are accessible outside the tf.cond() mm, mv = sess.run([bn.moving_mean, bn.moving_variance]) self.assertAllClose(np.zeros([1, 6]), mm) self.assertAllClose(np.ones([1, 6]), mv) # Tensors are not accessible outside the tf.cond() with self.assertRaisesRegexp(ValueError, "Operation"): sess.run(tuple(tf.get_collection(tf.GraphKeys.UPDATE_OPS)))
def _build(self, inputs, is_training): t = self.kernel_shape[0] h = self.kernel_shape[1] w = self.kernel_shape[2] t_stride = self.stride[0] h_stride = self.stride[1] w_stride = self.stride[2] net = snt.Conv3D(output_channels=self.output_channels, kernel_shape=(1, h, w), stride=(1, h_stride, w_stride), padding=snt.SAME, initializers=self.initializer, use_bias=self.use_bias, regularizers=regularizers, name='conv_3d')(inputs) net = tf.nn.relu(net) net = snt.Conv3D(output_channels=self.output_channels, kernel_shape=(t, 1, 1), stride=(t_stride, 1, 1), padding=snt.SAME, initializers=ones_initializer, use_bias=self.use_bias, regularizers=regularizers, name='conv_3d_temporal')(net) if self.use_batch_norm: bn = snt.BatchNormV2(scale=True) net = bn(net, is_training=is_training, test_local_stats=False) if self.activation_fn is not None: net = self.activation_fn(net) return net
def testInitializers(self, offset, scale): initializers = { "moving_mean": tf.constant_initializer(2.0), "moving_variance": tf.constant_initializer(3.0), } if scale: initializers["gamma"] = tf.constant_initializer(4.0) if offset: initializers["beta"] = tf.constant_initializer(5.0) inputs_shape = [10, 10] inputs = tf.placeholder(tf.float32, shape=[None] + inputs_shape) bn = snt.BatchNormV2( offset=offset, scale=scale, initializers=initializers) self.assertEqual(bn.initializers, initializers) bn(inputs, is_training=True) init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) ones_v = np.ones([1, 1, inputs_shape[-1]]) self.assertAllClose(bn.moving_mean.eval(), ones_v * 2.0) self.assertAllClose(bn.moving_variance.eval(), ones_v * 3.0) if scale: self.assertAllClose(bn.gamma.eval(), ones_v * 4.0) if offset: self.assertAllClose(bn.beta.eval(), ones_v * 5.0)
def testUpdateImproveStatistics(self): """Test that updating the moving_mean improves statistics.""" _, _, inputs = self._get_inputs() # Use small decay_rate to update faster. bn = snt.BatchNormV2( offset=False, scale=False, decay_rate=0.1, update_ops_collection=tf.GraphKeys.UPDATE_OPS) out1 = bn(inputs, is_training=False, test_local_stats=False) # Build the update ops. bn(inputs, is_training=True) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) out_v = sess.run(out1) # Before updating the moving_mean the results are off. self.assertBetween(np.max(np.abs(np.zeros([7, 6]) - out_v)), 2, 5) sess.run(tuple(tf.get_collection(tf.GraphKeys.UPDATE_OPS))) # After updating the moving_mean the results are better. out_v = sess.run(out1) self.assertBetween(np.max(np.abs(np.zeros([7, 6]) - out_v)), 1, 2)
def testVariableBatchSize(self): """Check the inputs batch_size can change.""" inputs_shape = [10, 10] inputs = tf.placeholder(tf.float32, shape=[None] + inputs_shape) bn = snt.BatchNormV2( offset=False, scale=False) # Outputs should be equal to inputs. out = bn(inputs, is_training=False, test_local_stats=False) init = tf.global_variables_initializer() update_ops = tuple(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) with self.test_session() as sess: sess.run(init) for batch_size in [1, 3, 10]: input_data = np.random.rand(batch_size, *inputs_shape) out_v = sess.run(out, feed_dict={inputs: input_data}) self.assertAllClose(input_data / np.sqrt(1.0 + bn._eps), out_v) sess.run(update_ops, feed_dict={inputs: input_data})
def testRegularizersInRegularizationLosses(self, offset, scale): regularizers = {} if offset: regularizers["beta"] = tf.contrib.layers.l1_regularizer(scale=0.5) if scale: regularizers["gamma"] = tf.contrib.layers.l2_regularizer(scale=0.5) inputs_shape = [10, 10] inputs = tf.placeholder(tf.float32, shape=[None] + inputs_shape) bn = snt.BatchNormV2( offset=offset, scale=scale, regularizers=regularizers) self.assertEqual(bn.regularizers, regularizers) bn(inputs, is_training=True) graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if not offset and not scale: self.assertFalse(graph_regularizers) if offset and not scale: self.assertRegexpMatches(graph_regularizers[0].name, ".*l1_regularizer.*") if scale and not offset: self.assertRegexpMatches(graph_regularizers[0].name, ".*l2_regularizer.*") if scale and offset: self.assertRegexpMatches(graph_regularizers[0].name, ".*l1_regularizer.*") self.assertRegexpMatches(graph_regularizers[1].name, ".*l2_regularizer.*")
def __init__(self, conf, name="encoder_layer"): """ Inits the module. Args: name: The module name. """ super(EncoderLayer, self).__init__(name=name) self.conf = conf self.training = True with self._enter_variable_scope(): batch_initializer = { 'gamma': utils.initializer(conf.embedding_dim), 'moving_mean': utils.initializer(conf.embedding_dim), 'moving_variance': utils.initializer(conf.embedding_dim), 'beta': utils.initializer(conf.embedding_dim) } self._mha = MultiHeadAttentionResidual(conf) self._batch_norm0 = snt.BatchNormV2(scale=True, initializers=batch_initializer, name="batch_norm0") self._batch_norm1 = snt.BatchNormV2(scale=True, initializers=batch_initializer, name="batch_norm1") self._lin_to_hidden = snt.Linear( output_size=conf.ff_hidden_size, initializers={ 'w': utils.initializer(conf.embedding_dim), 'b': utils.initializer(conf.embedding_dim) }, name="lin_to_hidden") self._hidden_to_ouput = snt.Linear( output_size=conf.embedding_dim, initializers={ 'w': utils.initializer(conf.ff_hidden_size), 'b': utils.initializer(conf.ff_hidden_size) }, name="hidden_to_ouput") self._feed_forward = snt.Sequential( [self._lin_to_hidden, tf.nn.relu, self._hidden_to_ouput], name="feed_forward") self._feed_forward_residual = snt.Residual( self._feed_forward, name="feed_forward_residual")
def test16Bit(self, dtype): inputs = tf.placeholder(dtype, shape=[None, 64, 32, 3]) batch_norm = snt.BatchNormV2(offset=True, scale=True, fused=False) output = batch_norm(inputs, is_training=True) self.assertEqual(dtype, output.dtype) self.assertEqual(tf.float32, batch_norm.moving_mean.dtype) self.assertEqual(tf.float32, batch_norm.moving_variance.dtype) self.assertEqual(dtype, batch_norm.gamma.dtype) self.assertEqual(dtype, batch_norm.beta.dtype)
def _fn(batch): net = snt.BatchFlatten()(batch["image"]) for i, h in enumerate(hidden_units): net = snt.Linear(h)(net) if i != (len(hidden_units) - 1): net = snt.BatchNormV2()(net, is_training=False) net = activation(net) loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2( labels=batch["label_onehot"], logits=net) return tf.reduce_mean(loss_vec)
def _build(self, main_inputs, attention_inputs, is_training): main_theta = main_inputs attention_theta = attention_inputs height = main_theta.shape.as_list()[1] width = main_theta.shape.as_list()[2] dim_inner = main_theta.shape.as_list()[-1] if self._batch_size is None: self._batch_size = -1 main_theta = tf.reshape(main_theta, [self._batch_size, dim_inner, height * width], name='main_theta_reshape') attention_theta = tf.reshape(attention_theta, [ self._batch_size, dim_inner, attention_theta.shape.as_list()[1] * attention_theta.shape.as_list()[2], ], name='attention_theta_reshape') main_attention_theta = tf.matmul(main_theta, attention_theta, transpose_a=False, transpose_b=True, name='main_attention_theta_matmul') if self._use_softmax: p = tf.nn.softmax(main_attention_theta, name='main_attention_theta_softmax') else: ones = tf.constant(1, shape=main_attention_theta.shape) ones = tf.reduce_sum(ones) zeros = tf.constant(0, shape=main_attention_theta.shape) denom = tf.add(zeros, ones) tf.stop_gradient(denom) p = tf.div(main_attention_theta, denom, name='main_attention_theta_dot_product') main_attention = tf.matmul(p, attention_theta) main_attention = tf.reshape( main_attention, [self._batch_size, height, width, dim_inner]) main_attention = snt.BatchNormV2(scale=True)(main_attention, is_training=is_training, test_local_stats=False) return main_attention
def conv_layer(cls, inputs, is_training, out_filters, filter_size, channel_multiplier, separable=False): with tf.variable_scope('inp_conv_1'): net = snt.Conv2D(out_filters, kernel_shape=1)(inputs) net = snt.BatchNormV2(scale=True, decay_rate=0.9, eps=1e-5)(net, is_training=is_training) net = tf.nn.relu(net) with tf.variable_scope(f'out_conv_{out_filters}'): if separable: net = snt.SeparableConv2D(out_filters, channel_multiplier, filter_size)(net) else: net = snt.Conv2D(out_filters, kernel_shape=filter_size)(net) net = snt.BatchNormV2(scale=True, decay_rate=0.9, eps=1e-5)(net, is_training=is_training) return net
def __init__(self, observation_spaces, action_spaces, shared_policy=False, shared_critic=False, normalize=None, name='maddpg'): name = 'maddpg_module' if name is None else name super().__init__(name=name) self.best_policy_group = PolicyGroup(observation_spaces, action_spaces, shared=shared_policy, name='best_policy') self.worst_policy_group = PolicyGroup(observation_spaces, action_spaces, shared=shared_policy, name='worst_policy') self.global_critic_group = CriticGroup(observation_spaces, action_spaces, shared=True, name='global_critic') self.personal_critic_group = CriticGroup(observation_spaces, action_spaces, shared=False, name='personal_critic') self.normalize = {} if normalize: if normalize.get('reward'): self.normalize['reward'] = { name: snt.BatchNormV2() for name in action_spaces } if normalize.get('observation'): self.normalize['observation'] = { name: snt.BatchNormV2() for name in action_spaces } self.observation_spaces = observation_spaces self.action_spaces = action_spaces
def _build(self, main_inputs , attention_inputs , is_training): # rgb_inputs and flow_inputs already has same shape as (None,frame_counts,56,56,256) # if self._use_conv: main_theta = Unit3d(main_inputs.shape.as_list()[-1],activation_fn=None,use_batch_norm=False)(main_inputs,is_training=is_training) attention_theta = Unit3d(attention_inputs.shape.as_list()[-1],activation_fn=None,use_batch_norm=False)(attention_inputs,is_training=is_training) else: main_theta = main_inputs attention_theta = attention_inputs temporal = main_theta.shape.as_list()[1] height = main_theta.shape.as_list()[2] width = main_theta.shape.as_list()[3] dim_inner = main_theta.shape.as_list()[-1] if self._batch_size is None: self._batch_size = -1 if self._space: main_theta = tf.reshape(main_theta,[self._batch_size, temporal,height * width,dim_inner],name='main_theta_reshape') attention_theta = tf.reshape(attention_theta,[self._batch_size, temporal,attention_theta.shape.as_list()[2] * attention_theta.shape.as_list()[3],dim_inner],name='attention_theta_reshape') else: main_theta = tf.reshape(main_theta,[self._batch_size, height * width,dim_inner * temporal],name='main_theta_reshape') attention_theta = tf.reshape(attention_theta,[self._batch_size, attention_theta.shape.as_list()[2] * attention_theta.shape.as_list()[3],dim_inner * temporal],name='attention_theta_reshape') main_attention_theta = tf.matmul(main_theta,attention_theta,transpose_a=False,transpose_b=True,name='main_attention_theta_matmul') if self._use_softmax: p = tf.nn.softmax(main_attention_theta,name='main_attention_theta_softmax') else: ones = tf.constant(1,shape=main_attention_theta.shape) ones = tf.reduce_sum(ones) zeros = tf.constant(0,shape=main_attention_theta.shape) denom = tf.add(zeros,ones) tf.stop_gradient(denom) p = tf.div(main_attention_theta,denom,name='main_attention_theta_dot_product') main_attention = tf.matmul(p,attention_theta) main_attention = tf.reshape(main_attention,[self._batch_size,temporal,height,width,dim_inner]) if self._use_conv: main_attention = Unit3d(dim_inner,activation_fn=None,use_batch_norm=False)(main_attention,is_training=is_training) main_attention = snt.BatchNormV2(scale=True)(main_attention,is_training=is_training,test_local_stats=False) return main_attention
def _build(self, inputs, is_training): net = snt.Conv2D(output_channels=self.output_channels, kernel_shape=self.kernel_shape, stride=self.stride, padding=snt.SAME, initializers=initializer, use_bias=self.use_bias, regularizers=regularizers)(inputs) if self.use_batch_norm: bn = snt.BatchNormV2(scale=True) net = bn(net, is_training=is_training, test_local_stats=False) if self.activation_fn is not None: net = self.activation_fn(net) return net
def testFusedBatchNormFloat16(self, is_training, test_local_stats): input_shape = (31, 7, 7, 5) iterations = 3 x = tf.placeholder(tf.float16, shape=input_shape) bn1 = snt.BatchNormV2(fused=False) bn2 = snt.BatchNormV2() feed_dict = {x: np.random.random(input_shape)} o1 = bn1(x, is_training=is_training, test_local_stats=test_local_stats) o2 = bn2(x, is_training=is_training, test_local_stats=test_local_stats) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) params = [ o1, o2, bn1._moving_mean, bn1._moving_variance, bn2._moving_mean, bn2._moving_variance ] for _ in range(iterations): y1, y2, mean1, var1, mean2, var2 = sess.run(params, feed_dict=feed_dict) self.assertAllClose(y1, y2, atol=1e-2) self.assertAllClose(mean1, mean2, atol=1e-2) self.assertAllClose(var1, var2, atol=1e-2)
def _build(self, inputs, is_training): net = snt.Conv2D(output_channels=self._output_channels, kernel_shape=self._kernel_shape, stride=self._stride, use_bias=self._use_bias, padding=snt.SAME)(inputs) if self._use_bn: bn = snt.BatchNormV2() net = bn(net, is_training=is_training, test_local_stats=False) if self._actvation_fn is not None: net = self._actvation_fn(net) return net
def testConstruct(self): inputs = tf.placeholder(tf.float32, shape=[None, 64, 64, 3]) batch_norm1 = snt.BatchNormV2(offset=False, scale=False, fused=False) batch_norm1(inputs, is_training=True) err = "Batch normalization doesn't have an offset, so no beta" with self.assertRaisesRegexp(snt.Error, err): _ = batch_norm1.beta err = "Batch normalization doesn't have a scale, so no gamma" with self.assertRaisesRegexp(snt.Error, err): _ = batch_norm1.gamma batch_norm2 = snt.BatchNormV2(offset=True, scale=False) batch_norm2(inputs, is_training=True) _ = batch_norm2.beta batch_norm3 = snt.BatchNormV2(offset=False, scale=True) batch_norm3(inputs, is_training=True) _ = batch_norm3.gamma batch_norm4 = snt.BatchNormV2(offset=True, scale=True) batch_norm4(inputs, is_training=True) _ = batch_norm4.beta _ = batch_norm4.gamma batch_norm4(inputs, is_training=True, test_local_stats=True) batch_norm4(inputs, is_training=tf.constant(True), test_local_stats=tf.constant(True)) is_training_ph = tf.placeholder(tf.bool) test_local_stats_ph = tf.placeholder(tf.bool) batch_norm4(inputs, is_training=is_training_ph, test_local_stats=test_local_stats_ph)
def testDataFormats(self, data_format): """Check that differing data formats give the correct output shape.""" dim_sizes = { "N": None, "D": 10, "H": 64, "W": 32, "C": 3 } inputs = tf.placeholder_with_default( tf.zeros([dim_sizes[dim_name] or 5 for dim_name in data_format]), [dim_sizes[dim_name] for dim_name in data_format]) bn_data_formats = [data_format] if data_format.endswith("C"): bn_data_formats.append(None) for bn_data_format in bn_data_formats: bn = snt.BatchNormV2(data_format=bn_data_format, offset=False) bn(inputs, is_training=True) mean_shape = bn.moving_mean.get_shape() correct_mean_shape = [ dim_sizes["C"] if dim_name == "C" else 1 for dim_name in data_format ] self.assertEqual(mean_shape, correct_mean_shape) for use_gpu in [True, False]: with self.test_session(use_gpu=use_gpu) as sess: for bn_data_format in "NC NWC NHWC NDHWC NCW NCHW NCDHW".split(): if len(data_format) != len(bn_data_format): bn = snt.BatchNormV2(data_format=bn_data_format, offset=False) err = r"Incorrect data format {} for input shape .*".format( bn_data_format) with self.assertRaisesRegexp(snt.IncompatibleShapeError, err): outputs = bn(inputs, is_training=True) sess.run(outputs)
def _build(self, inputs, is_training): h = snt.Linear(output_size=self.hidden_size)(inputs) h = tf.layers.Dropout(rate=self.drop_rate)(h, is_training) for i in range(self.num_highways): h = Highway()(h) if self.use_batch_norm: h = snt.BatchNormV2(data_format='NC')(h, is_training) if self.activation != 'linear': h = Activation(activation=self.activation)(h) # h = tf.layers.Dropout(rate=self.drop_rate)(h, is_training) outputs = snt.Linear(output_size=self.num_outputs)(h) return outputs
def testDynamicImageShape(self, shape, data_format, fused, batch_unknown): """Check that tensors with unknown spatial dimensions work.""" if batch_unknown: shape[0] = None input_ph = tf.placeholder(tf.float32, shape=shape) bn = snt.BatchNormV2(data_format=data_format, fused=fused) output_train = bn(input_ph, is_training=True) output_test = bn(input_ph, is_training=False) self.assertEqual(output_train.get_shape().as_list(), output_test.get_shape().as_list()) # Check that no information about the shape has been erased from the input. self.assertEqual(output_train.get_shape().as_list(), input_ph.get_shape().as_list())
def testCheckStatsPython(self): """The correct normalization is being used for different Python flags.""" v, input_v, inputs = self._get_inputs() bn = snt.BatchNormV2( offset=False, scale=False, decay_rate=0.5, update_ops_collection=tf.GraphKeys.UPDATE_OPS ) out1 = bn(inputs, is_training=True, test_local_stats=True) out2 = bn(inputs, is_training=False, test_local_stats=True) out3 = bn(inputs, is_training=False, test_local_stats=False) update_ops = tuple(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) self.assertLen(update_ops, 2) with tf.control_dependencies(update_ops): out1 = tf.identity(out1) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) out_v = sess.run(out1) mm, mv = sess.run([bn.moving_mean, bn.moving_variance]) # Single moving average steps should have happened. correct_mm = (1.0 - bn._decay_rate) * v correct_mv = np.ones([1, 6]) * bn._decay_rate self.assertAllClose(np.reshape(correct_mm, [1, 6]), mm) self.assertAllClose(np.reshape(correct_mv, [1, 6]), mv) self.assertAllClose(np.zeros([7, 6]), out_v, rtol=1e-6, atol=1e-5) out2_, out3_ = sess.run([out2, out3]) # Out2: Tested using local batch stats. # Better numerical precision due to using shifted estimators. self.assertAllClose(np.zeros([7, 6]), out2_, rtol=1e-6, atol=1e-5) # Out3: Tested using moving average stats. self.assertAllClose( (input_v - mm) / np.sqrt(mv + bn._eps), out3_)
def _build(self, main_inputs , attention_inputs , is_training): # rgb_inputs and flow_inputs already has same shape as (None,frame_counts,56,56,256) # if self._use_conv: main_theta = Unit3d(main_inputs.shape.as_list()[-1],activation_fn=None,use_batch_norm=False,name='main_theta')(main_inputs,is_training=is_training) attention_theta = Unit3d(attention_inputs.shape.as_list()[-1],activation_fn=None,use_batch_norm=False,name='attention_theta')(attention_inputs,is_training=is_training) else: main_theta = main_inputs attention_theta = attention_inputs temporal = main_theta.shape.as_list()[1] height = main_theta.shape.as_list()[2] width = main_theta.shape.as_list()[3] dim_inner = main_theta.shape.as_list()[-1] if self._batch_size is None: self._batch_size = -1 if self._space: main_theta = tf.reshape(main_theta,[self._batch_size, temporal,height * width,dim_inner],name='main_theta_reshape') attention_theta = tf.reshape(attention_theta,[self._batch_size, temporal,attention_theta.shape.as_list()[2] * attention_theta.shape.as_list()[3],dim_inner],name='attention_theta_reshape') else: main_theta = tf.reshape(main_theta,[self._batch_size, height * width,dim_inner * temporal],name='main_theta_reshape') attention_theta = tf.reshape(attention_theta,[self._batch_size, attention_theta.shape.as_list()[2] * attention_theta.shape.as_list()[3],dim_inner * temporal],name='attention_theta_reshape') main_attention_theta = tf.matmul(main_theta,attention_theta,transpose_a=False,transpose_b=True,name='main_attention_theta_matmul') l = main_attention_theta.shape.as_list()[3] main_attention_theta = linear_transform(main_attention_theta,l) if self._use_softmax: p = tf.nn.softmax(main_attention_theta,name='main_attention_theta_softmax') else: p = main_attention_theta main_attention = tf.matmul(p,attention_theta) main_attention = tf.reshape(main_attention,[self._batch_size,temporal,height,width,dim_inner]) if self._use_conv: main_attention = Unit3d(dim_inner,activation_fn=None,use_batch_norm=False)(main_attention,is_training=is_training) main_attention = snt.BatchNormV2(scale=True)(main_attention,is_training=is_training,test_local_stats=False) return main_attention