コード例 #1
0
    def _build(self, inputs, is_training):

        x_h, x_l = inputs if type(inputs) is tuple else (inputs, None)

        if self._stride == 2 or self._stride == (2,
                                                 2) or self._stride == [2, 2]:
            x_h = tf.nn.avg_pool(x_h, (1, 2, 2, 1),
                                 (1, 2, 2,
                                  1), 'SAME') if x_h is not None else None
            x_l = tf.nn.avg_pool(x_l, (1, 2, 2, 1),
                                 (1, 2, 2,
                                  1), 'SAME') if x_l is not None else None

        _, h, w, _ = x_h.shape.as_list()

        l_out = self._output_channels * self._ratio
        h_out = self._output_channels - l_out

        x_h2h, x_h2l, x_l2l, x_l2h = None, None, None, None
        if x_h is not None:
            x_h2h = snt.Conv2D(output_channels=h_out,
                               kernel_shape=self._kernel_shape,
                               padding=snt.SAME,
                               name='h2h')(x_h)
            if l_out > 0:
                x_h2l = tf.nn.avg_pool(x_h, (1, 2, 2, 1), (1, 2, 2, 1),
                                       padding=snt.SAME)
                x_h2l = snt.Conv2D(output_channels=l_out,
                                   kernel_shape=self._kernel_shape,
                                   padding=snt.SAME,
                                   name='h2l')(x_h2l)

        if x_l is not None:
            if l_out > 0:
                x_l2l = snt.Conv2D(output_channels=l_out,
                                   kernel_shape=self._kernel_shape,
                                   padding=snt.SAME,
                                   name='l2l')(x_l)
            x_l2h = snt.Conv2D(output_channels=h_out,
                               kernel_shape=self._kernel_shape,
                               padding=snt.SAME,
                               name='l2h')(x_l)
            x_l2h = tf.image.resize_nearest_neighbor(x_l2h, (h, w))

        y_h = x_h2h + x_l2h if x_l2h is not None else x_h2h
        y_l = x_h2l + x_l2l if x_l2l is not None else x_h2l

        if self._use_bn:
            bn1 = snt.BatchNormV2(name='h_bn')
            bn2 = snt.BatchNormV2(name='l_bn')
            y_h = bn1(y_h, is_training=is_training,
                      test_local_stats=False) if y_h is not None else None
            y_l = bn2(y_l, is_training=is_training,
                      test_local_stats=False) if y_l is not None else None

        if self._actvation_fn is not None:
            y_h = self._actvation_fn(y_h) if y_h is not None else None
            y_l = self._actvation_fn(y_l) if y_l is not None else None

        return y_h if y_l is None else (y_h, y_l)
