Пример #1
0
    def _init_host_and_devices(self, n_devices=None, random_seed=None):
        """Initializes host and device attributes for this trainer.

    Args:
      n_devices: Number of devices this trainer will use. If `None`, get the
          number from the backend.
      random_seed: Random seed as the starting point for all random numbers used
          by the trainer. If `None`, calculate one from system time and host id.

    Returns:
      is_chief: True if this trainer has special chief responsibilities.
      n_devices: The passed in value of n_devices or a computed default.
      random_seed: The passed in value of random_seed or a computed default.
    """
        if math.backend_name() == 'jax':
            host_id = jax.host_id()
            host_count = jax.host_count()
        else:
            host_id = 0
            host_count = 1
        is_chief = (host_id == 0)

        device_count = math.device_count()
        n_devices = n_devices or device_count
        # TODO(lukaszkaiser): remove this restriction when possible.
        if n_devices != device_count and math.backend_name() == 'jax':
            raise ValueError(
                'JAX cannot work yet with n_devices != all devices: '
                '%d != %d' % (n_devices, device_count))

        if random_seed is None and host_count > 1:
            random_seed = int(1e6 * (host_id + time.time())) % 2**32
        return is_chief, n_devices, init_random_number_generators(random_seed)
Пример #2
0
    def _run_value_model(self, observations, dist_inputs):
        if dist_inputs is None:
            dist_inputs = jnp.zeros(observations.shape[:2] +
                                    (self._policy_dist.n_inputs, ))

        actions = None
        if self._q_value:
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            # Swapping the n_samples and batch_size axes, so the input is split
            # between accelerators along the batch_size axis.
            dist_inputs = jnp.swapaxes(dist_inputs, 0, 1)
            actions = self._policy_dist.sample(dist_inputs)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            obs = observations
            obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:]))
            inputs = (obs, actions)
        else:
            log_probs = None
            inputs = (observations, )

        n_devices = math.device_count()
        weights = tl.for_n_devices(self._value_eval_model.weights, n_devices)
        state = tl.for_n_devices(self._value_eval_model.state, n_devices)
        rng = self._value_eval_model.rng
        values, _ = self._value_eval_jit(inputs, weights, state, rng)
        values *= self._value_network_scale
        values = jnp.squeeze(values,
                             axis=-1)  # Remove the singleton depth dim.
        return (values, actions, log_probs)
Пример #3
0
    def forward_with_state(self, xs, weights, state, rng):
        self._validate_forward_inputs(xs)
        (step, layers_state) = state
        # Get N+1 rngs, N for running layers and one extra.
        rngs = _split_rngs(rng, self._n_layers + 1)
        rng0, rngs = rngs[0], rngs[1:]
        if not self.sublayers:  # No-op: leave args unchanged.
            return (xs, (step + 1, layers_state))

        # Prepare the stack and do some safety checks as in the parent class.
        stack = xs
        new_state = []
        n_layers = self._n_layers
        if n_layers != 1 and len(weights) != n_layers:
            raise ValueError(
                'number of weights ({}) not equal to number of layers '
                '({})'.format(len(weights), n_layers))
        if n_layers != 1 and len(layers_state) != n_layers:
            raise ValueError(
                'length of state ({}) not equal to number of layers '
                '({})'.format(len(layers_state), n_layers))

        # TODO(chowdhery): try different strategies, also try running not all
        # layers backwards by using math.stop_gradient where needed.

        # Calculate how many layers to run forward.
        if self._mode == 'train':
            # warmup goes from 1.0 at start to 0.0 at skipping_warmup_steps and after
            w_steps = float(self._skipping_warmup_steps)
            warmup = np.maximum(0.0,
                                (w_steps - step.astype(np.float32)) / w_steps)
            # low is the minimum number of layers to *not* skip, from n_layers to 0
            low = warmup * float(n_layers)
            # high should be so that (high - n_layers) / high = 1.0 - skip_fraction
            # because (high - n_layers) / high is the probability we're not skipping
            # (after warmup); so high - n_layers = high - high * skip_fraction
            high = float(n_layers) / self._skip_fraction
            # We want the same rng0 on all cores.
            if math.device_count() > 1:
                rng0 = math.psum(rng0, 'batch')
            n_forward_layers = random.uniform(rng0, (), np.float32, low, high)
        else:
            n_forward_layers = float(n_layers)
        # Run layers skipping after a certain number.
        cur_layer_idx = 0.0
        for layer, p, s, rng in zip(self.sublayers, weights, layers_state,
                                    rngs):
            inputs = _inputs_from_stack(layer, stack)
            outputs, s = math.cond(  # Skip (do identity) if > n_forward_layers.
                pred=(math.lt(cur_layer_idx, n_forward_layers)),
                true_operand=(inputs, p, s, rng),  # This tuple is t below.
                true_fun=(lambda t: layer.pure_fn(t[0], t[1], t[2], t[3])),  # pylint: disable=cell-var-from-loop
                false_operand=(inputs, p, s, rng),
                false_fun=(lambda t: (t[0], t[2])),  # return (inputs, state)
            )
            stack = _outputs_onto_stack(layer, outputs, stack)
            new_state.append(s)
            cur_layer_idx += 1.0
        return stack, (step + 1, new_state)
