예제 #1
0
파일: ppo.py 프로젝트: modyharshit23/trax
def policy_and_value_net(n_actions, n_controls, vocab_size, bottom_layers_fn,
                         two_towers):
    """A policy and value net function."""

    # Layers.

    # Now, with the current logits, one head computes action probabilities and the
    # other computes the value function.
    # NOTE: The LogSoftmax instead of the Softmax because of numerical stability.

    @tl.layer()
    def FlattenControlsIntoTime(x, **unused_kwargs):  # pylint: disable=invalid-name
        """Splits logits for actions in different controls and flattens controls."""
        return np.reshape(x, (x.shape[0], -1, n_actions))

    if vocab_size is None:
        # In continuous policies every element of the output sequence corresponds to
        # an observation.
        n_preds_per_input = n_controls
        kwargs = {}
    else:
        # In discrete policies every element of the output sequence corresponds to
        # a symbol in the discrete representation, and each control takes 1 symbol.
        n_preds_per_input = 1
        kwargs = {"vocab_size": vocab_size}

    if two_towers:
        layers = [
            tl.Dup(),
            tl.Parallel(
                [
                    bottom_layers_fn(**kwargs),
                    tl.Dense(n_preds_per_input * n_actions),
                    FlattenControlsIntoTime(),  # pylint: disable=no-value-for-parameter
                    tl.LogSoftmax()
                ],
                [
                    bottom_layers_fn(**kwargs),
                    tl.Dense(n_preds_per_input),
                    tl.Flatten()
                ],
            )
        ]
    else:
        layers = [
            bottom_layers_fn(**kwargs),
            tl.Dup(),
            tl.Parallel(
                [
                    tl.Dense(n_preds_per_input * n_actions),
                    FlattenControlsIntoTime(),  # pylint: disable=no-value-for-parameter
                    tl.LogSoftmax()
                ],
                [tl.Dense(n_preds_per_input),
                 tl.Flatten()],
            )
        ]
    return tl.Model(layers)
예제 #2
0
파일: resnet.py 프로젝트: zsunpku/trax
def WideResnet(n_blocks=3, widen_factor=1, n_output_classes=10, bn_momentum=0.9,
               mode='train'):
  """WideResnet from https://arxiv.org/pdf/1605.07146.pdf.

  Args:
    n_blocks: int, number of blocks in a group. total layers = 6n + 4.
    widen_factor: int, widening factor of each group. k=1 is vanilla resnet.
    n_output_classes: int, number of distinct output classes.
    bn_momentum: float, momentum in BatchNorm.
    mode: Whether we are training or evaluating or doing inference.

  Returns:
    The list of layers comprising a WideResnet model with the given parameters.
  """
  return tl.Serial(
      tl.ToFloat(),
      tl.Conv(16, (3, 3), padding='SAME'),
      WideResnetGroup(n_blocks, 16 * widen_factor, bn_momentum=bn_momentum,
                      mode=mode),
      WideResnetGroup(n_blocks, 32 * widen_factor, (2, 2),
                      bn_momentum=bn_momentum, mode=mode),
      WideResnetGroup(n_blocks, 64 * widen_factor, (2, 2),
                      bn_momentum=bn_momentum, mode=mode),
      tl.BatchNorm(momentum=bn_momentum, mode=mode),
      tl.Relu(),
      tl.AvgPool(pool_size=(8, 8)),
      tl.Flatten(),
      tl.Dense(n_output_classes),
      tl.LogSoftmax(),
  )
