Ejemplo n.º 1
0
 def test_pure_mlp_forward_shape(self):
     """Run the PureMLP model forward and check output shape."""
     input_signature = ShapeDtype((7, 28, 28, 3))
     model = mlp.PureMLP(layer_widths=(32, 16, 8))
     final_shape = tl.check_shape_agreement(model, input_signature)
     self.assertEqual((7, 8), final_shape)
Ejemplo n.º 2
0
 def test_lstm_cell(self):
     self._test_cell_runs(rnn.LSTMCell(9),
                          input_signature=(ShapeDtype(
                              (8, 9)), ShapeDtype((8, 18))),
                          output_shape=((8, 9), (8, 18)))
Ejemplo n.º 3
0
 def test_rnnlm_forward_shape(self):
     """Runs the RNN LM forward and checks output shape."""
     input_signature = ShapeDtype((3, 28), dtype=math.numpy.int32)
     model = rnn.RNNLM(vocab_size=20, d_model=16)
     final_shape = tl.check_shape_agreement(model, input_signature)
     self.assertEqual((3, 28, 20), final_shape)
Ejemplo n.º 4
0
 def test_branch_one_layer(self):
   layer = cb.Branch(divide_by(0.5))
   input_signature = ShapeDtype((3, 2))
   expected_shape = (3, 2)
   output_shape = base.check_shape_agreement(layer, input_signature)
   self.assertEqual(output_shape, expected_shape)
Ejemplo n.º 5
0
 def test_conv_gru_cell(self):
     self._test_cell_runs(rnn.ConvGRUCell(9, kernel_size=(3, 3)),
                          input_signature=ShapeDtype((8, 1, 7, 9)),
                          output_shape=(8, 1, 7, 9))
Ejemplo n.º 6
0
 def test_serial_no_op_list(self):
   layer = cb.Serial([])
   input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 7)))
   expected_shape = ((3, 2), (4, 7))
   output_shape = base.check_shape_agreement(layer, input_signature)
   self.assertEqual(output_shape, expected_shape)
Ejemplo n.º 7
0
 def test_branch_noop_dup(self):
   layer = cb.Branch([], cb.Dup())
   input_signature = ShapeDtype((3, 2))
   expected_shape = ((3, 2), (3, 2), (3, 2))
   output_shape = base.check_shape_agreement(layer, input_signature)
   self.assertEqual(output_shape, expected_shape)
Ejemplo n.º 8
0
 def test_accuracy_scalar(self):
     input_signature = (ShapeDtype((29, 4, 4, 20)), ShapeDtype(
         (29, 4, 4)), ShapeDtype((29, 4, 4)))
     result_shape = base.check_shape_agreement(metrics.AccuracyScalar(),
                                               input_signature)
     self.assertEqual(result_shape, ())
Ejemplo n.º 9
0
 def test_resnet(self):
     input_signature = ShapeDtype((3, 256, 256, 3))
     model = resnet.Resnet50(d_hidden=8, n_output_classes=10)
     final_shape = tl.check_shape_agreement(model, input_signature)
     self.assertEqual((3, 10), final_shape)
Ejemplo n.º 10
0
 def test_weighted_mean_shape(self):
     input_signature = (ShapeDtype(
         (29, 4, 4, 20)), ShapeDtype((29, 4, 4, 20)))
     result_shape = base.check_shape_agreement(metrics._WeightedMean(),
                                               input_signature)
     self.assertEqual(result_shape, ())
Ejemplo n.º 11
0
 def test_cross_entropy_loss(self):
     input_signature = (ShapeDtype((29, 4, 4, 20)), ShapeDtype(
         (29, 4, 4)), ShapeDtype((29, 4, 4)))
     result_shape = base.check_shape_agreement(metrics.CrossEntropyLoss(),
                                               input_signature)
     self.assertEqual(result_shape, ())
Ejemplo n.º 12
0
 def test_layer_norm_shape(self):
   input_signature = ShapeDtype((29, 5, 7, 20))
   result_shape = base.check_shape_agreement(
       normalization.LayerNorm(), input_signature)
   self.assertEqual(result_shape, input_signature.shape)
Ejemplo n.º 13
0
 def test_image_fec(self):
     input_signature = ShapeDtype((3, 256, 256, 3))
     model = ImageFEC(d_hidden=8, n_output_classes=10)
     final_shape = tl.check_shape_agreement(model, input_signature)
     self.assertEqual((3, 10), final_shape)