Пример #4
0
    def _run_value_model(self, observations, dist_inputs):
        if dist_inputs is None:
            dist_inputs = jnp.zeros(observations.shape[:2] +
                                    (self._policy_dist.n_inputs, ))

        actions = None
        if self._q_value:
            if self._sample_all_discrete_actions:
                # Since we want to sample all actions, start by creating their list.
                act = np.arange(self._vocab_size)
                # Now act is a vector [0, ..., vocab_size-1], but we'll need to tile it.
                # Add extra dimenstions so it's the same dimensionality as dist_inputs.
                act = jnp.reshape(act,
                                  [-1] + [1] * (len(dist_inputs.shape) - 1))
                # Now act is [vocab_size, 1, ..., 1], dimensionality of dist_inputs.
            dist_inputs = jnp.broadcast_to(
                dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape)
            if self._sample_all_discrete_actions:
                actions = act + jnp.zeros(dist_inputs.shape[:-1],
                                          dtype=jnp.int32)
                actions = jnp.swapaxes(actions, 0, 1)
            # Swapping the n_samples and batch_size axes, so the input is split
            # between accelerators along the batch_size axis.
            dist_inputs = jnp.swapaxes(dist_inputs, 0, 1)
            if not self._sample_all_discrete_actions:
                actions = self._policy_dist.sample(dist_inputs)
            log_probs = self._policy_dist.log_prob(dist_inputs, actions)
            obs = observations
            obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:]))
            inputs = (obs, actions)
        else:
            log_probs = None
            inputs = (observations, )

        n_devices = math.device_count()
        weights = tl.for_n_devices(self._value_eval_model.weights, n_devices)
        state = tl.for_n_devices(self._value_eval_model.state, n_devices)
        rng = self._value_eval_model.rng
        values, _ = self._value_eval_jit(inputs, weights, state, rng)
        values *= self._value_network_scale
        values = jnp.squeeze(values,
                             axis=-1)  # Remove the singleton depth dim.
        return (values, actions, log_probs)