예제 #3
0
def AtariCnnBody(n_frames=4,
                 hidden_sizes=(32, 64, 64),
                 output_size=512,
                 mode='train',
                 kernel_initializer=None,
                 padding='VALID'):
    """An Atari CNN."""
    del mode

    # TODO(jonni): Include link to paper?
    # Input shape: (B, T, H, W, C)
    # Output shape: (B, T, output_size)
    return tl.Serial(
        _BytesToFloats(),
        _FrameStack(n_frames=n_frames),  # (B, T, H, W, 4C)
        tl.Conv(hidden_sizes[0], (8, 8), (4, 4),
                padding=padding,
                kernel_initializer=kernel_initializer),
        tl.Relu(),
        tl.Conv(hidden_sizes[1], (4, 4), (2, 2),
                padding=padding,
                kernel_initializer=kernel_initializer),
        tl.Relu(),
        tl.Conv(hidden_sizes[2], (3, 3), (1, 1),
                padding=padding,
                kernel_initializer=kernel_initializer),
        tl.Relu(),
        tl.Flatten(n_axes_to_keep=2),  # B, T and rest.
        tl.Dense(output_size),
        tl.Relu(),
    )
예제 #4
0
    def test_train_mnist(self):
        """Train MNIST model (almost) fully, to compare to other implementations.

    Evals for cross-entropy loss and accuracy are run every 50 steps;
    their values are visible in the test log.
    """
        gin.parse_config([
            'batch_fn.batch_size_per_device = 256',
            'batch_fn.eval_batch_size = 256',
        ])

        mnist_model = tl.Serial(
            tl.Flatten(),
            tl.Dense(512),
            tl.Relu(),
            tl.Dense(512),
            tl.Relu(),
            tl.Dense(10),
            tl.LogSoftmax(),
        )
        task = training.TrainTask(
            itertools.cycle(_mnist_dataset().train_stream(1)),
            tl.CrossEntropyLoss(), adafactor.Adafactor(.02))
        eval_task = training.EvalTask(
            itertools.cycle(_mnist_dataset().eval_stream(1)),
            [tl.CrossEntropyLoss(), tl.AccuracyScalar()],
            names=['CrossEntropyLoss', 'AccuracyScalar'],
            eval_at=lambda step_n: step_n % 50 == 0,
            eval_N=10)

        training_session = training.Loop(mnist_model,
                                         task,
                                         eval_task=eval_task)
        training_session.run(n_steps=1000)
        self.assertEqual(training_session.current_step(), 1000)
예제 #5
0
    def test_policy_and_value_net(self):
        observation_shape = (3, 4, 5)
        n_actions = 2
        n_controls = 3
        batch = 2
        time_steps = 10
        observations = np.random.uniform(size=(batch, time_steps) +
                                         observation_shape)
        actions = np.random.randint(n_actions,
                                    size=(batch, time_steps - 1, n_controls))
        (pnv_model, _) = policy_based_utils.policy_and_value_net(
            bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)],
            observation_space=gym.spaces.Box(shape=observation_shape,
                                             low=0,
                                             high=1),
            action_space=gym.spaces.MultiDiscrete((n_actions, ) * n_controls),
            vocab_size=None,
            two_towers=True,
        )
        input_signature = shapes.signature((observations, actions))
        _, _ = pnv_model.init(input_signature)

        (action_logits, values) = pnv_model((observations, actions))

        # Output is a list, first is probab of actions and the next is value output.
        self.assertEqual((batch, time_steps, n_controls, n_actions),
                         action_logits.shape)
        self.assertEqual((batch, time_steps), values.shape)
예제 #6
0
파일: ppo_test.py 프로젝트: wangleiphy/trax
  def test_policy_and_value_net(self):
    observation_shape = (3, 4, 5)
    batch_observation_shape = (1, 1) + observation_shape
    n_actions = 2
    n_controls = 3
    pnv_model = ppo.policy_and_value_net(
        n_controls=n_controls,
        n_actions=n_actions,
        vocab_size=None,
        bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)],
        two_towers=True,
    )
    input_signature = ShapeDtype(batch_observation_shape)
    _, _ = pnv_model.init(input_signature)

    batch = 2
    time_steps = 10
    batch_of_observations = np.random.uniform(
        size=(batch, time_steps) + observation_shape)
    pnv_output = pnv_model(batch_of_observations)

    # Output is a list, first is probab of actions and the next is value output.
    self.assertEqual(2, len(pnv_output))
    self.assertEqual(
        (batch, time_steps * n_controls, n_actions), pnv_output[0].shape)
    self.assertEqual((batch, time_steps * n_controls), pnv_output[1].shape)