コード例 #2
0
    def factorized_reduction(cls, inputs, is_training, out_filters, stride):
        assert out_filters % 2 == 0, (
            'Need even number of filters when using this factorized reduction')
        if stride == 1:
            with tf.variable_scope('path_conv'):
                net = snt.Conv2D(out_filters, kernel_shape=1)(inputs)
                net = snt.BatchNormV2(scale=True, decay_rate=0.9,
                                      eps=1e-5)(net, is_training=is_training)
                return net

        # Skip path 1
        path1 = slim.avg_pool2d(inputs,
                                kernel_size=1,
                                stride=stride,
                                padding='VALID')
        with tf.variable_scope('path1_conv'):
            path1 = snt.Conv2D(out_filters // 2, kernel_shape=1)(path1)

        # Skip path 2
        pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]]
        path2 = tf.pad(inputs, pad_arr)[:, 1:, 1:, :]
        path2 = slim.avg_pool2d(path2,
                                kernel_size=1,
                                stride=stride,
                                padding='VALID')
        with tf.variable_scope('path2_conv'):
            path2 = snt.Conv2D(out_filters // 2, kernel_shape=1)(path2)

        final_path = tf.concat([path1, path2], axis=-1)
        final_path = snt.BatchNormV2(scale=True, decay_rate=0.9,
                                     eps=1e-5)(final_path,
                                               is_training=is_training)
        return final_path
コード例 #3
0
ファイル: batch_norm_v2_test.py プロジェクト: zwcdp/sonnet
  def testFusedBatchNormV2(self, is_training, test_local_stats, scale,
                           is_training_python_bool):
    input_shape = (32, 9, 9, 8)
    iterations = 5
    x = tf.placeholder(tf.float32, shape=input_shape)
    bn1 = snt.BatchNormV2(scale=scale)
    bn2 = snt.BatchNormV2(fused=False, scale=scale)

    xx = np.random.random(input_shape)
    feed_dict = {x: xx}
    if not is_training_python_bool:
      is_training_node = tf.placeholder(tf.bool, shape=())
      feed_dict.update({is_training_node: is_training})
      is_training = is_training_node
      test_local_stats_node = tf.placeholder(tf.bool, shape=())
      feed_dict.update({test_local_stats_node: test_local_stats})
      test_local_stats = test_local_stats_node

    o1 = bn1(x, is_training=is_training, test_local_stats=test_local_stats)
    o2 = bn2(x, is_training=is_training, test_local_stats=test_local_stats)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      params = [
          o1, o2, bn1._moving_mean, bn1._moving_variance, bn2._moving_mean,
          bn2._moving_variance
      ]
      for _ in range(iterations):
        y1, y2, mean1, var1, mean2, var2 = sess.run(params, feed_dict=feed_dict)
        self.assertAllClose(y1, y2, atol=1e-4)
        self.assertAllClose(mean1, mean2, atol=1e-4)
        self.assertAllClose(var1, var2, atol=1e-4)
コード例 #4
0
 def __init__(self,
              observation_spaces,
              action_spaces,
              shared_policy=False,
              shared_critic=False,
              hyperparameters=None,
              name=None):
     name = 'maddpg_module' if name is None else name
     super().__init__(name=name)
     hyperparameters = hyperparameters if hyperparameters else {}
     self.policy_group = PolicyGroup(observation_spaces,
                                     action_spaces,
                                     shared=shared_policy,
                                     hyperparameters=hyperparameters.get(
                                         'policy', {}).copy())
     self.critic_group = CriticGroup(observation_spaces,
                                     action_spaces,
                                     shared=shared_critic,
                                     hyperparameters=hyperparameters.get(
                                         'critic', {}).copy())
     self.normalize = hyperparameters.get('normalize', {}).copy()
     if self.normalize:
         if self.normalize.get('reward'):
             self.normalize['reward'] = {
                 name: snt.BatchNormV2(decay_rate=1.0 - 1e-4)
                 for name in action_spaces
             }
         if self.normalize.get('observation'):
             self.normalize['observation'] = {
                 name: snt.BatchNormV2(decay_rate=1.0 - 1e-4)
                 for name in action_spaces
             }
     self.observation_spaces = observation_spaces
     self.action_spaces = action_spaces
コード例 #5
0
ファイル: batch_norm_v2_test.py プロジェクト: zwcdp/sonnet
  def testCheckpointCompatibility(self):
    save_path = os.path.join(self.get_temp_dir(), "basic_save_restore")

    input_shape_1 = (31, 7, 7, 5)
    input_shape_2 = (31, 5, 7, 7)

    x1 = tf.placeholder(tf.float32, shape=input_shape_1)
    bn1 = snt.BatchNormV2(data_format="NHWC")
    bn1(x1, is_training=True)
    saver1 = snt.get_saver(bn1)

    x2 = tf.placeholder(tf.float32, shape=input_shape_2)
    bn2 = snt.BatchNormV2(data_format="NCHW")
    bn2(x2, is_training=False)
    saver2 = snt.get_saver(bn2)

    x3 = tf.placeholder(tf.float32, shape=input_shape_1)
    bn3 = snt.BatchNormV2(data_format="NCHW")
    bn3(x3, is_training=False)
    saver3 = snt.get_saver(bn3)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      saver1.save(sess, save_path)
      saver2.restore(sess, save_path)
      with self.assertRaises(tf.errors.InvalidArgumentError):
        saver3.restore(sess, save_path)
コード例 #6
0
ファイル: batch_norm_v2_test.py プロジェクト: zwcdp/sonnet
  def testInvalidRegularizationParameters(self):
    with self.assertRaisesRegexp(KeyError, "Invalid regularizer keys.*"):
      snt.BatchNormV2(
          regularizers={"not_gamma": tf.contrib.layers.l1_regularizer(0.5)})

    err = "Regularizer for 'gamma' is not a callable function"
    with self.assertRaisesRegexp(TypeError, err):
      snt.BatchNormV2(regularizers={"gamma": tf.zeros([1, 2, 3])})
コード例 #7
0
ファイル: macro_child.py プロジェクト: MichaelChuai/modelzoo
 def enas_layer(self, prev_layers, layer_id, start_idx, sample_arc,
                is_training):
     layer = LayerCollection(self.layer_setups,
                             self.out_filters,
                             self.use_default_layer_func,
                             name='layercol')
     inputs = prev_layers[-1]
     layer_type_id = sample_arc[start_idx]
     out = layer(inputs, is_training, layer_type_id)
     if layer_id > 0:
         skip_start = start_idx + 1
         layer_skip_count = layer_id
         skip = sample_arc[skip_start:skip_start + layer_skip_count]
         with tf.variable_scope('skip'):
             res_layers = []
             for i in range(layer_skip_count):
                 res_layers.append(
                     tf.cond(tf.equal(skip[i],
                                      1), lambda: prev_layers[i + 1],
                             lambda: tf.zeros_like(prev_layers[i + 1])))
             res_layers.append(out)
             out = tf.add_n(res_layers)
             out = snt.BatchNormV2(scale=True, decay_rate=0.9,
                                   eps=1e-5)(out, is_training=is_training)
     return out
コード例 #8
0
    def _build(self, inputs, is_training, k=-1, r=0):
        if not isinstance(k, int):
            inputs = tf.cond(tf.equal(k, 0),
                             true_fn=lambda: mixup_process(inputs=inputs, r=r),
                             false_fn=lambda: inputs)

        h = snt.Linear(output_size=self.hidden_size)(inputs)
        h = tf.layers.Dropout(rate=self.drop_rate)(h, is_training)

        if not isinstance(k, int):
            h = tf.cond(tf.equal(k, 1),
                        true_fn=lambda: mixup_process(inputs=h, r=r),
                        false_fn=lambda: h)

        for i in range(self.num_highways):
            h = Highway()(h)

            if self.use_batch_norm:
                h = snt.BatchNormV2(data_format='NC')(h, is_training)
            elif self.use_layer_norm:
                h = snt.LayerNorm(axis=1)(h)

            if self.activation != 'linear':
                h = Activation(activation=self.activation)(h)

            if self.use_dropout:
                h = tf.layers.Dropout(rate=self.drop_rate)(h, is_training)

            if not isinstance(k, int):
                h = tf.cond(tf.equal(k, i + 2),
                            true_fn=lambda: mixup_process(inputs=h, r=r),
                            false_fn=lambda: h)

        outputs = snt.Linear(output_size=self.output_size)(h)
        return outputs
コード例 #9
0
ファイル: batch_norm_v2_test.py プロジェクト: zwcdp/sonnet
  def testUpdatesInsideCond(self):
    """Demonstrate that updates inside a cond fail."""

    _, input_v, inputs = self._get_inputs()
    bn = snt.BatchNormV2(
        offset=False,
        scale=False,
        decay_rate=0.5,
        update_ops_collection=tf.GraphKeys.UPDATE_OPS)
    condition = tf.placeholder(tf.bool)
    cond = tf.cond(condition,
                   lambda: bn(inputs, is_training=True),
                   lambda: inputs)

    init = tf.global_variables_initializer()

    with self.test_session() as sess:
      sess.run(init)
      out_v = sess.run(cond, feed_dict={condition: False})
      self.assertAllClose(input_v, out_v)

      out_v = sess.run(cond, feed_dict={condition: True})
      self.assertAllClose(np.zeros([7, 6]), out_v, rtol=1e-4, atol=1e-4)

      # Variables are accessible outside the tf.cond()
      mm, mv = sess.run([bn.moving_mean, bn.moving_variance])
      self.assertAllClose(np.zeros([1, 6]), mm)
      self.assertAllClose(np.ones([1, 6]), mv)

      # Tensors are not accessible outside the tf.cond()
      with self.assertRaisesRegexp(ValueError, "Operation"):
        sess.run(tuple(tf.get_collection(tf.GraphKeys.UPDATE_OPS)))
コード例 #10
0
ファイル: Res3D.py プロジェクト: imnotk/action_tf
    def _build(self, inputs, is_training):
        t = self.kernel_shape[0]
        h = self.kernel_shape[1]
        w = self.kernel_shape[2]
        t_stride = self.stride[0]
        h_stride = self.stride[1]
        w_stride = self.stride[2]
        net = snt.Conv3D(output_channels=self.output_channels,
                         kernel_shape=(1, h, w),
                         stride=(1, h_stride, w_stride),
                         padding=snt.SAME,
                         initializers=self.initializer,
                         use_bias=self.use_bias,
                         regularizers=regularizers,
                         name='conv_3d')(inputs)
        net = tf.nn.relu(net)
        net = snt.Conv3D(output_channels=self.output_channels,
                         kernel_shape=(t, 1, 1),
                         stride=(t_stride, 1, 1),
                         padding=snt.SAME,
                         initializers=ones_initializer,
                         use_bias=self.use_bias,
                         regularizers=regularizers,
                         name='conv_3d_temporal')(net)

        if self.use_batch_norm:
            bn = snt.BatchNormV2(scale=True)
            net = bn(net, is_training=is_training, test_local_stats=False)
        if self.activation_fn is not None:
            net = self.activation_fn(net)
        return net
コード例 #11
0
ファイル: batch_norm_v2_test.py プロジェクト: zwcdp/sonnet
  def testInitializers(self, offset, scale):
    initializers = {
        "moving_mean": tf.constant_initializer(2.0),
        "moving_variance": tf.constant_initializer(3.0),
    }

    if scale:
      initializers["gamma"] = tf.constant_initializer(4.0)
    if offset:
      initializers["beta"] = tf.constant_initializer(5.0)

    inputs_shape = [10, 10]
    inputs = tf.placeholder(tf.float32, shape=[None] + inputs_shape)
    bn = snt.BatchNormV2(
        offset=offset,
        scale=scale,
        initializers=initializers)
    self.assertEqual(bn.initializers, initializers)
    bn(inputs, is_training=True)

    init = tf.global_variables_initializer()
    with self.test_session() as sess:
      sess.run(init)

      ones_v = np.ones([1, 1, inputs_shape[-1]])
      self.assertAllClose(bn.moving_mean.eval(), ones_v * 2.0)
      self.assertAllClose(bn.moving_variance.eval(), ones_v * 3.0)

      if scale:
        self.assertAllClose(bn.gamma.eval(), ones_v * 4.0)
      if offset:
        self.assertAllClose(bn.beta.eval(), ones_v * 5.0)
コード例 #12
0
ファイル: batch_norm_v2_test.py プロジェクト: zwcdp/sonnet
  def testUpdateImproveStatistics(self):
    """Test that updating the moving_mean improves statistics."""

    _, _, inputs = self._get_inputs()

    # Use small decay_rate to update faster.
    bn = snt.BatchNormV2(
        offset=False,
        scale=False,
        decay_rate=0.1,
        update_ops_collection=tf.GraphKeys.UPDATE_OPS)
    out1 = bn(inputs, is_training=False, test_local_stats=False)

    # Build the update ops.
    bn(inputs, is_training=True)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      out_v = sess.run(out1)

      # Before updating the moving_mean the results are off.
      self.assertBetween(np.max(np.abs(np.zeros([7, 6]) - out_v)), 2, 5)

      sess.run(tuple(tf.get_collection(tf.GraphKeys.UPDATE_OPS)))

      # After updating the moving_mean the results are better.
      out_v = sess.run(out1)
      self.assertBetween(np.max(np.abs(np.zeros([7, 6]) - out_v)), 1, 2)
コード例 #13
0
ファイル: batch_norm_v2_test.py プロジェクト: zwcdp/sonnet
  def testVariableBatchSize(self):
    """Check the inputs batch_size can change."""

    inputs_shape = [10, 10]
    inputs = tf.placeholder(tf.float32, shape=[None] + inputs_shape)
    bn = snt.BatchNormV2(
        offset=False, scale=False)

    # Outputs should be equal to inputs.
    out = bn(inputs,
             is_training=False,
             test_local_stats=False)

    init = tf.global_variables_initializer()
    update_ops = tuple(tf.get_collection(tf.GraphKeys.UPDATE_OPS))

    with self.test_session() as sess:
      sess.run(init)

      for batch_size in [1, 3, 10]:
        input_data = np.random.rand(batch_size, *inputs_shape)
        out_v = sess.run(out, feed_dict={inputs: input_data})
        self.assertAllClose(input_data / np.sqrt(1.0 + bn._eps), out_v)

        sess.run(update_ops, feed_dict={inputs: input_data})
コード例 #14
0
ファイル: batch_norm_v2_test.py プロジェクト: zwcdp/sonnet
  def testRegularizersInRegularizationLosses(self, offset, scale):
    regularizers = {}
    if offset:
      regularizers["beta"] = tf.contrib.layers.l1_regularizer(scale=0.5)
    if scale:
      regularizers["gamma"] = tf.contrib.layers.l2_regularizer(scale=0.5)

    inputs_shape = [10, 10]
    inputs = tf.placeholder(tf.float32, shape=[None] + inputs_shape)
    bn = snt.BatchNormV2(
        offset=offset,
        scale=scale,
        regularizers=regularizers)
    self.assertEqual(bn.regularizers, regularizers)
    bn(inputs, is_training=True)

    graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    if not offset and not scale:
      self.assertFalse(graph_regularizers)
    if offset and not scale:
      self.assertRegexpMatches(graph_regularizers[0].name, ".*l1_regularizer.*")
    if scale and not offset:
      self.assertRegexpMatches(graph_regularizers[0].name, ".*l2_regularizer.*")
    if scale and offset:
      self.assertRegexpMatches(graph_regularizers[0].name, ".*l1_regularizer.*")
      self.assertRegexpMatches(graph_regularizers[1].name, ".*l2_regularizer.*")
コード例 #15
0
    def __init__(self, conf, name="encoder_layer"):
        """ Inits the module.

        Args:
            name: The module name.
        """
        super(EncoderLayer, self).__init__(name=name)
        self.conf = conf
        self.training = True
        with self._enter_variable_scope():
            batch_initializer = {
                'gamma': utils.initializer(conf.embedding_dim),
                'moving_mean': utils.initializer(conf.embedding_dim),
                'moving_variance': utils.initializer(conf.embedding_dim),
                'beta': utils.initializer(conf.embedding_dim)
            }

            self._mha = MultiHeadAttentionResidual(conf)

            self._batch_norm0 = snt.BatchNormV2(scale=True,
                                                initializers=batch_initializer,
                                                name="batch_norm0")

            self._batch_norm1 = snt.BatchNormV2(scale=True,
                                                initializers=batch_initializer,
                                                name="batch_norm1")

            self._lin_to_hidden = snt.Linear(
                output_size=conf.ff_hidden_size,
                initializers={
                    'w': utils.initializer(conf.embedding_dim),
                    'b': utils.initializer(conf.embedding_dim)
                },
                name="lin_to_hidden")
            self._hidden_to_ouput = snt.Linear(
                output_size=conf.embedding_dim,
                initializers={
                    'w': utils.initializer(conf.ff_hidden_size),
                    'b': utils.initializer(conf.ff_hidden_size)
                },
                name="hidden_to_ouput")
            self._feed_forward = snt.Sequential(
                [self._lin_to_hidden, tf.nn.relu, self._hidden_to_ouput],
                name="feed_forward")
            self._feed_forward_residual = snt.Residual(
                self._feed_forward, name="feed_forward_residual")
コード例 #16
0
ファイル: batch_norm_v2_test.py プロジェクト: zwcdp/sonnet
  def test16Bit(self, dtype):
    inputs = tf.placeholder(dtype, shape=[None, 64, 32, 3])
    batch_norm = snt.BatchNormV2(offset=True, scale=True, fused=False)
    output = batch_norm(inputs, is_training=True)

    self.assertEqual(dtype, output.dtype)
    self.assertEqual(tf.float32, batch_norm.moving_mean.dtype)
    self.assertEqual(tf.float32, batch_norm.moving_variance.dtype)
    self.assertEqual(dtype, batch_norm.gamma.dtype)
    self.assertEqual(dtype, batch_norm.beta.dtype)
コード例 #17
0
 def _fn(batch):
   net = snt.BatchFlatten()(batch["image"])
   for i, h in enumerate(hidden_units):
     net = snt.Linear(h)(net)
     if i != (len(hidden_units) - 1):
       net = snt.BatchNormV2()(net, is_training=False)
       net = activation(net)
   loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
       labels=batch["label_onehot"], logits=net)
   return tf.reduce_mean(loss_vec)