Пример #5
0
    def __init__(self,
                 task,
                 value_model=None,
                 value_optimizer=None,
                 value_lr_schedule=lr.MultifactorSchedule,
                 value_batch_size=64,
                 value_train_steps_per_epoch=500,
                 value_evals_per_epoch=1,
                 value_eval_steps=1,
                 n_shared_layers=0,
                 added_policy_slice_length=0,
                 n_replay_epochs=1,
                 scale_value_targets=False,
                 q_value=False,
                 q_value_aggregate_max=True,
                 q_value_n_samples=1,
                 vocab_size=2,
                 **kwargs):  # Arguments of PolicyTrainer come here.
        """Configures the actor-critic trainer.

    Args:
      task: `RLTask` instance to use.
      value_model: Model to use for the value function.
      value_optimizer: Optimizer to train the value model.
      value_lr_schedule: lr schedule for value model training.
      value_batch_size: Batch size for value model training.
      value_train_steps_per_epoch: Number of steps are we using to train the
          value model in each epoch.
      value_evals_per_epoch: Number of value trainer evaluations per RL epoch;
          only affects metric reporting.
      value_eval_steps: Number of value trainer steps per evaluation; only
          affects metric reporting.
      n_shared_layers: Number of layers to share between value and policy
          models.
      added_policy_slice_length: How much longer should slices of
          trajectories be for policy than for value training; this
          is useful for TD calculations and only affect the length
          of elements produced for policy batches; value batches
          have maximum length set by `max_slice_length` in `**kwargs`.
      n_replay_epochs: Number of last epochs to take into the replay buffer;
          only makes sense for off-policy algorithms.
      scale_value_targets: If `True`, scale value function targets by
          `1 / (1 - gamma)`.
      q_value: If `True`, use Q-values as baselines.
      q_value_aggregate_max: If `True`, aggregate Q-values with max (or mean).
      q_value_n_samples: Number of samples to average over when calculating
          baselines based on Q-values.
      vocab_size: Embedding vocabulary size (passed to `tl.Embedding`); used
          only with discrete actions and when `q_value` is `True`.
      **kwargs: Arguments for `PolicyTrainer` superclass.
    """
        self._n_shared_layers = n_shared_layers
        self._value_batch_size = value_batch_size
        self._value_train_steps_per_epoch = value_train_steps_per_epoch
        self._value_evals_per_epoch = value_evals_per_epoch
        self._value_eval_steps = value_eval_steps

        # The 2 below will be initalized in super.__init__ anyway, but are needed
        # to construct value batches which are needed before PolicyTrainer init
        # since policy input creation calls the value model -- hence this code.
        self._task = task
        self._max_slice_length = kwargs.get('max_slice_length', 1)
        self._added_policy_slice_length = added_policy_slice_length
        self._n_replay_epochs = n_replay_epochs
        task.set_n_replay_epochs(n_replay_epochs)

        if scale_value_targets:
            self._value_network_scale = 1 / (1 - self._task.gamma)
        else:
            self._value_network_scale = 1

        self._q_value = q_value
        self._q_value_aggregate_max = q_value_aggregate_max
        self._q_value_n_samples = q_value_n_samples
        self._vocab_size = vocab_size

        is_discrete = isinstance(self._task.action_space, gym.spaces.Discrete)
        # TODO(henrykm) handle the case other than Discrete/Gaussian

        if q_value:
            value_model = functools.partial(value_model,
                                            inject_actions=True,
                                            is_discrete=is_discrete,
                                            vocab_size=self._vocab_size)
        self._value_eval_model = value_model(mode='eval')
        self._value_eval_model.init(self._value_model_signature)
        self._value_eval_jit = tl.jit_forward(self._value_eval_model.pure_fn,
                                              math.device_count(),
                                              do_mean=False)

        # Initialize policy training.
        super(ActorCriticTrainer, self).__init__(task, **kwargs)

        # Initialize training of the value function.
        value_output_dir = kwargs.get('output_dir', None)
        if value_output_dir is not None:
            value_output_dir = os.path.join(value_output_dir, 'value')
            # If needed, create value_output_dir and missing parent directories.
            if not tf.io.gfile.isdir(value_output_dir):
                tf.io.gfile.makedirs(value_output_dir)
        self._value_inputs = supervised.Inputs(
            train_stream=lambda _: self.value_batches_stream())
        self._value_trainer = supervised.Trainer(
            model=value_model,
            optimizer=value_optimizer,
            lr_schedule=value_lr_schedule,
            loss_fn=tl.L2Loss(),
            inputs=self._value_inputs,
            output_dir=value_output_dir,
            metrics={'value_loss': tl.L2Loss()})