예제 #7
0
    def test_train_mnist(self):
        """Train MNIST model (almost) fully, to compare to other implementations.

    Evals for cross-entropy loss and accuracy are run every 50 steps;
    their values are visible in the test log.
    """
        mnist_model = tl.Serial(
            tl.Flatten(),
            tl.Dense(512),
            tl.Relu(),
            tl.Dense(512),
            tl.Relu(),
            tl.Dense(10),
            tl.LogSoftmax(),
        )
        task = training.TrainTask(
            itertools.cycle(_mnist_dataset().train_stream(1)),
            tl.CrossEntropyLoss(), adafactor.Adafactor(.02))
        eval_task = training.EvalTask(
            itertools.cycle(_mnist_dataset().eval_stream(1)),
            [tl.CrossEntropyLoss(), tl.Accuracy()],
            n_eval_batches=10)

        training_session = training.Loop(
            mnist_model, [task],
            eval_tasks=[eval_task],
            eval_at=lambda step_n: step_n % 50 == 0)

        training_session.run(n_steps=1000)
        self.assertEqual(training_session.step, 1000)
예제 #8
0
파일: atari_cnn.py 프로젝트: shadowkun/trax
def AtariCnnBody(n_frames=4,
                 hidden_sizes=(32, 64, 64),
                 output_size=512,
                 mode='train',
                 kernel_initializer=None):
    """An Atari CNN."""
    del mode

    # TODO(jonni): Include link to paper?
    # Input shape: (B, T, H, W, C)
    # Output shape: (B, T, output_size)
    return tl.Serial(
        tl.Fn(lambda x: x / 255.0),  # Convert unsigned bytes to float.
        _FrameStack(n_frames=n_frames),  # (B, T, H, W, 4C)
        tl.Conv(hidden_sizes[0], (8, 8), (4, 4),
                padding='SAME',
                kernel_initializer=kernel_initializer),
        tl.Relu(),
        tl.Conv(hidden_sizes[1], (4, 4), (2, 2),
                'SAME',
                kernel_initializer=kernel_initializer),
        tl.Relu(),
        tl.Conv(hidden_sizes[2], (3, 3), (1, 1),
                'SAME',
                kernel_initializer=kernel_initializer),
        tl.Relu(),
        tl.Flatten(n_axes_to_keep=2),  # B, T and rest.
        tl.Dense(output_size),
        tl.Relu(),
    )
예제 #9
0
 def test_two_outputs_pass(self):
     layer = tl.AssertFunction(
         '...cd->...x,...cd',
         tl.Branch(
             tl.Flatten(n_axes_to_keep=2),
             tl.Dropout(rate=0.1),
         ))
     x = np.ones((1, 2, 3, 4))
     layer(x)
예제 #10
0
def Resnet50(d_hidden=64,
             n_output_classes=1001,
             mode='train',
             norm=tl.BatchNorm,
             non_linearity=tl.Relu):
    """ResNet.

  Args:
    d_hidden: Dimensionality of the first hidden layer (multiplied later).
    n_output_classes: Number of distinct output classes.
    mode: Whether we are training or evaluating or doing inference.
    norm: `Layer` used for normalization, Ex: BatchNorm or
      FilterResponseNorm.
    non_linearity: `Layer` used as a non-linearity, Ex: If norm is
      BatchNorm then this is a Relu, otherwise for FilterResponseNorm this
      should be ThresholdedLinearUnit.

  Returns:
    The list of layers comprising a ResNet model with the given parameters.
  """

    # A ConvBlock configured with the given norm, non-linearity and mode.
    def Resnet50ConvBlock(filter_multiplier=1, strides=(2, 2)):
        filters = ([
            filter_multiplier * dim
            for dim in [d_hidden, d_hidden, 4 * d_hidden]
        ])
        return ConvBlock(3, filters, strides, norm, non_linearity, mode)

    # Same as above for IdentityBlock.
    def Resnet50IdentityBlock(filter_multiplier=1):
        filters = ([
            filter_multiplier * dim
            for dim in [d_hidden, d_hidden, 4 * d_hidden]
        ])
        return IdentityBlock(3, filters, norm, non_linearity, mode)

    return tl.Serial(
        tl.ToFloat(),
        tl.Conv(d_hidden, (7, 7), (2, 2), 'SAME'),
        norm(mode=mode),
        non_linearity(),
        tl.MaxPool(pool_size=(3, 3), strides=(2, 2)),
        Resnet50ConvBlock(strides=(1, 1)),
        [Resnet50IdentityBlock() for _ in range(2)],
        Resnet50ConvBlock(2),
        [Resnet50IdentityBlock(2) for _ in range(3)],
        Resnet50ConvBlock(4),
        [Resnet50IdentityBlock(4) for _ in range(5)],
        Resnet50ConvBlock(8),
        [Resnet50IdentityBlock(8) for _ in range(2)],
        tl.AvgPool(pool_size=(7, 7)),
        tl.Flatten(),
        tl.Dense(n_output_classes),
        tl.LogSoftmax(),
    )