Ejemplo n.º 14
0
 def test_mlp_forward_shape(self):
     """Run the MLP model forward and check output shape."""
     input_signature = ShapeDtype((3, 28, 28, 1))
     model = mlp.MLP(d_hidden=32, n_output_classes=10)
     final_shape = tl.check_shape_agreement(model, input_signature)
     self.assertEqual((3, 10), final_shape)
Ejemplo n.º 15
0
 def test_drop(self):
   layer = cb.Drop()
   input_signature = ShapeDtype((3, 2))
   expected_shape = ()
   output_shape = base.check_shape_agreement(layer, input_signature)
   self.assertEqual(output_shape, expected_shape)
Ejemplo n.º 16
0
 def test_wide_resnet(self):
     input_signature = ShapeDtype((3, 32, 32, 3))
     model = resnet.WideResnet(n_blocks=1, n_output_classes=10)
     final_shape = tl.check_shape_agreement(model, input_signature)
     self.assertEqual((3, 10), final_shape)
Ejemplo n.º 17
0
 def test_swap(self):
   layer = cb.Swap()
   input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 7)))
   expected_shape = ((4, 7), (3, 2))
   output_shape = base.check_shape_agreement(layer, input_signature)
   self.assertEqual(output_shape, expected_shape)
Ejemplo n.º 18
0
    def test_combined_loss(self):
        B, T, A, OBS = 2, 10, 2, (28, 28, 3)  # pylint: disable=invalid-name
        batch_observation_shape = (1, 1) + OBS

        make_net = lambda: ppo.policy_and_value_net(  # pylint: disable=g-long-lambda
            n_controls=1,
            n_actions=A,
            vocab_size=None,
            bottom_layers_fn=lambda: [layers.Flatten(n_axes_to_keep=2)],
            two_towers=True,
        )
        net = make_net()

        input_signature = ShapeDtype(batch_observation_shape)
        old_params, _ = net.init(input_signature)
        new_params, state = make_net().init(input_signature)

        # Generate a batch of observations.

        observations = np.random.uniform(size=(B, T + 1) + OBS)
        actions = np.random.randint(0, A, size=(B, T + 1))
        rewards = np.random.uniform(0, 1, size=(B, T))
        mask = np.ones_like(rewards)

        # Just test that this computes at all.
        (new_log_probabs, value_predictions_new) = (net(observations,
                                                        weights=new_params,
                                                        state=state))
        (old_log_probabs, value_predictions_old) = (net(observations,
                                                        weights=old_params,
                                                        state=state))

        gamma = 0.99
        lambda_ = 0.95
        epsilon = 0.2
        value_weight = 1.0
        entropy_weight = 0.01

        nontrainable_params = {
            'gamma': gamma,
            'lambda': lambda_,
            'epsilon': epsilon,
            'value_weight': value_weight,
            'entropy_weight': entropy_weight,
        }

        rewards_to_actions = np.eye(value_predictions_old.shape[1])
        (value_loss_1, _) = ppo.value_loss_given_predictions(
            value_predictions_new,
            rewards,
            mask,
            gamma=gamma,
            value_prediction_old=value_predictions_old,
            epsilon=epsilon)
        (ppo_loss_1, _) = ppo.ppo_loss_given_predictions(new_log_probabs,
                                                         old_log_probabs,
                                                         value_predictions_old,
                                                         actions,
                                                         rewards_to_actions,
                                                         rewards,
                                                         mask,
                                                         gamma=gamma,
                                                         lambda_=lambda_,
                                                         epsilon=epsilon)

        (combined_loss, (ppo_loss_2, value_loss_2, entropy_bonus), _,
         state) = (ppo.combined_loss(new_params,
                                     old_log_probabs,
                                     value_predictions_old,
                                     net,
                                     observations,
                                     actions,
                                     rewards_to_actions,
                                     rewards,
                                     mask,
                                     nontrainable_params=nontrainable_params,
                                     state=state))

        # Test that these compute at all and are self consistent.
        self.assertGreater(entropy_bonus, 0.0)
        self.assertNear(value_loss_1, value_loss_2, 1e-6)
        self.assertNear(ppo_loss_1, ppo_loss_2, 1e-6)
        self.assertNear(
            combined_loss, ppo_loss_2 + (value_weight * value_loss_2) -
            (entropy_weight * entropy_bonus), 1e-6)
