Beispiel #1
0
 def sample(self, inputs, temperature=1.0):
     # No need for LogSoftmax with Gumbel sampling - softmax normalization is
     # subtracting a constant from every logit, and Gumbel sampling is taking
     # a max over logits plus noise, so invariant to adding a constant.
     if temperature == 0.0:
         return jnp.argmax(self._unflatten_inputs(inputs), axis=-1)
     return tl.gumbel_sample(self._unflatten_inputs(inputs), temperature)
Beispiel #2
0
def autoregressive_sample(model,
                          prefix=None,
                          inputs=None,
                          batch_size=1,
                          temperature=1.0,
                          start_id=0,
                          eos_id=1,
                          max_length=100,
                          accelerate=True):
    """Perform aturegressive sampling from the provided model.

  Args:
    model: instance of trax.Layer, the model to sample from (at mode='predict')
    prefix: optional tensor [batch_size, L]: prefix for decoding
    inputs: optional tensor [batch_size, M]: inputs to provide to the model
    batch_size: how many batches to sample (default: 1)
    temperature: sampling temperature (default: 1.0)
    start_id: int, id for the start symbol fed at the beginning (default: 1)
    eos_id: int, id of the end-of-sequence symbol used to stop (default: 1)
    max_length: maximum length to sample (default: 100)
    accelerate: whether to accelerate the model before decoding (default: True)

  Returns:
    a tensor of ints of shape [batch_size, N] with N <= max_length containing
    the autoregressively sampled output from the model
  """
    if prefix is not None and prefix.shape[0] != batch_size:
        raise ValueError(
            f'Prefix batch size {prefix.shape[0]} != {batch_size}.')
    if inputs is not None and inputs.shape[0] != batch_size:
        raise ValueError(
            f'Inputs batch size {inputs.shape[0]} != {batch_size}.')
    fast_model = tl.Accelerate(model) if accelerate else model
    cur_symbol = np.full((batch_size, 1), start_id, dtype=np.int32)
    result = []
    for i in range(max_length):
        model_input = cur_symbol if inputs is None else (inputs, cur_symbol)
        logits = fast_model(model_input)
        if inputs is not None:
            logits = logits[
                0]  # Pick first element from model output (a pair here)
        if prefix is not None and i < prefix.shape[1]:  # Read from prefix.
            cur_prefix_symbol = prefix[:, i]
            sample = cur_prefix_symbol[:, None]
        else:
            sample = tl.gumbel_sample(logits, temperature=temperature)
        result.append(sample)
        # Note: we're using 'predict' mode autoregressive models here, so history
        # is caches in the model state and we are only feeding one symbol next.
        cur_symbol = sample
        # TODO(lukaszkaiser): extend stopping below to batch_sizes > 1.
        if batch_size == 1 and int(sample[0, 0]) == eos_id:
            break
    return np.concatenate(result, axis=1)
Beispiel #3
0
 def _predict_obs(self, predict_fn, rng):
     obs_repr = np.zeros(
         (self._steps.shape[0], self._obs_repr_length),
         dtype=np.int32,
     )
     for (i,
          subrng) in enumerate(jax_random.split(rng,
                                                self._obs_repr_length)):
         log_probs = predict_fn(self._last_symbols, rng=subrng)
         self._last_symbols = tl.gumbel_sample(log_probs)
         obs_repr[:, i] = self._last_symbols[:, 0]
     return np.array(self._obs_serializer.deserialize(obs_repr))
Beispiel #4
0
 def policy(self, trajectory):
     model = self._eval_model
     model.weights = self._trainer.model_weights
     pred = model(trajectory.last_observation[None, ...], n_accelerators=1)
     sample = tl.gumbel_sample(pred[0, :])
     return sample, pred[0, sample]