예제 #11
0
def get_model(num_classes):
    return tl.Serial(
        tl.Flatten(),
        tl.Dense(512),
        tl.Relu(),
        tl.Dense(512),
        tl.Relu(),
        tl.Dense(num_classes),
        tl.LogSoftmax(),
    )
예제 #12
0
 def test_multi_output_rank_fail(self):
     layer = tl.AssertFunction(
         '...34->...x,...y',
         tl.Branch(
             tl.Flatten(n_axes_to_keep=3),
             tl.Serial(),
         ))
     x = np.ones((1, 2, 3, 4))
     with self.assertRaises(tl.LayerError):
         layer(x)
예제 #13
0
 def test_too_many_outputs_fail(self):
     layer = tl.AssertFunction(
         '...cd->...x,...cd,...cd,...cd',
         tl.Branch(
             tl.Flatten(n_axes_to_keep=2),
             tl.Dropout(rate=0.1),
             tl.Serial(),
         ))
     x = np.ones((1, 2, 3, 4))
     with self.assertRaises(tl.LayerError):
         layer(x)
예제 #14
0
def Resnet50(d_hidden=64, n_output_classes=1001, mode='train'):
    """ResNet.

  Args:
    d_hidden: Dimensionality of the first hidden layer (multiplied later).
    n_output_classes: Number of distinct output classes.
    mode: Whether we are training or evaluating or doing inference.

  Returns:
    The list of layers comprising a ResNet model with the given parameters.
  """
    return tl.Model(
        tl.ToFloat(),
        tl.Conv(d_hidden, (7, 7), (2, 2), 'SAME'),
        tl.BatchNorm(mode=mode),
        tl.Relu(),
        tl.MaxPool(pool_size=(3, 3), strides=(2, 2)),
        ConvBlock(3, [d_hidden, d_hidden, 4 * d_hidden], (1, 1), mode=mode),
        IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden], mode=mode),
        IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden], mode=mode),
        ConvBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], (2, 2),
                  mode=mode),
        IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden],
                      mode=mode),
        ConvBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], (2, 2),
                  mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden],
                      mode=mode),
        ConvBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], (2, 2),
                  mode=mode),
        IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden],
                      mode=mode),
        IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden],
                      mode=mode),
        tl.AvgPool(pool_size=(7, 7)),
        tl.Flatten(),
        tl.Dense(n_output_classes),
        tl.LogSoftmax(),
    )
예제 #15
0
def RecommenderTransformer(n_classes_in, embedding_size, n_out_classes,
                           dropout_rate):
    transfomer = tl.Serial(
        tl.Embedding(n_classes_in, d_feature=embedding_size),
        tl.Dropout(dropout_rate),
        tl.SelfAttention(2),
        tl.Flatten(),
        tl.Dropout(dropout_rate),
        #tl.DotProductCausalAttention(4),
        tl.Dense(n_out_classes),
        tl.LogSoftmax())

    print(str(transfomer))
    return transfomer