Ejemplo n.º 19
0
 def test_serial_div_div(self):
   layer = cb.Serial(divide_by(2.0), divide_by(5.0))
   input_signature = ShapeDtype((3, 2))
   expected_shape = (3, 2)
   output_shape = base.check_shape_agreement(layer, input_signature)
   self.assertEqual(output_shape, expected_shape)
Ejemplo n.º 20
0
  def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs,
               output_dir=None, random_seed=None, n_devices=None,
               save_steps=None, should_save_checkpoints=True,
               should_write_summaries=True, has_weights=False,
               nontrainable_param_map=None, mask_id=None, metrics=None):
    if backend.get_name() == 'jax':
      self._host_id = jax.host_id()
      self._host_count = jax.host_count()
    else:
      self._host_id = 0
      self._host_count = 1
    self._is_chief = (self._host_id == 0)

    if save_steps is None:
      save_steps = []
    self._save_steps = save_steps
    self._should_save_checkpoints = should_save_checkpoints
    self._should_write_summaries = should_write_summaries
    self._has_weights = has_weights
    self._mask_id = mask_id
    self._metrics_dict = _METRICS if metrics is None else metrics
    loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id)
    device_count = backend.device_count()
    n_devices = n_devices or device_count
    # TODO(lukaszkaiser): remove this restriction when possible.
    if n_devices != device_count and backend.get_name() == 'jax':
      raise ValueError('JAX cannot work yet with n_devices != all devices: '
                       '%d != %d' % (n_devices, device_count))
    self._n_devices = n_devices

    # Simple differential seeding of RNG across hosts by host_id and time.
    if random_seed is None and self._host_count > 1:
      _, random_seed = divmod(int(time.time() * 1e6) +
                              int(self._host_id * 1e6), 2**32)
    rng = get_random_number_generator_and_set_seed(random_seed)
    inputs = inputs(n_devices)
    self._inputs = inputs

    # Initialize the learning rate to a dummy value. It will be set in reset().
    opt = optimizer(learning_rate=0.0)

    # Setup the model.
    model_train = model(mode='train')
    model_predict_eval = model(mode='eval')

    # Setup state.
    rng, init_rng = jax_random.split(rng)
    self._rngs = np.stack(jax_random.split(rng, n_devices))
    first_shape = inputs.input_shape[0]
    # If the inputs are a tuple/list, add [None] (batch) to each element.
    if isinstance(first_shape, (list, tuple)):
      model_input_shape = tuple(
          tuple([None] + list(shape)) for shape in inputs.input_shape)
      model_target_shape = tuple(
          tuple([None] + list(shape)) for shape in inputs.target_shape)
    else:  # Otherwise just add [None] to the input shape.
      model_input_shape = tuple([None] + list(inputs.input_shape))
      model_target_shape = tuple([None] + list(inputs.target_shape))
    # Change all None to 1 in input and target shape.
    model_input_shape = backend.nested_map(lambda x: x or 1, model_input_shape)
    model_target_shape = backend.nested_map(lambda x: x or 1,
                                            model_target_shape)
    def new_opt_state_and_model_state(input_shape, input_dtype, target_shape,
                                      target_dtype, rng):
      """Returns optimizer and model states suitable for training a model."""
      # Combine inputs and targets on the stack.
      if not isinstance(input_dtype, (list, tuple)):
        input_dtype = [input_dtype]
        input_shape = [input_shape]
      if not isinstance(target_dtype, (list, tuple)):
        target_dtype = [target_dtype]
        target_shape = [target_shape]
      dtypes = list(input_dtype) + list(target_dtype)
      shapes = list(input_shape) + list(target_shape)
      if self._has_weights:
        shapes += list(target_shape)
        dtypes += [np.float32 for _ in target_dtype]
      input_signature = tuple(ShapeDtype(s, d)
                              for (s, d) in zip(shapes, dtypes))
      # We need to create a new model instance and not reuse `model_train` here,
      # because `m.initialize` puts cached parameter values in `m` and hence the
      # next call of `m.initialize` will give wrong results.
      m = tl.Serial(model(mode='train'), loss_fn)
      m._set_rng_recursive(rng)  # pylint: disable=protected-access
      weights, state = m.init(input_signature)
      (slots, opt_params) = opt.tree_init(weights)
      return (OptState(weights, slots, opt_params), state)
    if _is_jit_init():
      # JIT parameter initialization to avoid memory fragmentation
      new_opt_state_and_model_state = backend.jit(new_opt_state_and_model_state,
                                                  static_argnums=(0, 1, 2, 3))
    self._new_opt_state_and_model_state = (
        lambda: new_opt_state_and_model_state(  # pylint: disable=g-long-lambda
            model_input_shape, self._inputs.input_dtype,
            model_target_shape, self._inputs.target_dtype, init_rng))

    # jit model_predict and update so they're fast
    # TODO(lukaszkaiser): the code below creates a layer computing
    # multiple metrics from a single model output; re-factor for clarity.
    dup_layer = tl.Dup3() if self._has_weights else tl.Dup2()
    def lower(layer):
      """Apply layer below the current inputs, targets, and possibly weights."""
      if self._has_weights:
        # Apply layer below inputs, targets, and loss weights.
        return tl.Parallel([], [], [], layer)
      else:
        # Apply layer below inputs and targets.
        return tl.Parallel([], [], layer)
    metrics_layer = []
    self._metrics = list(sorted(self._metrics_dict.keys()))
    for i, m in enumerate(reversed(self._metrics)):
      metric = self._metrics_dict[m](has_weights=self._has_weights,
                                     mask_id=self._mask_id)
      if i != len(self._metrics) - 1:
        metrics_layer.append(dup_layer)
        metrics_layer.append(lower(metric))
      else:
        metrics_layer.append(metric)
    # TODO(lukaszkaiser): clean this up once layer API stabilizes.
    # For now, we need to initialize metric layers somehow, so here we go.
    # We assume that they do not have any parameters, so this is a dummy.
    dummy_shapes = ((1, 2), (1,), (1,)) if self._has_weights else ((1, 2), (1,))
    dummy_dtypes = [np.float32] * (3 if self._has_weights else 2)
    dummy_signature = tuple(ShapeDtype(s, d)
                            for s, d in zip(dummy_shapes, dummy_dtypes))
    metrics_layer = tl.Serial(metrics_layer)
    metrics_layer._set_rng_recursive(init_rng)  # pylint: disable=protected-access
    metrics_weights, metrics_state = (
        metrics_layer.init(dummy_signature))
    self._metrics_weights = self._for_n_devices(metrics_weights)
    self._metrics_state = self._for_n_devices(metrics_state)
    self._jit_eval = _jit_predict_fn(
        model_predict_eval, metrics_layer, n_devices)
    self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, n_devices)

    self._model_train = model_train
    self._model_predict_eval = model_predict_eval
    self._loss_fn = loss_fn
    # TODO(pkozakowski): "Learning rate schedules" are currently able to control
    # control all optimizer parameters and model state, so let's rename them
    # accordingly.
    self._lr_schedule = lr_schedule

    if nontrainable_param_map is None:
      nontrainable_param_map = {}
    self._nontrainable_param_map = nontrainable_param_map

    # Those fields will be set in reset().
    self._output_dir = None
    self._train_sw = None
    self._eval_sw = None
    self._history = None
    self._lr_fn = None
    self._opt_state = None
    self._step = None
    self._model_state = None

    if output_dir is not None:
      self.reset(output_dir)