Beispiel #5
0
def PolicySchedule(
    history,
    observation_metrics=(
        ('train', 'metrics/accuracy'),
        ('train', 'metrics/loss'),
        ('eval', 'metrics/accuracy'),
        ('eval', 'metrics/loss'),
    ),
    include_controls_in_observation=False,
    control_configs=(
        # (name, start, (low, high), flip)
        ('learning_rate', 1e-3, (1e-9, 10.0), False), ),
    observation_range=(0.0, 10.0),
    action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5),
    policy_and_value_model=trax_models.FrameStackMLP,
    policy_and_value_two_towers=False,
    policy_and_value_vocab_size=None,
    policy_dir=gin.REQUIRED,
    temperature=1.0,
):
    """Learning rate schedule controlled by a learned policy.

  Args:
    history: the history of training and evaluation (History object).
    observation_metrics: list of pairs (mode, metric), as in the History object.
    include_controls_in_observation: bool, whether to include the controls in
      observations.
    control_configs: control configs, see trax.rl.envs.OnlineTuneEnv.
    observation_range: tuple (low, high), range to clip the metrics to.
    action_multipliers: sequence of LR multipliers that policy actions
      correspond to.
    policy_and_value_model: Trax model to use as the policy.
    policy_and_value_two_towers: bool, whether the action distribution and value
      prediction is computed by separate model towers.
    policy_and_value_vocab_size: vocabulary size of a policy and value network
      operating on serialized representation. If None, use raw continuous
      representation.
    policy_dir: directory with the policy checkpoint.
    temperature: temperature for sampling from the policy.

  Returns:
    a function nontrainable_params(step): float -> {'name': float}, the
    step-dependent schedule for nontrainable parameters.
  """

    # Turn the history into observations for the policy. If we don't have any,
    # return the initial learning rate.
    start_time = time.time()
    observations = online_tune.history_to_observations(
        history, observation_metrics, observation_range,
        control_configs if include_controls_in_observation else None)
    logging.vlog(1, 'Building observations took %0.2f sec.',
                 time.time() - start_time)
    if observations.shape[0] == 0:
        controls = {
            name: start_value
            for (name, start_value, _, _) in control_configs
        }
        return lambda _: controls

    # Build the policy network and load its parameters.
    start_time = time.time()
    (low, high) = observation_range
    observation_space = gym.spaces.Box(shape=observations.shape[1:],
                                       low=low,
                                       high=high)
    action_space = gym.spaces.MultiDiscrete(nvec=(len(action_multipliers), ) *
                                            len(control_configs))
    (net, _) = policy_based_utils.policy_and_value_net(
        bottom_layers_fn=policy_and_value_model,
        observation_space=observation_space,
        action_space=action_space,
        vocab_size=policy_and_value_vocab_size,
        two_towers=policy_and_value_two_towers,
    )
    logging.vlog(1, 'Building the policy network took %0.2f sec.',
                 time.time() - start_time)
    start_time = time.time()
    # (opt_state, state, epoch, opt_step, history)
    (opt_state, state, _, _,
     _) = policy_based_utils.maybe_restore_opt_state(policy_dir)
    assert opt_state is not None, 'Policy checkpoint not found.'
    (params, _, _) = opt_state
    logging.vlog(1, 'Restoring the policy parameters took %0.2f sec.',
                 time.time() - start_time)

    # Run the policy and sample an action.
    seed = random.randint(0, 2**31 - 1)
    rng = jax_random.get_prng(seed=seed)
    start_time = time.time()

    n_timesteps = observations.shape[0]
    # (log_probs, value_preds, state, rng)
    (log_probs, _, _, _) = policy_based_utils.run_policy(
        policy_and_value_net_apply=net,
        observations=np.array([observations]),
        lengths=np.array([n_timesteps]),
        weights=params,
        state=state,
        rng=rng,
        action_space=action_space,
    )

    logging.vlog(1, 'Running the policy took %0.2f sec.',
                 time.time() - start_time)
    # Sample from the action distribution for the last timestep.
    assert log_probs.shape == (1, len(control_configs),
                               len(action_multipliers))
    action = tl.gumbel_sample(log_probs[0], temperature)

    # Get new controls.
    controls = {
        # name: value
        control_config[0]: online_tune.update_control(  # pylint: disable=g-complex-comprehension
            control_config, control_action, history, action_multipliers)
        for (control_action, control_config) in zip(action, control_configs)
    }
    return lambda _: controls
