def __init__(self, name='ExitFlowModule'): super(ExitFlowModule, self).__init__(name=name) with self._enter_variable_scope(): self.resconv_ex1 = snt.Conv2D(output_channels=1024, kernel_shape=1, stride=2, name='resconv_ex1') self.bn_resex1 = snt.BatchNorm(name='bn_resex1') self.sepconv_ex1 = snt.SeparableConv2D(output_channels=728, channel_multiplier=1, kernel_shape=3, name='sepconv_ex1') self.bn_sepex1 = snt.BatchNorm(name='bn_sepex1') self.sepconv_ex2 = snt.SeparableConv2D(output_channels=1024, channel_multiplier=1, kernel_shape=3, name='sepconv_ex2') self.bn_sepex2 = snt.BatchNorm(name='bn_sepex2') self.sepconv_ex3 = snt.SeparableConv2D(output_channels=1536, channel_multiplier=1, kernel_shape=3, name='sepconv_ex3') self.bn_sepex3 = snt.BatchNorm(name='bn_sepex3') self.sepconv_ex4 = snt.SeparableConv2D(output_channels=2048, channel_multiplier=1, kernel_shape=3, name='sepconv_ex4') self.bn_sepex4 = snt.BatchNorm(name='bn_sepex4')
def custom_build(inputs, is_training, keep_prob): x_inputs = tf.reshape(inputs, [-1, 28, 28, 1]) """A custom build method to wrap into a sonnet Module.""" outputs = snt.Conv2D(output_channels=32, kernel_shape=4, stride=2)(x_inputs) outputs = snt.BatchNorm()(outputs, is_training=is_training) outputs = tf.nn.relu(outputs) outputs = tf.nn.max_pool(outputs, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') outputs = snt.Conv2D(output_channels=64, kernel_shape=4, stride=2)(outputs) outputs = snt.BatchNorm()(outputs, is_training=is_training) outputs = tf.nn.relu(outputs) outputs = tf.nn.max_pool(outputs, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') outputs = snt.Conv2D(output_channels=1024, kernel_shape=1, stride=1)(outputs) outputs = snt.BatchNorm()(outputs, is_training=is_training) outputs = tf.nn.relu(outputs) outputs = snt.BatchFlatten()(outputs) outputs = tf.nn.dropout(outputs, keep_prob=keep_prob) outputs = snt.Linear(output_size=10)(outputs) # _activation_summary(outputs) return outputs
def _depthwise_separable_conv(self, output_fmaps, stride=1): """Returns a list of modules/ops implementing a single depthwise separable convolution, as defined in the MobileNet paper: 3x3 depthwise, with optional stride batch norm relu 1x1 pointwise with output_fmaps output channels batch norm relu Args: output_fmaps (int): the number of feature maps resulting from the 1x1 convolution. stride (Optional[int]): stride, applied evenly in both dimensions. Defaults to 1. Returns: list of modules/ops as above. """ modules = [ snt.DepthwiseConv2D(1, [3, 3], stride=stride, use_bias=False), snt.BatchNorm(), tf.nn.relu, snt.Conv2D(output_fmaps, [1, 1], use_bias=False), snt.BatchNorm(), tf.nn.relu ] return modules
def res_block2d(input_features, n_channels, n_down_channels=None, activation_fn=tf.nn.relu, initializers=None, regularizers=None, convs_per_block=3, mode=False, name='conv'): """A pre-activated residual block. Args: input_features: A tensor of shape (b, h, w, c). n_channels: An integer specifying the number of output channels. n_down_channels: An integer specifying the number of intermediate channels. activation_fn: A callable activation function. initializers: Initializers for the weights and biases. regularizers: Regularizers for the weights and biases. convs_per_block: An Integer specifying the number of convolutional layers. Returns: A tensor of shape (b, h, w, c). """ # Pre-activate the inputs. skip = input_features if BN: input_features = snt.BatchNorm(scale=True)(input_features, is_training=mode) residual = activation_fn(input_features) # Set the number of intermediate channels that we compress to. if n_down_channels is None: n_down_channels = n_channels for c in range(convs_per_block): residual = snt.Conv2D(n_down_channels, (3, 3), padding='SAME', initializers=initializers, regularizers=regularizers, name=name + '-d1-c{}'.format(c))(residual) if c < convs_per_block - 1: if BN: residual = snt.BatchNorm(scale=True)(residual, is_training=mode) residual = activation_fn(residual) incoming_channels = input_features.shape[-1] if incoming_channels != n_channels: skip = snt.Conv2D(n_channels, (1, 1), padding='SAME', initializers=initializers, regularizers=regularizers, name=name + '-d2-c{}'.format(c))(skip) if n_down_channels != n_channels: residual = snt.Conv2D(n_channels, (1, 1), padding='SAME', initializers=initializers, regularizers=regularizers, name=name + '-d3-c{}'.format(c))(residual) return skip + residual
def residual_linear(x, l, project_shortcut=False, is_training=True): # see https://arxiv.org/pdf/1512.03385.pdf # and https://blog.waya.ai/deep-residual-learning-9610bb62c355 use_batch_norm = False shortcut = x x = snt.Linear(l, use_bias=False)(x) if use_batch_norm: x = snt.BatchNorm(update_ops_collection=tf.GraphKeys.UPDATE_OPS)( x, is_training=is_training) x = tf.nn.leaky_relu(x) x = snt.Linear(l, use_bias=False)(x) if use_batch_norm: x = snt.BatchNorm(update_ops_collection=tf.GraphKeys.UPDATE_OPS)( x, is_training=is_training) if project_shortcut: shortcut = snt.Linear(l, use_bias=False)(shortcut) if use_batch_norm: shortcut = snt.BatchNorm( update_ops_collection=tf.GraphKeys.UPDATE_OPS)( shortcut, is_training=is_training) x = x + shortcut x = tf.nn.leaky_relu(x) return x
def __init__(self, hidden_size): super(MLP, self).__init__() self.dense_1 = snt.Linear(hidden_size, name="hidden1") self.bn_1 = snt.BatchNorm(create_scale=True, create_offset=True) self.dense_2 = snt.Linear(hidden_size, name="hidden2") self.bn_2 = snt.BatchNorm(create_scale=True, create_offset=True)
def __init__(self, pc_dim=(2048, 3), fc_dims=(64, 128, 512, 1024), act=tf.nn.relu, entropy_reg=True, batch_norm=False, name='gen'): super(Generator, self).__init__(name=name) self.pc_dim = pc_dim self.act = act self.batch_norm = batch_norm self.entropy_reg = entropy_reg self.fc_body = [] self.fc_sigma_body = [] self.bn_body = [] self.bn_sigma_body = [] with self._enter_variable_scope(): for i, fc_dim in enumerate(fc_dims): fc = snt.Linear(fc_dim, name='fc_%d' % i) self.fc_body.append(fc) self.bn_body.append( snt.BatchNorm(offset=True, scale=True, name='bn_%d' % i)) self.fc_final = snt.Linear(np.prod(pc_dim), name='fc_final') for i, fc_dim in enumerate(fc_dims): fc = snt.Linear(fc_dim, name='fc_sigma_%d' % i) self.fc_sigma_body.append(fc) self.bn_sigma_body.append( snt.BatchNorm(offset=True, scale=True, name='bn_sigma_%d' % i)) self.fc_sigma_final = snt.Linear(np.prod(pc_dim), name='fc_sigma_final')
def __init__(self, name='MNIST_Generator', regularization=1.e-4): super(MNISTGenerator, self).__init__(name=name) reg = { 'w': l2_regularizer(scale=regularization), 'b': l2_regularizer(scale=regularization) } with self._enter_variable_scope(): self.linear = snt.Linear(name='linear', output_size=3136, regularizers=reg) self.bn1 = snt.BatchNorm(name='batch_norm_1') self.reshape = snt.BatchReshape(name='reshape', shape=[7, 7, 64]) self.deconv1 = snt.Conv2DTranspose(name='tr-conv2d_1', output_channels=64, kernel_shape=5, stride=2, regularizers=reg) self.bn2 = snt.BatchNorm(name='batch_norm_2') self.deconv2 = snt.Conv2DTranspose(name='tr-conv2d_2', output_channels=32, kernel_shape=5, stride=1, regularizers=reg) self.bn3 = snt.BatchNorm(name='batch_norm_3') self.deconv3 = snt.Conv2DTranspose(name='tr-conv2d_3', output_channels=3, kernel_shape=5, stride=2, regularizers=reg)
def testInvalidRegularizationParameters(self): with self.assertRaisesRegexp(KeyError, "Invalid regularizer keys.*"): snt.BatchNorm( regularizers={"not_gamma": tf.contrib.layers.l1_regularizer(0.5)}) err = "Regularizer for 'gamma' is not a callable function" with self.assertRaisesRegexp(TypeError, err): snt.BatchNorm(regularizers={"gamma": tf.zeros([1, 2, 3])})
def _build(self, inputs, verbose=VERBOSITY): if EncodeProcessDecode_v7_edge_segmentation_no_edges_dropout.convnet_tanh: activation = tf.nn.tanh else: activation = tf.nn.relu img_shape = get_correct_image_shape( config=None, get_type="seg", depth_data_provided= EncodeProcessDecode_v7_edge_segmentation_no_edges_dropout. depth_data_provided) img_data = tf.reshape( inputs, [-1, *img_shape]) # -1 means "all", i.e. batch dimension print(img_data.get_shape()) ''' 60, 80 ''' outputs = snt.Conv2D(output_channels=32, kernel_shape=3, stride=2, padding="SAME")(img_data) outputs = activation(outputs) if EncodeProcessDecode_v7_edge_segmentation_no_edges_dropout.conv_layer_instance_norm: outputs = snt.BatchNorm()(outputs, is_training=self._is_training) print(outputs.get_shape()) ''' 30, 40 ''' outputs = snt.Conv2D(output_channels=32, kernel_shape=3, stride=2, padding="SAME")(outputs) outputs = activation(outputs) if EncodeProcessDecode_v7_edge_segmentation_no_edges_dropout.conv_layer_instance_norm: outputs = snt.BatchNorm()(outputs, is_training=self._is_training) print(outputs.get_shape()) ''' 15, 20 ''' outputs = snt.Conv2D(output_channels=16, kernel_shape=3, stride=2, padding="SAME")(outputs) outputs = activation(outputs) if EncodeProcessDecode_v7_edge_segmentation_no_edges_dropout.conv_layer_instance_norm: outputs = snt.BatchNorm()(outputs, is_training=self._is_training) print(outputs.get_shape()) ''' 8, 10 ''' outputs = snt.Conv2D(output_channels=5, kernel_shape=3, stride=2, padding="SAME")(outputs) outputs = activation(outputs) if EncodeProcessDecode_v7_edge_segmentation_no_edges_dropout.conv_layer_instance_norm: outputs = snt.BatchNorm()(outputs, is_training=self._is_training) print(outputs.get_shape()) outputs = tf.layers.flatten(outputs) # 8,10,5 flattened return outputs
def custom_build(inputs, is_training, keep_prob): """A custom build method to wrap into a sonnet Module.""" outputs = snt.Conv2D(output_channels=32, kernel_shape=4, stride=2)(inputs) outputs = snt.BatchNorm()(outputs, is_training=is_training) outputs = tf.nn.relu(outputs) outputs = snt.Conv2D(output_channels=64, kernel_shape=4, stride=2)(outputs) outputs = snt.BatchNorm()(outputs, is_training=is_training) outputs = tf.nn.relu(outputs) outputs = snt.BatchFlatten()(outputs) outputs = tf.nn.dropout(outputs, keep_prob=keep_prob) outputs = snt.Linear(output_size=10)(outputs) return outputs
def compute_top_delta(self, z): """ parameterization of topD. This converts the top level activation to an error signal. Args: z: tf.Tensor batch of final layer post activations Returns delta: tf.Tensor the error signal """ s_idx = 0 with tf.variable_scope('compute_top_delta'), tf.device( self.remote_device): # typically this takes [BS, length, input_channels], # We are applying this such that we convolve over the batch dimension. act = tf.expand_dims(tf.transpose(z, [1, 0]), 2) # [channels, BS, 1] mod = snt.Conv1D(output_channels=self.top_delta_size, kernel_shape=[5]) act = mod(act) act = snt.BatchNorm(axis=[0, 1])(act, is_training=False) act = tf.nn.relu(act) bs = act.shape.as_list()[0] act = tf.transpose(act, [2, 1, 0]) act = snt.Conv1D(output_channels=bs, kernel_shape=[3])(act) act = snt.BatchNorm(axis=[0, 1])(act, is_training=False) act = tf.nn.relu(act) act = snt.Conv1D(output_channels=bs, kernel_shape=[3])(act) act = snt.BatchNorm(axis=[0, 1])(act, is_training=False) act = tf.nn.relu(act) act = tf.transpose(act, [2, 1, 0]) prev_act = act for i in range(self.top_delta_layers): mod = snt.Conv1D(output_channels=self.top_delta_size, kernel_shape=[3]) act = mod(act) act = snt.BatchNorm(axis=[0, 1])(act, is_training=False) act = tf.nn.relu(act) prev_act = act mod = snt.Conv1D(output_channels=self.delta_dim, kernel_shape=[3]) act = mod(act) # [bs, feature_channels, delta_channels] act = tf.transpose(act, [1, 0, 2]) return act
def __init__(self, channels: int, stride: Union[int, Sequence[int]], use_projection: bool, bn_config: Mapping[Text, float], name: Optional[Text] = None): super(BottleNeckBlockV2, self).__init__(name=name) self._channels = channels self._stride = stride self._use_projection = use_projection self._bn_config = bn_config batchnorm_args = {"create_scale": True, "create_offset": True} batchnorm_args.update(bn_config) if self._use_projection: self._proj_conv = snt.Conv2D(output_channels=channels * 4, kernel_shape=1, stride=stride, with_bias=False, padding=snt.pad.same, name="shortcut_conv") self._conv_0 = snt.Conv2D(output_channels=channels, kernel_shape=1, stride=1, with_bias=False, padding=snt.pad.same, name="conv_0") self._bn_0 = snt.BatchNorm(name="batchnorm_0", **batchnorm_args) self._conv_1 = snt.Conv2D(output_channels=channels, kernel_shape=3, stride=stride, with_bias=False, padding=snt.pad.same, name="conv_1") self._bn_1 = snt.BatchNorm(name="batchnorm_1", **batchnorm_args) self._conv_2 = snt.Conv2D(output_channels=channels * 4, kernel_shape=1, stride=1, with_bias=False, padding=snt.pad.same, name="conv_2") # NOTE: Some implementations of ResNet50 v2 suggest initializing gamma/scale # here to zeros. self._bn_2 = snt.BatchNorm(name="batchnorm_2", **batchnorm_args)
def __init__(self, name='Code_Discriminator', regularization=1.e-4): super(CodeDiscriminator, self).__init__(name=name) reg = { 'w': l2_regularizer(scale=regularization), 'b': l2_regularizer(scale=regularization) } with self._enter_variable_scope(): self.l1 = snt.Linear(name='l1', output_size=750, regularizers=reg) self.bn1 = snt.BatchNorm(name='batch_norm_1') self.l2 = snt.Linear(name='l2', output_size=750, regularizers=reg) self.bn2 = snt.BatchNorm(name='batch_norm_2') self.l3 = snt.Linear(name='l3', output_size=1, regularizers=reg)
def _build(self, inputs, is_training=True): if EncodeProcessDecode_v1.convnet_tanh: activation = tf.nn.tanh else: activation = tf.nn.relu # input shape is (batch_size, feature_length) but CNN operates on depth channels --> (batch_size, feature_length, 1) inputs = tf.expand_dims(inputs, axis=2) ''' layer 1''' outputs = snt.Conv1D(output_channels=12, kernel_shape=10, stride=2)(inputs) outputs = snt.BatchNorm()(outputs, is_training=is_training) if EncodeProcessDecode_v1.convnet_pooling: outputs = tf.layers.max_pooling1d(outputs, 2, 2) outputs = activation(outputs) #print(outputs.get_shape()) ''' layer 2''' outputs = snt.Conv1D(output_channels=12, kernel_shape=10, stride=2)(outputs) outputs = snt.BatchNorm()(outputs, is_training=is_training) if EncodeProcessDecode_v1.convnet_pooling: outputs = tf.layers.max_pooling1d(outputs, 2, 2) outputs = activation(outputs) #print(outputs.get_shape()) ''' layer 3''' outputs = snt.Conv1D(output_channels=12, kernel_shape=10, stride=2)(outputs) outputs = snt.BatchNorm()(outputs, is_training=is_training) if EncodeProcessDecode_v1.convnet_pooling: outputs = tf.layers.max_pooling1d(outputs, 2, 2) outputs = activation(outputs) #print(outputs.get_shape()) ''' layer 4''' outputs = snt.Conv1D(output_channels=12, kernel_shape=10, stride=2)(outputs) outputs = snt.BatchNorm()(outputs, is_training=is_training) # todo: deal with train/test time if EncodeProcessDecode_v1.convnet_pooling: outputs = tf.layers.max_pooling1d(outputs, 2, 2) outputs = activation(outputs) #print(outputs.get_shape()) ''' layer 5''' outputs = snt.BatchFlatten()(outputs) #outputs = tf.nn.dropout(outputs, keep_prob=tf.constant(1.0)) # todo: deal with train/test time outputs = snt.Linear(output_size=EncodeProcessDecode_v1.dimensions_latent_repr)(outputs) #print(outputs.get_shape()) return outputs
def _build(self, input_image): leaky_relu_activation = lambda x: tf.maximum( self._leaky_relu_coeff * x, x) init_dict = { 'w': tf.truncated_normal_initializer(seed=547, stddev=0.02), 'b': tf.constant_initializer(0.3) } layer1 = snt.Conv2D(output_channels=64, kernel_shape=[4, 4], stride=2, initializers=init_dict)(input_image) layer2 = leaky_relu_activation(layer1) layer3 = snt.Conv2D(output_channels=128, kernel_shape=[4, 4], stride=2, initializers=init_dict)(layer2) layer4 = snt.BatchNorm(offset=1, scale=1, decay_rate=0.9)(layer3, is_training=True, test_local_stats=True) layer5 = leaky_relu_activation(layer4) layer6 = snt.BatchFlatten()(layer5) layer7 = snt.Linear(output_size=1024, initializers=init_dict)(layer6) layer8 = snt.BatchNorm(offset=1, scale=1, decay_rate=0.9)(layer7, is_training=True, test_local_stats=True) layer9 = leaky_relu_activation(layer8) classification_logits = snt.Linear(2, initializers=init_dict)(layer9) # conv2d = snt.nets.ConvNet2D( # output_channels=[8, 16, 32, 64, 128], # kernel_shapes=[[5, 5]], # strides=[2, 1, 2, 1, 2], # paddings=[snt.SAME], # activate_final=True, # activation=leaky_relu_activation, # use_batch_norm=False, # initializers=init_dict) # convolved = conv2d(input_image) # # Flatten the data to 2D for the classification layer # flat_data = snt.BatchFlatten()(convolved) # # We have two classes: one for real, and oen for fake data. # classification_logits = snt.Linear(2,initializers=init_dict)(flat_data) return classification_logits
def __init__(self, name='Encoder', latent_size=50, image_size=64, ndf=64, regularization=1.e-4): super(Encoder, self).__init__(name=name) reg = {'w': l2_regularizer(scale=regularization)} self.convs = [] self.batch_norms = [] self.latent_size = latent_size csize, cndf = image_size / 2, ndf with self._enter_variable_scope(): self.convs.append(snt.Conv2D(name='conv2d_1', output_channels=ndf, kernel_shape=4, stride=2, padding='SAME', regularizers=reg, use_bias=False)) self.batch_norms.append(snt.BatchNorm(name='batch_norm_1')) n_layer = 2 while csize > 4: self.convs.append(snt.Conv2D(name='conv2d_{}'.format(n_layer), output_channels=cndf * 2, kernel_shape=4, stride=2, padding='SAME', regularizers=reg, use_bias=False)) self.batch_norms.append(snt.BatchNorm(name='batch_norm_{}'.format(n_layer))) cndf = cndf * 2 csize = csize // 2 self.mean = snt.Conv2D(name='conv_mean', output_channels=latent_size, kernel_shape=4, stride=1, padding='VALID', regularizers=reg, use_bias=False) self.variance = snt.Conv2D(name='conv_variance', output_channels=latent_size, kernel_shape=4, stride=1, padding='VALID', regularizers=reg, use_bias=False)
def _build(self, inputs, is_training): net = Unit3D(output_channels=self._output_channels[0], kernel_shape=[1, 1, 1], padding=self.padding, use_bias=self._use_bias, name=self.name + "_1")(inputs, is_training=is_training) net = SepConv(output_channels=self._output_channels[1], kernel_shape=self.kernel_shape, padding=snt.SAME, name=self.name + "_2")(net, is_training=is_training) net = snt.Conv3D(output_channels=self._output_channels[2], kernel_shape=[1, 1, 1], padding=self.padding, use_bias=self._use_bias, name=self.name + "_3")(net) net = snt.BatchNorm()(net, is_training=is_training, test_local_stats=False) net = layers.add([net, inputs]) net = self._activation_fn(net) return net
def _build(self, x, is_training): def _build_layer(x, batch_norm, dropout, premade_batchnorm, premade_linear): if batch_norm: bn = premade_batchnorm(x, is_training=is_training) else: bn = x lin = premade_linear(bn) out = self._nonlinearity(lin) if dropout < 1: out = tf.nn.dropout(out, dropout) return out current_in = x bn = [snt.BatchNorm() for _ in range(self._n)] lin = [ snt.Linear(output_size=self._dim, regularizers=self._regularizers, initializers=self._initializers) for i in range(self._n) ] projection = snt.Linear(output_size=self._out_dim, regularizers=self._regularizers, initializers=self._initializers) for i in range(self._n): current_in = _build_layer(current_in, self._batch_norm, self._dropout, bn[i], lin[i]) out_projection = projection(current_in) return out_projection
def testRegularizersInRegularizationLosses(self, offset, scale): regularizers = {} if offset: regularizers["beta"] = tf.contrib.layers.l1_regularizer(scale=0.5) if scale: regularizers["gamma"] = tf.contrib.layers.l2_regularizer(scale=0.5) inputs_shape = [10, 10] inputs = tf.placeholder(tf.float32, shape=[None] + inputs_shape) bn = snt.BatchNorm(offset=offset, scale=scale, regularizers=regularizers) self.assertEqual(bn.regularizers, regularizers) bn(inputs, is_training=True) graph_regularizers = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if not offset and not scale: self.assertFalse(graph_regularizers) if offset and not scale: self.assertRegexpMatches(graph_regularizers[0].name, ".*l1_regularizer.*") if scale and not offset: self.assertRegexpMatches(graph_regularizers[0].name, ".*l2_regularizer.*") if scale and offset: self.assertRegexpMatches(graph_regularizers[0].name, ".*l1_regularizer.*") self.assertRegexpMatches(graph_regularizers[1].name, ".*l2_regularizer.*")
def testInitializers(self, offset, scale): initializers = { "moving_mean": tf.constant_initializer(2.0), "moving_variance": tf.constant_initializer(3.0), } if scale: initializers["gamma"] = tf.constant_initializer(4.0) if offset: initializers["beta"] = tf.constant_initializer(5.0) inputs_shape = [10, 10] inputs = tf.placeholder(tf.float32, shape=[None] + inputs_shape) bn = snt.BatchNorm(offset=offset, scale=scale, initializers=initializers) self.assertEqual(bn.initializers, initializers) bn(inputs, is_training=True) init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) ones_v = np.ones([1, 1, inputs_shape[-1]]) self.assertAllClose(bn.moving_mean.eval(), ones_v * 2.0) self.assertAllClose(bn.moving_variance.eval(), ones_v * 3.0) if scale: self.assertAllClose(bn.gamma.eval(), ones_v * 4.0) if offset: self.assertAllClose(bn.beta.eval(), ones_v * 5.0)
def testUpdatesInsideCond(self): """Demonstrate that updates inside a cond fail. """ _, input_v, inputs = self._get_inputs() bn = snt.BatchNorm(offset=False, scale=False, decay_rate=0.5) condition = tf.placeholder(tf.bool) cond = tf.cond(condition, lambda: bn(inputs, is_training=True), lambda: inputs) init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) out_v = sess.run(cond, feed_dict={condition: False}) self.assertAllClose(input_v, out_v) out_v = sess.run(cond, feed_dict={condition: True}) self.assertAllClose(np.zeros([7, 6]), out_v, rtol=1e-4, atol=1e-4) # Variables are accessible outside the tf.cond() mm, mv = sess.run([bn.moving_mean, bn.moving_variance]) self.assertAllClose(np.zeros([1, 6]), mm) self.assertAllClose(np.ones([1, 6]), mv) # Tensors are not accessible outside the tf.cond() with self.assertRaisesRegexp(ValueError, "Operation"): sess.run(tuple(tf.get_collection(tf.GraphKeys.UPDATE_OPS)))
def __init__(self, make_gnn_fn, num_timesteps, weight_sharing=False, use_batch_norm=False, residual=True, test_local_stats=False, use_layer_norm=False, name="TimestepGNN"): super(TimestepGNN, self).__init__(name=name) self._weight_sharing = weight_sharing self._num_timesteps = num_timesteps self._use_batch_norm = use_batch_norm self._residual = residual self._bns = [] self._lns = [] self._test_local_stats = test_local_stats self._use_layer_norm = use_layer_norm with self._enter_variable_scope(): if not weight_sharing: self._gnn = [make_gnn_fn() for _ in range(num_timesteps)] else: self._gnn = make_gnn_fn() if use_batch_norm: self._bns = [ snt.BatchNorm(scale=True) for _ in range(num_timesteps) ] if use_layer_norm: self._lns = [snt.LayerNorm() for _ in range(num_timesteps)]
def _build(self, input, is_training=False): """Adds the network into the graph.""" # TODO(drewjaegle): add initializers, etc. input_shape = input.get_shape().as_list() input = tf.reshape(input, [input_shape[0], -1]) if self._use_batchnorm: bn = lambda x: snt.BatchNorm(axis=[0])(x, is_training) else: bn = lambda x: x discriminator = snt.nets.MLP( self._layers, activation=lambda x: tf.nn.leaky_relu(bn(x)), activate_final=False, regularizers=self._regularizers, initializers={ "w": tf.variance_scaling_initializer(scale=1e-4), "b": tf.constant_initializer(value=0.01) }) logits = discriminator(input) if self._final_activation is not None: logits = self._final_activation(logits) return logits
def func(name, data_format, custom_getter=None): conv = snt.Conv3D( name=name, output_channels=self.OUT_CHANNELS, kernel_shape=self.KERNEL_SHAPE, use_bias=use_bias, initializers=create_initializers(use_bias), data_format=data_format, custom_getter=custom_getter) if data_format == "NDHWC": batch_norm = snt.BatchNorm(scale=True, update_ops_collection=None) else: # data_format = "NCDHW" batch_norm = snt.BatchNorm(scale=True, update_ops_collection=None, axis=(0, 2, 3, 4)) return snt.Sequential([conv, functools.partial(batch_norm, is_training=True)])
def testCheckStatsDouble(self, dtype): """The correct statistics are being computed for double connection. Connected in parallel, it's ill-defined what order the updates will happen in. A double update could happen, or two sequential updates. E.g. If decay_rate is 0.9, the start value is 1.0, and the target value is 0.0, the value could progress as 1.00 -> 0.90 -> 0.81, if the second update uses the fresh second value. Or as 1.00 -> 0.90 -> 0.80 if the second update uses the stale first value. We fix this here by running them in sequential run calls to ensure that this test is deterministic. The two situations are minimally different, especially if decay_rate is close to one (e.g. the default of 0.999). Args: dtype: TensorFlow datatype of input test batch. """ v, _, inputs = self._get_inputs(dtype) bn = snt.BatchNorm(offset=False, scale=False, decay_rate=0.9) with tf.name_scope("net1"): bn(inputs, is_training=True) with tf.name_scope("net2"): bn(inputs, is_training=True) update_ops_1 = tuple(tf.get_collection(tf.GraphKeys.UPDATE_OPS, "net1")) self.assertEqual(len(update_ops_1), 2) update_ops_2 = tuple(tf.get_collection(tf.GraphKeys.UPDATE_OPS, "net2")) self.assertEqual(len(update_ops_2), 2) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) mm, mv = sess.run([bn.moving_mean, bn.moving_variance]) self.assertAllClose(np.zeros([1, 6]), mm) self.assertAllClose(np.ones([1, 6]), mv) sess.run(update_ops_1) sess.run(update_ops_2) mm, mv = sess.run([bn.moving_mean, bn.moving_variance]) correct_mm = (1.0 - bn._decay_rate) * v correct_mm = (1.0 - bn._decay_rate) * v + bn._decay_rate * correct_mm correct_mv = np.ones([1, 6]) * bn._decay_rate**2 self.assertAllClose(np.reshape(correct_mm, [1, 6]), mm) self.assertAllClose(np.reshape(correct_mv, [1, 6]), mv)
def func(name, data_format="NHWC"): conv = snt.Conv2D(name=name, output_channels=self.OUT_CHANNELS, kernel_shape=self.KERNEL_SHAPE, use_bias=use_bias, initializers=create_constant_initializers( 1.0, 1.0, use_bias), data_format=data_format) if data_format == "NHWC": bn = snt.BatchNorm(scale=True, update_ops_collection=None) else: bn = snt.BatchNorm(scale=True, update_ops_collection=None, fused=True, axis=(0, 2, 3)) return snt.Sequential([conv, bn])
def _build(self, inputs, is_training): """Connects the module to inputs. Args: inputs: Inputs to the SepConv component. is_training: whether to use training mode for snt.BatchNorm (boolean). Returns: Outputs from the module. """ intermediate = snt.Conv3D(output_channels=self._output_channels, kernel_shape=self._sp_kernel_shape, stride=self._sp_stride_shape, padding=self.padding, use_bias=self._use_bias)(inputs) net = snt.Conv3D(output_channels=self._output_channels, kernel_shape=self._temp_kernel_shape, stride=self._temp_stride_shape, padding=self.padding, use_bias=self._use_bias)(intermediate) if self._use_batch_norm: bn = snt.BatchNorm() net = bn(net, is_training=is_training, test_local_stats=False) if self._activation_fn is not None: net = self._activation_fn(net) return net
def testCollectionGetSaver(self): with tf.variable_scope("prefix") as s1: input_ = tf.placeholder(tf.float32, shape=[3, 4]) net = snt.Linear(10)(input_) net = snt.BatchNorm()(net, is_training=True) saver1 = snt.get_saver(s1) saver2 = snt.get_saver( s1, collections=(tf.GraphKeys.TRAINABLE_VARIABLES, )) self.assertIsInstance(saver1, tf.train.Saver) self.assertIsInstance(saver2, tf.train.Saver) self.assertEqual(len(saver1._var_list), 5) self.assertIn("linear/w", saver1._var_list) self.assertIn("linear/b", saver1._var_list) self.assertIn("batch_norm/beta", saver1._var_list) self.assertIn("batch_norm/moving_mean", saver1._var_list) self.assertIn("batch_norm/moving_variance", saver1._var_list) self.assertEqual(len(saver2._var_list), 3) self.assertIn("linear/w", saver2._var_list) self.assertIn("linear/b", saver2._var_list) self.assertIn("batch_norm/beta", saver2._var_list) self.assertNotIn("batch_norm/moving_mean", saver2._var_list) self.assertNotIn("batch_norm/moving_variance", saver2._var_list)
def _build(self, inputs, spatial_cnn, motion_cnn, is_training): if spatial_cnn is not None and motion_cnn is not None: # print('spatial cnn and motion cnn is not None') ST = self._concat(spatial_cnn, motion_cnn, is_training=is_training, activation_fn=None) ST = snt.Conv2D(output_channels=self._output_channels, kernel_shape=[1, 1])(ST) if inputs is not None: # net = self._mix(inputs,ST,is_training=is_training) net = tf.concat([inputs, ST], axis=-1) else: # print('inputs is none') net = ST else: if inputs is not None: net = inputs net = snt.Conv2D(output_channels=self._output_channels, kernel_shape=self._kernel_shape, stride=self._stride, padding=snt.SAME, use_bias=self._use_bias)(net) if self._use_batch_norm: bn = snt.BatchNorm() net = bn(net, is_training=is_training, test_local_stats=False) if self._activation_fn is not None: net = self._activation_fn(net) return net