Ejemplo n.º 21
0
 def test_branch_add_div(self):
   layer = cb.Branch(cb.Add(), divide_by(0.5))
   input_signature = (ShapeDtype((3, 2)), ShapeDtype((3, 2)))
   expected_shape = ((3, 2), (3, 2))
   output_shape = base.check_shape_agreement(layer, input_signature)
   self.assertEqual(output_shape, expected_shape)
Ejemplo n.º 22
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 checkpoints_at=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 has_weights=False,
                 nontrainable_param_map=None,
                 id_to_mask=None,
                 metrics=None,
                 checkpoint_highest=None,
                 checkpoint_lowest=None):

        self._is_chief, self._n_devices, rng = (self._init_host_and_devices(
            n_devices, random_seed))
        self._should_save_checkpoints = should_save_checkpoints and self._is_chief
        self._checkpoints_at = checkpoints_at or []
        self._should_write_summaries = should_write_summaries
        if not output_dir:
            self._should_save_checkpoints = False
            self._should_write_summaries = False
        self._checkpoint_highest = checkpoint_highest
        self._checkpoint_lowest = checkpoint_lowest
        self._has_weights = has_weights
        self._id_to_mask = id_to_mask
        self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        loss_fn = loss_fn(has_weights=has_weights, id_to_mask=id_to_mask)
        # Inputs is either an Inputs instance or a function that returns it.
        self._inputs = inputs
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            self._inputs = inputs()

        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, self._n_devices))
        # If the inputs are a tuple/list, add [None] (batch) to each element.
        if self._inputs.input_shape and isinstance(self._inputs.input_shape[0],
                                                   (list, tuple)):
            model_input_shape = tuple(
                tuple([None] + list(shape))
                for shape in self._inputs.input_shape)
        else:  # Otherwise just add [None] to the input shape.
            model_input_shape = tuple([None] + list(self._inputs.input_shape))
        # Same for targets.
        if self._inputs.target_shape and isinstance(
                self._inputs.target_shape[0], (list, tuple)):
            model_target_shape = tuple(
                tuple([None] + list(shape))
                for shape in self._inputs.target_shape)
        else:
            model_target_shape = tuple([None] +
                                       list(self._inputs.target_shape))
        # Change all None to 1 in input and target shape.
        model_input_shape = math.nested_map(lambda x: x or 1,
                                            model_input_shape)
        model_target_shape = math.nested_map(lambda x: x or 1,
                                             model_target_shape)

        def new_opt_state_and_model_state(shape_dtype, rng):
            """Returns optimizer and model states suitable for training a model."""
            # Combine inputs and targets on the stack.
            shapes, dtypes = shape_dtype
            input_signature = tuple(
                ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))
            # We need to create a new model instance and not reuse `model_train` here,
            # because `m.initialize` puts cached parameter values in `m` and hence the
            # next call of `m.initialize` will give wrong results.
            m = tl.Serial(model(mode='train'), loss_fn)
            m._set_rng_recursive(rng)  # pylint: disable=protected-access
            weights, state = m.init(input_signature)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if _is_jit_init():
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = math.jit(
                new_opt_state_and_model_state, static_argnums=(0, ))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(  # pylint: disable=g-long-lambda
                self._inputs.example_shape_dtype, init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [
            self._metrics_dict[m](has_weights=self._has_weights,
                                  id_to_mask=self._id_to_mask)
            for m in self._metrics
        ]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        metrics_in_parallel._set_rng_recursive(init_rng)  # pylint: disable=protected-access
        example_signature = tuple(
            ShapeDtype(s, d)
            for (s, d) in zip(*self._inputs.example_shape_dtype))
        model_predict_eval.init(example_signature)
        output_signature = model_predict_eval.output_signature(
            example_signature)
        m_weights, m_state = metrics_in_parallel.init(output_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, self._n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             self._n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        # TODO(pkozakowski): "Learning rate schedules" are currently able to control
        # control all optimizer parameters and model state, so let's rename them
        # accordingly.
        self._lr_schedule = lr_schedule

        if nontrainable_param_map is None:
            nontrainable_param_map = {}
        self._nontrainable_param_map = nontrainable_param_map

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._lr_fn = None
        self._opt_state = None
        self._step = None
        self._model_state = None
        self.reset(output_dir)
Ejemplo n.º 23
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 checkpoints_at=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 metrics=None,
                 checkpoint_highest=None,
                 checkpoint_lowest=None):

        self._is_chief, _, self._n_devices, rng = (
            training.init_host_and_devices(n_devices, random_seed))
        self._should_save_checkpoints = should_save_checkpoints and self._is_chief
        self._checkpoints_at = checkpoints_at if checkpoints_at is not None else []
        self._should_write_summaries = should_write_summaries
        if not output_dir:
            self._should_save_checkpoints = False
            self._should_write_summaries = False
        self._checkpoint_highest = checkpoint_highest
        self._checkpoint_lowest = checkpoint_lowest
        self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        # Inputs is either an Inputs instance or a function that returns it.
        self._inputs = inputs
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            self._inputs = inputs()
        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')
        self._model_with_loss = tl.Serial(model_train, loss_fn)

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, self._n_devices))
        shapes, dtypes = self._inputs.example_shape_dtype
        input_signature = tuple(
            ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))

        def new_opt_state_and_model_state(rng):
            """Returns optimizer and model states suitable for training a model."""
            weights, state = self._model_with_loss.init(input_signature,
                                                        rng=rng)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if fastmath.is_backend(fastmath.Backend.JAX):
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = (
                fastmath.jit(new_opt_state_and_model_state))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [self._metrics_dict[m] for m in self._metrics]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        metrics_in_parallel.rng = init_rng
        example_signature = tuple(
            ShapeDtype(s, d)
            for (s, d) in zip(*self._inputs.example_shape_dtype))
        model_predict_eval.init(example_signature)
        self._input_signature = example_signature
        output_signature = model_predict_eval.output_signature(
            example_signature)
        m_weights, m_state = metrics_in_parallel.init(output_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, self._n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             self._n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        self._lr_schedule = lr_schedule

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._opt_state = None
        self._step = None
        self._model_state = None
        self.reset(output_dir)
Ejemplo n.º 24
0
 def test_parallel_dup_dup(self):
   layer = cb.Parallel(cb.Dup(), cb.Dup())
   input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 7)))
   expected_shape = ((3, 2), (3, 2), (4, 7), (4, 7))
   output_shape = base.check_shape_agreement(layer, input_signature)
   self.assertEqual(output_shape, expected_shape)