Beispiel #6
0
def collect_trajectories(
    env,
    policy_fn,
    n_trajectories=1,
    n_observations=None,
    max_timestep=None,
    reset=True,
    len_history_for_policy=32,
    boundary=32,
    state=None,
    temperature=1.0,
    rng=None,
    abort_fn=None,
    raw_trajectory=False,
):
    """Collect trajectories with the given policy net and behaviour.

  Args:
    env: A gym env interface, for now this is not-batched.
    policy_fn: Callable
      (observations(B,T+1), actions(B, T+1, C)) -> log-probabs(B, T+1, C, A).
    n_trajectories: int, number of trajectories.
    n_observations: int, number of non-terminal observations. NOTE: Exactly one
      of `n_trajectories` and `n_observations` should be None.
    max_timestep: int or None, the index of the maximum time-step at which we
      return the trajectory, None for ending a trajectory only when env returns
      done.
    reset: bool, true if we want to reset the envs. The envs are also reset if
      max_max_timestep is None or < 0
    len_history_for_policy: int or None, the maximum history to keep for
      applying the policy on. If None, use the full history.
    boundary: int, pad the sequences to the multiples of this number.
    state: state for `policy_fn`.
    temperature: (float) temperature to sample action from policy_fn.
    rng: jax rng, splittable.
    abort_fn: callable, If not None, then at every env step call and abort the
      trajectory collection if it returns True, if so reset the env and return
      None.
    raw_trajectory: bool, if True a list of trajectory.Trajectory objects is
      returned, otherwise a list of numpy representations of
      `trajectory.Trajectory` is returned.

  Returns:
    A tuple (trajectory, number of trajectories that are done)
    trajectory: list of (observation, action, reward) tuples, where each element
    `i` is a tuple of numpy arrays with shapes as follows:
    observation[i] = (B, T_i + 1)
    action[i] = (B, T_i)
    reward[i] = (B, T_i)
  """

    assert isinstance(env, env_problem.EnvProblem)

    # We need to reset all environments, if we're coming here the first time.
    if reset or max_timestep is None or max_timestep <= 0:
        env.reset()
    else:
        # Clear completed trajectories held internally.
        env.trajectories.clear_completed_trajectories()

    num_done_trajectories = 0

    # The stopping criterion, returns True if we should stop.
    def should_stop():
        if n_trajectories is not None:
            assert n_observations is None
            return env.trajectories.num_completed_trajectories >= n_trajectories
        assert n_observations is not None
        # The number of non-terminal observations is what we want.
        return (env.trajectories.num_completed_time_steps -
                env.trajectories.num_completed_trajectories) >= n_observations

    policy_application_total_time = 0
    env_actions_total_time = 0
    bare_env_run_time = 0
    while not should_stop():
        # Check if we should abort and return nothing.
        if abort_fn and abort_fn():
            # We should also reset the environment, since it will have some
            # trajectories (complete and incomplete) that we want to discard.
            env.reset()
            return None, 0, {}, state

        # Get all the observations for all the active trajectories.
        # Shape is (B, T+1) + OBS
        # Bucket on whatever length is needed.
        padded_observations, lengths = env.trajectories.observations_np(
            boundary=boundary, len_history_for_policy=len_history_for_policy)

        B = padded_observations.shape[0]  # pylint: disable=invalid-name

        assert B == env.batch_size
        assert (B, ) == lengths.shape

        t1 = time.time()
        log_probs, value_preds, state, rng = policy_fn(padded_observations,
                                                       lengths,
                                                       state=state,
                                                       rng=rng)
        policy_application_total_time += (time.time() - t1)

        assert B == log_probs.shape[0]

        actions = tl.gumbel_sample(log_probs, temperature)
        if (isinstance(env.action_space, gym.spaces.Discrete)
                and (actions.shape[1] == 1)):
            actions = onp.squeeze(actions, axis=1)

        # Step through the env.
        t1 = time.time()
        _, _, dones, env_infos = env.step(actions,
                                          infos={
                                              'log_prob_actions': log_probs,
                                              'value_predictions': value_preds,
                                          })
        env_actions_total_time += (time.time() - t1)
        bare_env_run_time += sum(info['__bare_env_run_time__']
                                 for info in env_infos)

        # Count the number of done trajectories, the others could just have been
        # truncated.
        num_done_trajectories += onp.sum(dones)

        # Get the indices where we are done ...
        done_idxs = env_problem_utils.done_indices(dones)

        # ... and reset those.
        t1 = time.time()
        if done_idxs.size:
            env.reset(indices=done_idxs)
        env_actions_total_time += (time.time() - t1)

        if max_timestep is None or max_timestep < 1:
            continue

        # Are there any trajectories that have exceeded the time-limit we want.
        lengths = env.trajectories.trajectory_lengths
        exceeded_time_limit_idxs = env_problem_utils.done_indices(
            lengths > max_timestep)

        # If so, reset these as well.
        t1 = time.time()
        if exceeded_time_limit_idxs.size:
            # This just cuts the trajectory, doesn't reset the env, so it continues
            # from where it left off.
            env.truncate(indices=exceeded_time_limit_idxs, num_to_keep=1)
        env_actions_total_time += (time.time() - t1)

    # We have the trajectories we need, return a list of triples:
    # (observations, actions, rewards)
    completed_trajectories = (
        env_problem_utils.get_completed_trajectories_from_env(
            env,
            env.trajectories.num_completed_trajectories,
            raw_trajectory=raw_trajectory))

    timing_info = {
        'trajectory_collection/policy_application':
        policy_application_total_time,
        'trajectory_collection/env_actions': env_actions_total_time,
        'trajectory_collection/env_actions/bare_env': bare_env_run_time,
    }
    timing_info = {k: round(1000 * v, 2) for k, v in timing_info.items()}

    return completed_trajectories, num_done_trajectories, timing_info, state