コード例 #18
0
    def _build(self, main_inputs, attention_inputs, is_training):

        main_theta = main_inputs
        attention_theta = attention_inputs

        height = main_theta.shape.as_list()[1]
        width = main_theta.shape.as_list()[2]
        dim_inner = main_theta.shape.as_list()[-1]

        if self._batch_size is None:
            self._batch_size = -1

        main_theta = tf.reshape(main_theta,
                                [self._batch_size, dim_inner, height * width],
                                name='main_theta_reshape')

        attention_theta = tf.reshape(attention_theta, [
            self._batch_size,
            dim_inner,
            attention_theta.shape.as_list()[1] *
            attention_theta.shape.as_list()[2],
        ],
                                     name='attention_theta_reshape')

        main_attention_theta = tf.matmul(main_theta,
                                         attention_theta,
                                         transpose_a=False,
                                         transpose_b=True,
                                         name='main_attention_theta_matmul')

        if self._use_softmax:
            p = tf.nn.softmax(main_attention_theta,
                              name='main_attention_theta_softmax')
        else:
            ones = tf.constant(1, shape=main_attention_theta.shape)
            ones = tf.reduce_sum(ones)

            zeros = tf.constant(0, shape=main_attention_theta.shape)
            denom = tf.add(zeros, ones)

            tf.stop_gradient(denom)

            p = tf.div(main_attention_theta,
                       denom,
                       name='main_attention_theta_dot_product')

        main_attention = tf.matmul(p, attention_theta)
        main_attention = tf.reshape(
            main_attention, [self._batch_size, height, width, dim_inner])

        main_attention = snt.BatchNormV2(scale=True)(main_attention,
                                                     is_training=is_training,
                                                     test_local_stats=False)

        return main_attention