Ejemplo n.º 25
0
 def test_gru_cell(self):
     self._test_cell_runs(rnn.GRUCell(9),
                          input_signature=(ShapeDtype(
                              (8, 7, 9)), ShapeDtype((8, 7, 9))),
                          output_shape=((8, 7, 9), (8, 7, 9)))
Ejemplo n.º 26
0
 def test_parallel_div_div(self):
   layer = cb.Parallel(divide_by(0.5), divide_by(3.0))
   input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 7)))
   expected_shape = ((3, 2), (4, 7))
   output_shape = base.check_shape_agreement(layer, input_signature)
   self.assertEqual(output_shape, expected_shape)
Ejemplo n.º 27
0
 def test_sru(self):
     self._test_cell_runs(rnn.SRU(7),
                          input_signature=ShapeDtype((8, 9, 7)),
                          output_shape=(8, 9, 7))
Ejemplo n.º 28
0
 def test_parallel_no_ops(self):
   layer = cb.Parallel([], None)
   input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 7)))
   expected_shape = ((3, 2), (4, 7))
   output_shape = base.check_shape_agreement(layer, input_signature)
   self.assertEqual(output_shape, expected_shape)
Ejemplo n.º 29
0
 def test_constructor_and_read_properties(self):
     sd = ShapeDtype((2, 3), onp.int32)
     self.assertEqual(sd.shape, (2, 3))
     self.assertEqual(sd.dtype, onp.int32)