Beispiel #7
0
def autoregressive_sample(model,
                          prefix=None,
                          inputs=None,
                          batch_size=1,
                          temperature=1.0,
                          start_id=0,
                          eos_id=1,
                          max_length=100,
                          accelerate=True):
    """Perform aturegressive sampling from the provided model.

  Note that the provided model should be an autoregressive model initialized
  in 'predict' mode. In this mode, a model takes the outputs it is generating
  one-by-one (instead of taking them all at once, as, e.g., during training).
  Model state is used to store the intermediate information needed, and usually
  the model perfoms inference in this mode faster than in 'eval' mode.

  Args:
    model: instance of trax.Layer, the model to sample from (at mode='predict')
    prefix: optional tensor [batch_size, L]: prefix for decoding
    inputs: optional tensor [batch_size, M]: inputs to provide to the model
    batch_size: how many batches to sample (default: 1)
    temperature: sampling temperature (default: 1.0)
    start_id: int, id for the start symbol fed at the beginning (default: 1)
    eos_id: int, id of the end-of-sequence symbol used to stop (default: 1)
    max_length: maximum length to sample (default: 100)
    accelerate: whether to accelerate the model before decoding (default: True)

  Returns:
    a tensor of ints of shape [batch_size, N] with N <= max_length containing
    the autoregressively sampled output from the model
  """
    if prefix is not None and prefix.shape[0] != batch_size:
        raise ValueError(
            f'Prefix batch size {prefix.shape[0]} != {batch_size}.')
    if inputs is not None and inputs.shape[0] != batch_size:
        raise ValueError(
            f'Inputs batch size {inputs.shape[0]} != {batch_size}.')
    fast_model = tl.Accelerate(model) if accelerate else model
    cur_symbol = np.full((batch_size, 1), start_id, dtype=np.int32)
    result = []
    eos_seen = []
    for i in range(max_length):
        model_input = cur_symbol if inputs is None else (inputs, cur_symbol)
        logits = fast_model(model_input)
        if inputs is not None:
            logits = logits[
                0]  # Pick first element from model output (a pair here)
        if prefix is not None and i < prefix.shape[1]:  # Read from prefix.
            cur_prefix_symbol = prefix[:, i]
            sample = cur_prefix_symbol[:, None]
        else:
            sample = tl.gumbel_sample(logits, temperature=temperature)
        result.append(sample)
        # Note: we're using 'predict' mode autoregressive models here, so history
        # is caches in the model state and we are only feeding one symbol next.
        cur_symbol = sample
        # Check at which batch positions have we already encountered EOS.
        for j in range(batch_size):
            if int(sample[j, 0]) == eos_id:
                eos_seen.append(j)
        # If EOS has been seen on all positions, stop.
        if all([j in eos_seen for j in range(batch_size)]):
            break
    return np.concatenate(result, axis=1)
Beispiel #8
0
 def sample(self, inputs):
     return tl.gumbel_sample(inputs)