コード例 #19
0
    def conv_layer(cls,
                   inputs,
                   is_training,
                   out_filters,
                   filter_size,
                   channel_multiplier,
                   separable=False):
        with tf.variable_scope('inp_conv_1'):
            net = snt.Conv2D(out_filters, kernel_shape=1)(inputs)
            net = snt.BatchNormV2(scale=True, decay_rate=0.9,
                                  eps=1e-5)(net, is_training=is_training)
            net = tf.nn.relu(net)

        with tf.variable_scope(f'out_conv_{out_filters}'):
            if separable:
                net = snt.SeparableConv2D(out_filters, channel_multiplier,
                                          filter_size)(net)
            else:
                net = snt.Conv2D(out_filters, kernel_shape=filter_size)(net)
            net = snt.BatchNormV2(scale=True, decay_rate=0.9,
                                  eps=1e-5)(net, is_training=is_training)
        return net
コード例 #20
0
 def __init__(self,
              observation_spaces,
              action_spaces,
              shared_policy=False,
              shared_critic=False,
              normalize=None,
              name='maddpg'):
     name = 'maddpg_module' if name is None else name
     super().__init__(name=name)
     self.best_policy_group = PolicyGroup(observation_spaces,
                                          action_spaces,
                                          shared=shared_policy,
                                          name='best_policy')
     self.worst_policy_group = PolicyGroup(observation_spaces,
                                           action_spaces,
                                           shared=shared_policy,
                                           name='worst_policy')
     self.global_critic_group = CriticGroup(observation_spaces,
                                            action_spaces,
                                            shared=True,
                                            name='global_critic')
     self.personal_critic_group = CriticGroup(observation_spaces,
                                              action_spaces,
                                              shared=False,
                                              name='personal_critic')
     self.normalize = {}
     if normalize:
         if normalize.get('reward'):
             self.normalize['reward'] = {
                 name: snt.BatchNormV2()
                 for name in action_spaces
             }
         if normalize.get('observation'):
             self.normalize['observation'] = {
                 name: snt.BatchNormV2()
                 for name in action_spaces
             }
     self.observation_spaces = observation_spaces
     self.action_spaces = action_spaces