Пример #6
0
    def __init__(self,
                 task,
                 value_model=None,
                 value_optimizer=None,
                 value_lr_schedule=lr.MultifactorSchedule,
                 value_batch_size=64,
                 value_train_steps_per_epoch=500,
                 value_evals_per_epoch=1,
                 value_eval_steps=1,
                 n_shared_layers=0,
                 added_policy_slice_length=0,
                 n_replay_epochs=1,
                 scale_value_targets=False,
                 q_value=False,
                 q_value_aggregate_max=True,
                 q_value_n_samples=1,
                 **kwargs):  # Arguments of PolicyTrainer come here.
        """Configures the actor-critic Trainer.

    Args:
      task: RLTask instance to use
      value_model: the model to use for the value function
      value_optimizer: the optimizer to train the value model
      value_lr_schedule: lr schedule for value model training
      value_batch_size: batch size for value model training
      value_train_steps_per_epoch: how many steps are we using to
        train the value model in each epoch
      value_evals_per_epoch: number of value trainer evaluations per RL
          epoch - only affects metric reporting.
      value_eval_steps: number of value trainer steps per evaluation -
          only affects metric reporting.
      n_shared_layers: how many layers to share between value and
        policy models
      added_policy_slice_length: how much longer should slices of
        trajectories be for policy than for value training; this
        is useful for TD calculations and only affect the length
        of elements produced for policy batches; value batches
        have maximum length set by max_slice_length in **kwargs
     n_replay_epochs: how many last epochs to take into the replay buffer;
        only makes sense for off-policy algorithms
     scale_value_targets: whether to scale targets for the value function by
        1 / (1 - gamma)
     q_value: whether to use Q-values as baselines
     q_value_aggregate_max: whether to aggregate Q-values with max (or mean)
     q_value_n_samples: number of samples to average over when calculating
        baselines based on Q-values
     **kwargs: arguments for PolicyTrainer super-class
    """
        self._n_shared_layers = n_shared_layers
        self._value_batch_size = value_batch_size
        self._value_train_steps_per_epoch = value_train_steps_per_epoch
        self._value_evals_per_epoch = value_evals_per_epoch
        self._value_eval_steps = value_eval_steps

        # The 2 below will be initalized in super.__init__ anyway, but are needed
        # to construct value batches which are needed before PolicyTrainer init
        # since policy input creation calls the value model -- hence this code.
        self._task = task
        self._max_slice_length = kwargs.get('max_slice_length', 1)
        self._added_policy_slice_length = added_policy_slice_length
        self._n_replay_epochs = n_replay_epochs
        task.set_n_replay_epochs(n_replay_epochs)

        if scale_value_targets:
            self._value_network_scale = 1 / (1 - self._task.gamma)
        else:
            self._value_network_scale = 1

        self._q_value = q_value
        self._q_value_aggregate_max = q_value_aggregate_max
        self._q_value_n_samples = q_value_n_samples
        if q_value:
            value_model = functools.partial(value_model, inject_actions=True)
        self._value_eval_model = value_model(mode='eval')
        self._value_eval_model.init(self._value_model_signature)
        self._value_eval_jit = tl.jit_forward(self._value_eval_model.pure_fn,
                                              math.device_count(),
                                              do_mean=False)

        # Initialize policy training.
        super(ActorCriticTrainer, self).__init__(task, **kwargs)

        # Initialize training of the value function.
        value_output_dir = kwargs.get('output_dir', None)
        if value_output_dir is not None:
            value_output_dir = os.path.join(value_output_dir, 'value')
            # If needed, create value_output_dir and missing parent directories.
            if not tf.io.gfile.isdir(value_output_dir):
                tf.io.gfile.makedirs(value_output_dir)
        self._value_inputs = supervised.Inputs(
            train_stream=lambda _: self.value_batches_stream())
        self._value_trainer = supervised.Trainer(
            model=value_model,
            optimizer=value_optimizer,
            lr_schedule=value_lr_schedule,
            loss_fn=tl.L2Loss(),
            inputs=self._value_inputs,
            output_dir=value_output_dir,
            metrics={'value_loss': tl.L2Loss()})