예제 #16
0
파일: mlp.py 프로젝트: modyharshit23/trax
def MLP(n_hidden_layers=2,
        d_hidden=512,
        activation_fn=tl.Relu,
        n_output_classes=10,
        mode="train"):
    """A multi-layer feedforward (perceptron) network."""
    del mode

    return tl.Model(
        tl.Flatten(),
        [[tl.Dense(d_hidden), activation_fn()]
         for _ in range(n_hidden_layers)],
        tl.Dense(n_output_classes),
        tl.LogSoftmax(),
    )
예제 #17
0
 def model(mode):
     del mode
     return layers.Serial(
         layers.Parallel(
             layers.Flatten(),  # Observation stack.
             layers.Embedding(d_feature=1,
                              vocab_size=n_actions),  # Action.
         ),
         layers.Concatenate(),
         layers.Dense(n_units=1),
         layers.Dup(),
         layers.Parallel(
             layers.Dense(n_units=obs_shape[1]),  # New observation.
             None,  # Reward.
         ))
예제 #18
0
파일: mlp.py 프로젝트: zhaoqiuye/trax
def PureMLP(
    layer_widths=(128, 64),
    activation_fn=tl.Relu,
    out_activation=False,
    flatten=True,
    mode='train'):
  """A "multilayer perceptron" (MLP) network.

  This is a classic fully connected feedforward network, with one or more
  layers and a (nonlinear) activation function between each layer. For
  historical reasons, such networks are often called multilayer perceptrons;
  but they are more accurately described as multilayer networks, where
  each layer + activation function is a perceptron-like unit (see, e.g.,
  [https://en.wikipedia.org/wiki/Multilayer_perceptron#Terminology]).

  Args:
    layer_widths: Tuple of ints telling the number of layers and the width of
        each layer. For example, setting `layer_widths=(128, 64, 32)` would
        yield 3 layers with successive widths of 128, 64, and 32.
    activation_fn: Layer that computes a nonlinear activation between pairs
        of fully connnected layers. An activation function typically acts
        elementwise, and its output has the same shape and dtype as its input.
    out_activation: If True, include a copy of the activation function as the
        last layer in the network.
    flatten: If True, insert a layer at the head of the network to flatten the
        input tensor into a matrix of shape (batch_size. -1).
    mode: Ignored.

  Returns:
    An assembled MLP network with the specified layers. This network can either
    be initialized and trained as a full model, or can be used as a building
    block in a larger network.
  """
  del mode

  layers = []
  for width in layer_widths:
    layers.append(tl.Dense(width))
    layers.append(activation_fn())

  if not out_activation:
    # Don't need the last activation.
    layers.pop()

  return tl.Serial(
      [tl.Flatten()] if flatten else [],
      layers,
  )
예제 #19
0
def _build_model(two_heads):
    cls_head = tl.Serial(tl.Dense(10), tl.LogSoftmax())
    if two_heads:
        reg_head = tl.Dense(1)
        heads = tl.Branch(cls_head, reg_head)
    else:
        heads = cls_head
    return tl.Serial(
        tl.Fn('ScaleInput', lambda x: x / 255),
        tl.Flatten(),
        tl.Dense(512),
        tl.Relu(),
        tl.Dense(512),
        tl.Relu(),
        heads,
    )
예제 #20
0
파일: mlp.py 프로젝트: youngjt/trax
def PureMLP(
    hidden_dims=(128, 64), activation_fn=tl.Relu, flatten=True, mode='train'):
  """A multi-layer feedforward (perceptron) network."""
  del mode

  layers = []
  for hidden_dim in hidden_dims:
    layers.append(tl.Dense(hidden_dim))
    layers.append(activation_fn())

  # Don't need the last activation.
  layers.pop()

  return tl.Serial(
      [tl.Flatten()] if flatten else [],
      layers,
  )