コード例 #21
0
    def _build(self, main_inputs , attention_inputs , is_training):
        
        # rgb_inputs and flow_inputs already has same shape as (None,frame_counts,56,56,256)
        # 

        if self._use_conv:
            main_theta = Unit3d(main_inputs.shape.as_list()[-1],activation_fn=None,use_batch_norm=False)(main_inputs,is_training=is_training)
            attention_theta = Unit3d(attention_inputs.shape.as_list()[-1],activation_fn=None,use_batch_norm=False)(attention_inputs,is_training=is_training)
        else:
            main_theta = main_inputs
            attention_theta = attention_inputs
        
        temporal = main_theta.shape.as_list()[1]
        height = main_theta.shape.as_list()[2]
        width = main_theta.shape.as_list()[3]
        dim_inner = main_theta.shape.as_list()[-1]

        if self._batch_size is None:
            self._batch_size = -1

        if self._space:
            main_theta = tf.reshape(main_theta,[self._batch_size, temporal,height * width,dim_inner],name='main_theta_reshape')
            attention_theta = tf.reshape(attention_theta,[self._batch_size, temporal,attention_theta.shape.as_list()[2] * attention_theta.shape.as_list()[3],dim_inner],name='attention_theta_reshape')
        else:
            main_theta = tf.reshape(main_theta,[self._batch_size, height * width,dim_inner * temporal],name='main_theta_reshape')
            attention_theta = tf.reshape(attention_theta,[self._batch_size, attention_theta.shape.as_list()[2] * attention_theta.shape.as_list()[3],dim_inner * temporal],name='attention_theta_reshape')

        main_attention_theta = tf.matmul(main_theta,attention_theta,transpose_a=False,transpose_b=True,name='main_attention_theta_matmul')

        if self._use_softmax:
            p = tf.nn.softmax(main_attention_theta,name='main_attention_theta_softmax')
        else:
            ones = tf.constant(1,shape=main_attention_theta.shape)
            ones = tf.reduce_sum(ones)

            zeros = tf.constant(0,shape=main_attention_theta.shape)
            denom = tf.add(zeros,ones)

            tf.stop_gradient(denom)

            p = tf.div(main_attention_theta,denom,name='main_attention_theta_dot_product')
        
        main_attention = tf.matmul(p,attention_theta)
        main_attention = tf.reshape(main_attention,[self._batch_size,temporal,height,width,dim_inner])

        if self._use_conv:
            main_attention = Unit3d(dim_inner,activation_fn=None,use_batch_norm=False)(main_attention,is_training=is_training)

        main_attention = snt.BatchNormV2(scale=True)(main_attention,is_training=is_training,test_local_stats=False)

        return main_attention