Ejemplo n.º 30
0
  def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs,
               output_dir=None, random_seed=None, n_devices=None,
               checkpoints_at=None, should_save_checkpoints=True,
               should_write_summaries=True, has_weights=False,
               nontrainable_param_map=None, mask_id=None, metrics=None):

    self._is_chief, self._n_devices, rng = (
        self._init_host_and_devices(n_devices, random_seed))
    self._should_save_checkpoints = should_save_checkpoints and self._is_chief
    self._checkpoints_at = checkpoints_at or []
    self._should_write_summaries = should_write_summaries

    self._has_weights = has_weights
    self._mask_id = mask_id
    self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
    loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id)
    inputs = inputs(self._n_devices)
    self._inputs = inputs

    # Initialize the learning rate to a dummy value. It will be set in reset().
    opt = optimizer(learning_rate=0.0)

    # Setup the model.
    model_train = model(mode='train')
    model_predict_eval = model(mode='eval')

    # Setup state.
    rng, init_rng = jax_random.split(rng)
    self._rngs = np.stack(jax_random.split(rng, self._n_devices))
    first_shape = inputs.input_shape[0]
    # If the inputs are a tuple/list, add [None] (batch) to each element.
    if isinstance(first_shape, (list, tuple)):
      model_input_shape = tuple(
          tuple([None] + list(shape)) for shape in inputs.input_shape)
      model_target_shape = tuple(
          tuple([None] + list(shape)) for shape in inputs.target_shape)
    else:  # Otherwise just add [None] to the input shape.
      model_input_shape = tuple([None] + list(inputs.input_shape))
      model_target_shape = tuple([None] + list(inputs.target_shape))
    # Change all None to 1 in input and target shape.
    model_input_shape = backend.nested_map(lambda x: x or 1, model_input_shape)
    model_target_shape = backend.nested_map(lambda x: x or 1,
                                            model_target_shape)

    def new_opt_state_and_model_state(input_shape, input_dtype, target_shape,
                                      target_dtype, rng):
      """Returns optimizer and model states suitable for training a model."""
      # Combine inputs and targets on the stack.
      if not isinstance(input_dtype, (list, tuple)):
        input_dtype = [input_dtype]
        input_shape = [input_shape]
      if not isinstance(target_dtype, (list, tuple)):
        target_dtype = [target_dtype]
        target_shape = [target_shape]
      dtypes = list(input_dtype) + list(target_dtype)
      shapes = list(input_shape) + list(target_shape)
      if self._has_weights:
        shapes += list(target_shape)
        dtypes += [np.float32 for _ in target_dtype]
      input_signature = tuple(ShapeDtype(s, d)
                              for (s, d) in zip(shapes, dtypes))
      # We need to create a new model instance and not reuse `model_train` here,
      # because `m.initialize` puts cached parameter values in `m` and hence the
      # next call of `m.initialize` will give wrong results.
      m = tl.Serial(model(mode='train'), loss_fn)
      m._set_rng_recursive(rng)  # pylint: disable=protected-access
      weights, state = m.init(input_signature)
      (slots, opt_params) = opt.tree_init(weights)
      return (OptState(weights, slots, opt_params), state)

    if _is_jit_init():
      # JIT parameter initialization to avoid memory fragmentation
      new_opt_state_and_model_state = backend.jit(new_opt_state_and_model_state,
                                                  static_argnums=(0, 1, 2, 3))
    self._new_opt_state_and_model_state = (
        lambda: new_opt_state_and_model_state(  # pylint: disable=g-long-lambda
            model_input_shape, self._inputs.input_dtype,
            model_target_shape, self._inputs.target_dtype, init_rng))

    # Arrange and initialize metrics layers.
    self._metrics = list(sorted(self._metrics_dict.keys()))
    metrics_layers = [self._metrics_dict[m](has_weights=self._has_weights,
                                            mask_id=self._mask_id)
                      for m in self._metrics]
    metrics_in_parallel = tl.Branch(*metrics_layers)
    # TODO(lukaszkaiser): clean this up once layer API stabilizes.
    # For now, we need to initialize metric layers somehow, so here we go.
    # We assume that they do not have any parameters, so this is a dummy.
    dummy_shapes = ((1, 2), (1,), (1,)) if self._has_weights else ((1, 2), (1,))
    dummy_signature = tuple(ShapeDtype(s) for s in dummy_shapes)
    metrics_in_parallel._set_rng_recursive(init_rng)  # pylint: disable=protected-access
    m_weights, m_state = metrics_in_parallel.init(dummy_signature)
    self._metrics_weights = self._for_n_devices(m_weights)
    self._metrics_state = self._for_n_devices(m_state)

    # Jit model_predict and update so they're fast.
    self._jit_eval = _jit_predict_fn(
        model_predict_eval, metrics_in_parallel, self._n_devices)
    self._jit_update_fn = _jit_update_fn(
        model_train, loss_fn, opt, self._n_devices)

    self._model_train = model_train
    self._model_predict_eval = model_predict_eval
    self._loss_fn = loss_fn
    # TODO(pkozakowski): "Learning rate schedules" are currently able to control
    # control all optimizer parameters and model state, so let's rename them
    # accordingly.
    self._lr_schedule = lr_schedule

    if nontrainable_param_map is None:
      nontrainable_param_map = {}
    self._nontrainable_param_map = nontrainable_param_map

    # Those fields will be set in reset().
    self._output_dir = None
    self._train_sw = None
    self._eval_sw = None
    self._history = None
    self._lr_fn = None
    self._opt_state = None
    self._step = None
    self._model_state = None

    if output_dir is not None:
      self.reset(output_dir)