예제 #21
0
파일: atari_cnn.py 프로젝트: shadowkun/trax
def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode='train'):
    """An Atari CNN."""
    del mode

    # TODO(jonni): Include link to paper?
    # Input shape: (B, T, H, W, C)
    # Output shape: (B, T, output_size)
    return tl.Serial(
        tl.Fn(lambda x: x / 255.0),  # Convert unsigned bytes to float.
        _FrameStack(n_frames=n_frames),  # (B, T, H, W, 4C)
        tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'),
        tl.Relu(),
        tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'),
        tl.Relu(),
        tl.Flatten(n_axes_to_keep=2),  # B, T and rest.
        tl.Dense(output_size),
        tl.Relu(),
    )
예제 #22
0
파일: mlp.py 프로젝트: zzszmyf/trax
def MLP(d_hidden=512,
        n_hidden_layers=2,
        activation_fn=tl.Relu,
        n_output_classes=10,
        mode='train'):
    """A multi-layer feedforward (perceptron) network."""
    del mode

    # Define a function rather than a variable, so that multiple copies will
    # each be their own object with their own weights.
    def DensePlusActivation():
        return [tl.Dense(d_hidden), activation_fn()]

    return tl.Serial(
        tl.Flatten(),
        [DensePlusActivation() for _ in range(n_hidden_layers)],
        tl.Dense(n_output_classes),
        tl.LogSoftmax(),
    )
예제 #23
0
def RawPolicy(seq_model, n_controls, n_actions):
    """Wraps a sequence model in a policy interface.

  The resulting model takes as input observation anc action sequences, but only
  uses the observations. Adds output heads for action logits and value
  predictions.

  Args:
    seq_model: Trax sequence model taking as input and outputting a sequence of
      continuous vectors.
    n_controls: Number of controls.
    n_actions: Number of action categories in each control.

  Returns:
    A model of signature (obs, act) -> (act_logits, values), with shapes:
      obs: (batch_size, length + 1, obs_depth)
      act: (batch_size, length, n_controls)
      act_logits: (batch_size, length, n_controls, n_actions)
      values: (batch_size, length)
  """
    @tl.layer()
    def SplitControls(x, **unused_kwargs):  # pylint: disable=invalid-name
        """Splits logits for actions in different controls."""
        return np.reshape(x, x.shape[:2] + (n_controls, n_actions))

    action_head = [
        # Predict all action logits at the same time.
        tl.Dense(n_controls * n_actions),
        # Then group them into separate controls, adding a new dimension.
        SplitControls(),  # pylint: disable=no-value-for-parameter
        # Needed because there is 1 less actions than observations.
        DropLastTimestep(),  # pylint: disable=no-value-for-parameter
        tl.LogSoftmax(),
    ]
    return tl.Serial([  # (obs, act)
        tl.Select([0], n_in=2),  # (obs,)
        seq_model,  # (obs_hidden,)
        tl.Dup(),  # (obs_hidden, obs_hidden)
        tl.Parallel(
            action_head,
            [tl.Dense(1), tl.Flatten()],
        )  # (act_logits, values)
    ])
예제 #24
0
def get_model(n_output_classes=10):
    """
    Simple CNN to classify Fashion MNIST
    """
    model = tl.Serial(
        tl.ToFloat(),
        tl.Conv(32, (3, 3), (1, 1), "SAME"),
        tl.LayerNorm(),
        tl.Relu(),
        tl.MaxPool(),
        tl.Conv(64, (3, 3), (1, 1), "SAME"),
        tl.LayerNorm(),
        tl.Relu(),
        tl.MaxPool(),
        tl.Flatten(),
        tl.Dense(n_output_classes),
    )

    return model
예제 #25
0
def AtariCnn(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode='train'):
    """An Atari CNN."""
    del mode

    # TODO(jonni): Include link to paper?
    # Input shape: (B, T, H, W, C)
    # Output shape: (B, T, output_size)
    return tl.Model(
        tl.ToFloat(),
        tl.Div(divisor=255.0),

        # Set up n_frames successive game frames, concatenated on the last axis.
        FrameStack(n_frames=n_frames),  # (B, T, H, W, 4C)
        tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'),
        tl.Relu(),
        tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'),
        tl.Relu(),
        tl.Flatten(n_axes_to_keep=2),  # B, T and rest.
        tl.Dense(output_size),
        tl.Relu(),
    )