コード例 #22
0
 def _build(self, inputs, is_training):
     net = snt.Conv2D(output_channels=self.output_channels,
                      kernel_shape=self.kernel_shape,
                      stride=self.stride,
                      padding=snt.SAME,
                      initializers=initializer,
                      use_bias=self.use_bias,
                      regularizers=regularizers)(inputs)
     if self.use_batch_norm:
         bn = snt.BatchNormV2(scale=True)
         net = bn(net, is_training=is_training, test_local_stats=False)
     if self.activation_fn is not None:
         net = self.activation_fn(net)
     return net
コード例 #23
0
ファイル: batch_norm_v2_test.py プロジェクト: zwcdp/sonnet
  def testFusedBatchNormFloat16(self, is_training, test_local_stats):
    input_shape = (31, 7, 7, 5)
    iterations = 3
    x = tf.placeholder(tf.float16, shape=input_shape)
    bn1 = snt.BatchNormV2(fused=False)
    bn2 = snt.BatchNormV2()

    feed_dict = {x: np.random.random(input_shape)}

    o1 = bn1(x, is_training=is_training, test_local_stats=test_local_stats)
    o2 = bn2(x, is_training=is_training, test_local_stats=test_local_stats)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      params = [
          o1, o2, bn1._moving_mean, bn1._moving_variance, bn2._moving_mean,
          bn2._moving_variance
      ]
      for _ in range(iterations):
        y1, y2, mean1, var1, mean2, var2 = sess.run(params, feed_dict=feed_dict)
        self.assertAllClose(y1, y2, atol=1e-2)
        self.assertAllClose(mean1, mean2, atol=1e-2)
        self.assertAllClose(var1, var2, atol=1e-2)
