Esempio n. 1
0
    def testBuild(self):
        output_spec = tensor_spec.BoundedTensorSpec([2], tf.int32, 0, 1)
        network = categorical_projection_network.CategoricalProjectionNetwork(
            output_spec)

        inputs = _get_inputs(batch_size=3, num_input_dims=5)

        distribution = network(inputs, outer_rank=1)
        self.evaluate(tf.global_variables_initializer())
        sample = self.evaluate(distribution.sample())

        self.assertEqual(tfp.distributions.Categorical, type(distribution))
        self.assertAllEqual((3, 2), sample.shape)
    def testTrainableVariables(self):
        output_spec = tensor_spec.BoundedTensorSpec([2], tf.int32, 0, 1)
        network = categorical_projection_network.CategoricalProjectionNetwork(
            output_spec)

        inputs = _get_inputs(batch_size=3, num_input_dims=5)

        network(inputs, outer_rank=1)
        self.evaluate(tf.compat.v1.global_variables_initializer())

        # Dense kernel, dense bias.
        self.assertEqual(2, len(network.trainable_variables))
        self.assertEqual((5, 4), network.trainable_variables[0].shape)
        self.assertEqual((4, ), network.trainable_variables[1].shape)
    def testBuild(self):
        output_spec = tensor_spec.BoundedTensorSpec([2, 3], tf.int32, 0, 1)
        network = categorical_projection_network.CategoricalProjectionNetwork(
            output_spec)

        inputs = _get_inputs(batch_size=3, num_input_dims=5)

        distribution, _ = network(inputs, outer_rank=1)
        self.evaluate(tf.compat.v1.global_variables_initializer())
        sample = self.evaluate(distribution.sample())

        self.assertEqual(tfp.distributions.Categorical, type(distribution))
        # Batch = 3; 2x3 action choices, 2x actions per choise.
        self.assertEqual((3, 2, 3, 2), distribution.logits.shape)
        self.assertAllEqual((3, 2, 3), sample.shape)
Esempio n. 4
0
  def __init__(self,
               input_tensor_spec,
               output_tensor_spec,
               multiplier=1):
    heads = []
    resid_blocks = []

    layer_num = 0

    def get_layer_num_str():
      nonlocal layer_num
      layer_num += 1
      return str(layer_num - 1)

    for channel in [16, 32, 32]:
      depth = channel * multiplier
      heads_block = [tf.keras.layers.Conv2D(depth, 3, padding='same', name='distribution_layer_' + get_layer_num_str()),
                     tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')]
      heads.append(heads_block)

      resid_block = [tf.keras.layers.ReLU(),
                     tf.keras.layers.Conv2D(depth, 3, padding='same', name='distribution_layer_' + get_layer_num_str()),
                     tf.keras.layers.ReLU(),
                     tf.keras.layers.Conv2D(depth, 3, padding='same', name='distribution_layer_' + get_layer_num_str()),
                     tf.keras.layers.Add()]
      resid_blocks.append(resid_block)

      resid_block = [tf.keras.layers.ReLU(),
                     tf.keras.layers.Conv2D(depth, 3, padding='same', name='distribution_layer_' + get_layer_num_str()),
                     tf.keras.layers.ReLU(),
                     tf.keras.layers.Conv2D(depth, 3, padding='same', name='distribution_layer_' + get_layer_num_str()),
                     tf.keras.layers.Add()]
      resid_blocks.append(resid_block)

    tail = [tf.keras.layers.Flatten(), tf.keras.layers.ReLU(),
            tf.keras.layers.Dense(256, activation=tf.nn.relu, name='distribution_layer_' + get_layer_num_str())]

    projection_network = categorical_projection_network.CategoricalProjectionNetwork(output_tensor_spec)
    super().__init__(input_tensor_spec, (), projection_network.output_spec,
                                                    'impala_distribution_network')
    self.heads = heads
    self.resid_blocks = resid_blocks
    self.tail = tail
    self.projection_network = projection_network
def _categorical_projection_net(action_spec, logits_init_output_factor=0.1):
  return categorical_projection_network.CategoricalProjectionNetwork(
      action_spec, logits_init_output_factor=logits_init_output_factor)