예제 #26
0
 def test_reduce_rank_explicit_fail2(self):
     layer = tl.AssertFunction('abcde->abcd', tl.Flatten(n_axes_to_keep=3))
     x = np.ones((1, 2, 3, 4, 5))
     with self.assertRaises(tl.LayerError):
         layer(x)
예제 #27
0
 def test_reduce_rank_ellipsis_pass(self):
     layer = tl.AssertFunction('...ab->...c', tl.Flatten(n_axes_to_keep=3))
     x = np.ones((1, 2, 3, 4, 5))
     layer(x)
예제 #28
0
 def test_reduce_rank_explicit_pass(self):
     layer = tl.AssertFunction('xyzab->xyzc', tl.Flatten(n_axes_to_keep=3))
     x = np.ones((1, 2, 3, 4, 5))
     layer(x)
예제 #29
0
 def test_reduce_rank_to_one_pass(self):
     layer = tl.AssertFunction('abcde->x', tl.Flatten(n_axes_to_keep=0))
     x = np.ones((1, 2, 3, 4, 5))
     layer(x)
예제 #30
0
파일: ppo_test.py 프로젝트: wangleiphy/trax
  def test_combined_loss(self):
    B, T, A, OBS = 2, 10, 2, (28, 28, 3)  # pylint: disable=invalid-name
    batch_observation_shape = (1, 1) + OBS

    net = ppo.policy_and_value_net(
        n_controls=1,
        n_actions=A,
        vocab_size=None,
        bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)],
        two_towers=True,
    )

    input_signature = ShapeDtype(batch_observation_shape)
    old_params, _ = net.init(input_signature)
    new_params, state = net.init(input_signature)

    # Generate a batch of observations.

    observations = np.random.uniform(size=(B, T + 1) + OBS)
    actions = np.random.randint(0, A, size=(B, T + 1))
    rewards = np.random.uniform(0, 1, size=(B, T))
    mask = np.ones_like(rewards)

    # Just test that this computes at all.
    (new_log_probabs, value_predictions_new) = (
        net(observations, weights=new_params, state=state))
    (old_log_probabs, value_predictions_old) = (
        net(observations, weights=old_params, state=state))

    gamma = 0.99
    lambda_ = 0.95
    epsilon = 0.2
    value_weight = 1.0
    entropy_weight = 0.01

    nontrainable_params = {
        'gamma': gamma,
        'lambda': lambda_,
        'epsilon': epsilon,
        'value_weight': value_weight,
        'entropy_weight': entropy_weight,
    }

    rewards_to_actions = np.eye(value_predictions_old.shape[1])
    (value_loss_1, _) = ppo.value_loss_given_predictions(
        value_predictions_new, rewards, mask, gamma=gamma,
        value_prediction_old=value_predictions_old, epsilon=epsilon)
    (ppo_loss_1, _) = ppo.ppo_loss_given_predictions(
        new_log_probabs,
        old_log_probabs,
        value_predictions_old,
        actions,
        rewards_to_actions,
        rewards,
        mask,
        gamma=gamma,
        lambda_=lambda_,
        epsilon=epsilon)

    (combined_loss, (ppo_loss_2, value_loss_2, entropy_bonus), _, state) = (
        ppo.combined_loss(new_params,
                          old_log_probabs,
                          value_predictions_old,
                          net,
                          observations,
                          actions,
                          rewards_to_actions,
                          rewards,
                          mask,
                          nontrainable_params=nontrainable_params,
                          state=state)
    )

    # Test that these compute at all and are self consistent.
    self.assertGreater(entropy_bonus, 0.0)
    self.assertNear(value_loss_1, value_loss_2, 1e-6)
    self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6)
    self.assertNear(
        combined_loss,
        ppo_loss_2 + (value_weight * value_loss_2) -
        (entropy_weight * entropy_bonus),
        1e-6
    )