コード例 #24
0
    def _build(self, inputs, is_training):
        net = snt.Conv2D(output_channels=self._output_channels,
                         kernel_shape=self._kernel_shape,
                         stride=self._stride,
                         use_bias=self._use_bias,
                         padding=snt.SAME)(inputs)
        if self._use_bn:
            bn = snt.BatchNormV2()
            net = bn(net, is_training=is_training, test_local_stats=False)

        if self._actvation_fn is not None:
            net = self._actvation_fn(net)

        return net
コード例 #25
0
ファイル: batch_norm_v2_test.py プロジェクト: zwcdp/sonnet
  def testConstruct(self):
    inputs = tf.placeholder(tf.float32, shape=[None, 64, 64, 3])

    batch_norm1 = snt.BatchNormV2(offset=False, scale=False, fused=False)
    batch_norm1(inputs, is_training=True)

    err = "Batch normalization doesn't have an offset, so no beta"
    with self.assertRaisesRegexp(snt.Error, err):
      _ = batch_norm1.beta

    err = "Batch normalization doesn't have a scale, so no gamma"
    with self.assertRaisesRegexp(snt.Error, err):
      _ = batch_norm1.gamma

    batch_norm2 = snt.BatchNormV2(offset=True, scale=False)
    batch_norm2(inputs, is_training=True)
    _ = batch_norm2.beta

    batch_norm3 = snt.BatchNormV2(offset=False, scale=True)
    batch_norm3(inputs, is_training=True)
    _ = batch_norm3.gamma

    batch_norm4 = snt.BatchNormV2(offset=True, scale=True)
    batch_norm4(inputs, is_training=True)
    _ = batch_norm4.beta
    _ = batch_norm4.gamma

    batch_norm4(inputs, is_training=True, test_local_stats=True)
    batch_norm4(inputs,
                is_training=tf.constant(True),
                test_local_stats=tf.constant(True))

    is_training_ph = tf.placeholder(tf.bool)
    test_local_stats_ph = tf.placeholder(tf.bool)
    batch_norm4(inputs,
                is_training=is_training_ph,
                test_local_stats=test_local_stats_ph)
コード例 #26
0
ファイル: batch_norm_v2_test.py プロジェクト: zwcdp/sonnet
  def testDataFormats(self, data_format):
    """Check that differing data formats give the correct output shape."""
    dim_sizes = {
        "N": None,
        "D": 10,
        "H": 64,
        "W": 32,
        "C": 3
    }
    inputs = tf.placeholder_with_default(
        tf.zeros([dim_sizes[dim_name] or 5 for dim_name in data_format]),
        [dim_sizes[dim_name] for dim_name in data_format])

    bn_data_formats = [data_format]
    if data_format.endswith("C"):
      bn_data_formats.append(None)

    for bn_data_format in bn_data_formats:
      bn = snt.BatchNormV2(data_format=bn_data_format, offset=False)
      bn(inputs, is_training=True)
      mean_shape = bn.moving_mean.get_shape()
      correct_mean_shape = [
          dim_sizes["C"] if dim_name == "C" else 1 for dim_name in data_format
      ]
      self.assertEqual(mean_shape, correct_mean_shape)

    for use_gpu in [True, False]:
      with self.test_session(use_gpu=use_gpu) as sess:
        for bn_data_format in "NC NWC NHWC NDHWC NCW NCHW NCDHW".split():
          if len(data_format) != len(bn_data_format):
            bn = snt.BatchNormV2(data_format=bn_data_format, offset=False)
            err = r"Incorrect data format {} for input shape .*".format(
                bn_data_format)
            with self.assertRaisesRegexp(snt.IncompatibleShapeError, err):
              outputs = bn(inputs, is_training=True)
              sess.run(outputs)
コード例 #27
0
    def _build(self, inputs, is_training):
        h = snt.Linear(output_size=self.hidden_size)(inputs)
        h = tf.layers.Dropout(rate=self.drop_rate)(h, is_training)

        for i in range(self.num_highways):
            h = Highway()(h)

            if self.use_batch_norm:
                h = snt.BatchNormV2(data_format='NC')(h, is_training)
            if self.activation != 'linear':
                h = Activation(activation=self.activation)(h)

            # h = tf.layers.Dropout(rate=self.drop_rate)(h, is_training)

        outputs = snt.Linear(output_size=self.num_outputs)(h)
        return outputs
コード例 #28
0
ファイル: batch_norm_v2_test.py プロジェクト: zwcdp/sonnet
  def testDynamicImageShape(self, shape, data_format, fused, batch_unknown):
    """Check that tensors with unknown spatial dimensions work."""

    if batch_unknown:
      shape[0] = None

    input_ph = tf.placeholder(tf.float32, shape=shape)

    bn = snt.BatchNormV2(data_format=data_format, fused=fused)
    output_train = bn(input_ph, is_training=True)
    output_test = bn(input_ph, is_training=False)
    self.assertEqual(output_train.get_shape().as_list(),
                     output_test.get_shape().as_list())
    # Check that no information about the shape has been erased from the input.
    self.assertEqual(output_train.get_shape().as_list(),
                     input_ph.get_shape().as_list())
コード例 #29
0
ファイル: batch_norm_v2_test.py プロジェクト: zwcdp/sonnet
  def testCheckStatsPython(self):
    """The correct normalization is being used for different Python flags."""

    v, input_v, inputs = self._get_inputs()

    bn = snt.BatchNormV2(
        offset=False,
        scale=False,
        decay_rate=0.5,
        update_ops_collection=tf.GraphKeys.UPDATE_OPS
    )
    out1 = bn(inputs, is_training=True, test_local_stats=True)
    out2 = bn(inputs, is_training=False, test_local_stats=True)
    out3 = bn(inputs, is_training=False, test_local_stats=False)

    update_ops = tuple(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
    self.assertLen(update_ops, 2)

    with tf.control_dependencies(update_ops):
      out1 = tf.identity(out1)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())

      out_v = sess.run(out1)
      mm, mv = sess.run([bn.moving_mean, bn.moving_variance])

      # Single moving average steps should have happened.
      correct_mm = (1.0 - bn._decay_rate) * v
      correct_mv = np.ones([1, 6]) * bn._decay_rate

      self.assertAllClose(np.reshape(correct_mm, [1, 6]), mm)
      self.assertAllClose(np.reshape(correct_mv, [1, 6]), mv)
      self.assertAllClose(np.zeros([7, 6]), out_v, rtol=1e-6, atol=1e-5)

      out2_, out3_ = sess.run([out2, out3])

      # Out2: Tested using local batch stats.
      # Better numerical precision due to using shifted estimators.
      self.assertAllClose(np.zeros([7, 6]), out2_, rtol=1e-6, atol=1e-5)

      # Out3: Tested using moving average stats.
      self.assertAllClose(
          (input_v - mm) / np.sqrt(mv + bn._eps),
          out3_)
コード例 #30
0
    def _build(self, main_inputs , attention_inputs , is_training):
        
        # rgb_inputs and flow_inputs already has same shape as (None,frame_counts,56,56,256)
        # 
        
        if self._use_conv:
            main_theta = Unit3d(main_inputs.shape.as_list()[-1],activation_fn=None,use_batch_norm=False,name='main_theta')(main_inputs,is_training=is_training)
            attention_theta = Unit3d(attention_inputs.shape.as_list()[-1],activation_fn=None,use_batch_norm=False,name='attention_theta')(attention_inputs,is_training=is_training)
        else:
            main_theta = main_inputs
            attention_theta = attention_inputs
        
        temporal = main_theta.shape.as_list()[1]
        height = main_theta.shape.as_list()[2]
        width = main_theta.shape.as_list()[3]
        dim_inner = main_theta.shape.as_list()[-1]

        if self._batch_size is None:
            self._batch_size = -1

        if self._space:
            main_theta = tf.reshape(main_theta,[self._batch_size, temporal,height * width,dim_inner],name='main_theta_reshape')
            attention_theta = tf.reshape(attention_theta,[self._batch_size, temporal,attention_theta.shape.as_list()[2] * attention_theta.shape.as_list()[3],dim_inner],name='attention_theta_reshape')
        else:
            main_theta = tf.reshape(main_theta,[self._batch_size, height * width,dim_inner * temporal],name='main_theta_reshape')
            attention_theta = tf.reshape(attention_theta,[self._batch_size, attention_theta.shape.as_list()[2] * attention_theta.shape.as_list()[3],dim_inner * temporal],name='attention_theta_reshape')
  
        main_attention_theta = tf.matmul(main_theta,attention_theta,transpose_a=False,transpose_b=True,name='main_attention_theta_matmul')
        l = main_attention_theta.shape.as_list()[3]
        main_attention_theta = linear_transform(main_attention_theta,l)
        if self._use_softmax:
            p = tf.nn.softmax(main_attention_theta,name='main_attention_theta_softmax')
        else:
            p = main_attention_theta
        
        main_attention = tf.matmul(p,attention_theta)
        main_attention = tf.reshape(main_attention,[self._batch_size,temporal,height,width,dim_inner])

        if self._use_conv:
            main_attention = Unit3d(dim_inner,activation_fn=None,use_batch_norm=False)(main_attention,is_training=is_training)

        main_attention = snt.BatchNormV2(scale=True)(main_attention,is_training=is_training,test_local_stats=False)

        return main_attention