def _init_host_and_devices(self, n_devices=None, random_seed=None): """Initializes host and device attributes for this trainer. Args: n_devices: Number of devices this trainer will use. If `None`, get the number from the backend. random_seed: Random seed as the starting point for all random numbers used by the trainer. If `None`, calculate one from system time and host id. Returns: is_chief: True if this trainer has special chief responsibilities. n_devices: The passed in value of n_devices or a computed default. random_seed: The passed in value of random_seed or a computed default. """ if math.backend_name() == 'jax': host_id = jax.host_id() host_count = jax.host_count() else: host_id = 0 host_count = 1 is_chief = (host_id == 0) device_count = math.device_count() n_devices = n_devices or device_count # TODO(lukaszkaiser): remove this restriction when possible. if n_devices != device_count and math.backend_name() == 'jax': raise ValueError( 'JAX cannot work yet with n_devices != all devices: ' '%d != %d' % (n_devices, device_count)) if random_seed is None and host_count > 1: random_seed = int(1e6 * (host_id + time.time())) % 2**32 return is_chief, n_devices, init_random_number_generators(random_seed)
def _run_value_model(self, observations, dist_inputs): if dist_inputs is None: dist_inputs = jnp.zeros(observations.shape[:2] + (self._policy_dist.n_inputs, )) actions = None if self._q_value: dist_inputs = jnp.broadcast_to( dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape) # Swapping the n_samples and batch_size axes, so the input is split # between accelerators along the batch_size axis. dist_inputs = jnp.swapaxes(dist_inputs, 0, 1) actions = self._policy_dist.sample(dist_inputs) log_probs = self._policy_dist.log_prob(dist_inputs, actions) obs = observations obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:])) inputs = (obs, actions) else: log_probs = None inputs = (observations, ) n_devices = math.device_count() weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) state = tl.for_n_devices(self._value_eval_model.state, n_devices) rng = self._value_eval_model.rng values, _ = self._value_eval_jit(inputs, weights, state, rng) values *= self._value_network_scale values = jnp.squeeze(values, axis=-1) # Remove the singleton depth dim. return (values, actions, log_probs)
def forward_with_state(self, xs, weights, state, rng): self._validate_forward_inputs(xs) (step, layers_state) = state # Get N+1 rngs, N for running layers and one extra. rngs = _split_rngs(rng, self._n_layers + 1) rng0, rngs = rngs[0], rngs[1:] if not self.sublayers: # No-op: leave args unchanged. return (xs, (step + 1, layers_state)) # Prepare the stack and do some safety checks as in the parent class. stack = xs new_state = [] n_layers = self._n_layers if n_layers != 1 and len(weights) != n_layers: raise ValueError( 'number of weights ({}) not equal to number of layers ' '({})'.format(len(weights), n_layers)) if n_layers != 1 and len(layers_state) != n_layers: raise ValueError( 'length of state ({}) not equal to number of layers ' '({})'.format(len(layers_state), n_layers)) # TODO(chowdhery): try different strategies, also try running not all # layers backwards by using math.stop_gradient where needed. # Calculate how many layers to run forward. if self._mode == 'train': # warmup goes from 1.0 at start to 0.0 at skipping_warmup_steps and after w_steps = float(self._skipping_warmup_steps) warmup = np.maximum(0.0, (w_steps - step.astype(np.float32)) / w_steps) # low is the minimum number of layers to *not* skip, from n_layers to 0 low = warmup * float(n_layers) # high should be so that (high - n_layers) / high = 1.0 - skip_fraction # because (high - n_layers) / high is the probability we're not skipping # (after warmup); so high - n_layers = high - high * skip_fraction high = float(n_layers) / self._skip_fraction # We want the same rng0 on all cores. if math.device_count() > 1: rng0 = math.psum(rng0, 'batch') n_forward_layers = random.uniform(rng0, (), np.float32, low, high) else: n_forward_layers = float(n_layers) # Run layers skipping after a certain number. cur_layer_idx = 0.0 for layer, p, s, rng in zip(self.sublayers, weights, layers_state, rngs): inputs = _inputs_from_stack(layer, stack) outputs, s = math.cond( # Skip (do identity) if > n_forward_layers. pred=(math.lt(cur_layer_idx, n_forward_layers)), true_operand=(inputs, p, s, rng), # This tuple is t below. true_fun=(lambda t: layer.pure_fn(t[0], t[1], t[2], t[3])), # pylint: disable=cell-var-from-loop false_operand=(inputs, p, s, rng), false_fun=(lambda t: (t[0], t[2])), # return (inputs, state) ) stack = _outputs_onto_stack(layer, outputs, stack) new_state.append(s) cur_layer_idx += 1.0 return stack, (step + 1, new_state)
def _run_value_model(self, observations, dist_inputs): if dist_inputs is None: dist_inputs = jnp.zeros(observations.shape[:2] + (self._policy_dist.n_inputs, )) actions = None if self._q_value: if self._sample_all_discrete_actions: # Since we want to sample all actions, start by creating their list. act = np.arange(self._vocab_size) # Now act is a vector [0, ..., vocab_size-1], but we'll need to tile it. # Add extra dimenstions so it's the same dimensionality as dist_inputs. act = jnp.reshape(act, [-1] + [1] * (len(dist_inputs.shape) - 1)) # Now act is [vocab_size, 1, ..., 1], dimensionality of dist_inputs. dist_inputs = jnp.broadcast_to( dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape) if self._sample_all_discrete_actions: actions = act + jnp.zeros(dist_inputs.shape[:-1], dtype=jnp.int32) actions = jnp.swapaxes(actions, 0, 1) # Swapping the n_samples and batch_size axes, so the input is split # between accelerators along the batch_size axis. dist_inputs = jnp.swapaxes(dist_inputs, 0, 1) if not self._sample_all_discrete_actions: actions = self._policy_dist.sample(dist_inputs) log_probs = self._policy_dist.log_prob(dist_inputs, actions) obs = observations obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:])) inputs = (obs, actions) else: log_probs = None inputs = (observations, ) n_devices = math.device_count() weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) state = tl.for_n_devices(self._value_eval_model.state, n_devices) rng = self._value_eval_model.rng values, _ = self._value_eval_jit(inputs, weights, state, rng) values *= self._value_network_scale values = jnp.squeeze(values, axis=-1) # Remove the singleton depth dim. return (values, actions, log_probs)
def __init__(self, task, value_model=None, value_optimizer=None, value_lr_schedule=lr.MultifactorSchedule, value_batch_size=64, value_train_steps_per_epoch=500, value_evals_per_epoch=1, value_eval_steps=1, n_shared_layers=0, added_policy_slice_length=0, n_replay_epochs=1, scale_value_targets=False, q_value=False, q_value_aggregate_max=True, q_value_n_samples=1, vocab_size=2, **kwargs): # Arguments of PolicyTrainer come here. """Configures the actor-critic trainer. Args: task: `RLTask` instance to use. value_model: Model to use for the value function. value_optimizer: Optimizer to train the value model. value_lr_schedule: lr schedule for value model training. value_batch_size: Batch size for value model training. value_train_steps_per_epoch: Number of steps are we using to train the value model in each epoch. value_evals_per_epoch: Number of value trainer evaluations per RL epoch; only affects metric reporting. value_eval_steps: Number of value trainer steps per evaluation; only affects metric reporting. n_shared_layers: Number of layers to share between value and policy models. added_policy_slice_length: How much longer should slices of trajectories be for policy than for value training; this is useful for TD calculations and only affect the length of elements produced for policy batches; value batches have maximum length set by `max_slice_length` in `**kwargs`. n_replay_epochs: Number of last epochs to take into the replay buffer; only makes sense for off-policy algorithms. scale_value_targets: If `True`, scale value function targets by `1 / (1 - gamma)`. q_value: If `True`, use Q-values as baselines. q_value_aggregate_max: If `True`, aggregate Q-values with max (or mean). q_value_n_samples: Number of samples to average over when calculating baselines based on Q-values. vocab_size: Embedding vocabulary size (passed to `tl.Embedding`); used only with discrete actions and when `q_value` is `True`. **kwargs: Arguments for `PolicyTrainer` superclass. """ self._n_shared_layers = n_shared_layers self._value_batch_size = value_batch_size self._value_train_steps_per_epoch = value_train_steps_per_epoch self._value_evals_per_epoch = value_evals_per_epoch self._value_eval_steps = value_eval_steps # The 2 below will be initalized in super.__init__ anyway, but are needed # to construct value batches which are needed before PolicyTrainer init # since policy input creation calls the value model -- hence this code. self._task = task self._max_slice_length = kwargs.get('max_slice_length', 1) self._added_policy_slice_length = added_policy_slice_length self._n_replay_epochs = n_replay_epochs task.set_n_replay_epochs(n_replay_epochs) if scale_value_targets: self._value_network_scale = 1 / (1 - self._task.gamma) else: self._value_network_scale = 1 self._q_value = q_value self._q_value_aggregate_max = q_value_aggregate_max self._q_value_n_samples = q_value_n_samples self._vocab_size = vocab_size is_discrete = isinstance(self._task.action_space, gym.spaces.Discrete) # TODO(henrykm) handle the case other than Discrete/Gaussian if q_value: value_model = functools.partial(value_model, inject_actions=True, is_discrete=is_discrete, vocab_size=self._vocab_size) self._value_eval_model = value_model(mode='eval') self._value_eval_model.init(self._value_model_signature) self._value_eval_jit = tl.jit_forward(self._value_eval_model.pure_fn, math.device_count(), do_mean=False) # Initialize policy training. super(ActorCriticTrainer, self).__init__(task, **kwargs) # Initialize training of the value function. value_output_dir = kwargs.get('output_dir', None) if value_output_dir is not None: value_output_dir = os.path.join(value_output_dir, 'value') # If needed, create value_output_dir and missing parent directories. if not tf.io.gfile.isdir(value_output_dir): tf.io.gfile.makedirs(value_output_dir) self._value_inputs = supervised.Inputs( train_stream=lambda _: self.value_batches_stream()) self._value_trainer = supervised.Trainer( model=value_model, optimizer=value_optimizer, lr_schedule=value_lr_schedule, loss_fn=tl.L2Loss(), inputs=self._value_inputs, output_dir=value_output_dir, metrics={'value_loss': tl.L2Loss()})
def __init__(self, task, value_model=None, value_optimizer=None, value_lr_schedule=lr.MultifactorSchedule, value_batch_size=64, value_train_steps_per_epoch=500, value_evals_per_epoch=1, value_eval_steps=1, n_shared_layers=0, added_policy_slice_length=0, n_replay_epochs=1, scale_value_targets=False, q_value=False, q_value_aggregate_max=True, q_value_n_samples=1, **kwargs): # Arguments of PolicyTrainer come here. """Configures the actor-critic Trainer. Args: task: RLTask instance to use value_model: the model to use for the value function value_optimizer: the optimizer to train the value model value_lr_schedule: lr schedule for value model training value_batch_size: batch size for value model training value_train_steps_per_epoch: how many steps are we using to train the value model in each epoch value_evals_per_epoch: number of value trainer evaluations per RL epoch - only affects metric reporting. value_eval_steps: number of value trainer steps per evaluation - only affects metric reporting. n_shared_layers: how many layers to share between value and policy models added_policy_slice_length: how much longer should slices of trajectories be for policy than for value training; this is useful for TD calculations and only affect the length of elements produced for policy batches; value batches have maximum length set by max_slice_length in **kwargs n_replay_epochs: how many last epochs to take into the replay buffer; only makes sense for off-policy algorithms scale_value_targets: whether to scale targets for the value function by 1 / (1 - gamma) q_value: whether to use Q-values as baselines q_value_aggregate_max: whether to aggregate Q-values with max (or mean) q_value_n_samples: number of samples to average over when calculating baselines based on Q-values **kwargs: arguments for PolicyTrainer super-class """ self._n_shared_layers = n_shared_layers self._value_batch_size = value_batch_size self._value_train_steps_per_epoch = value_train_steps_per_epoch self._value_evals_per_epoch = value_evals_per_epoch self._value_eval_steps = value_eval_steps # The 2 below will be initalized in super.__init__ anyway, but are needed # to construct value batches which are needed before PolicyTrainer init # since policy input creation calls the value model -- hence this code. self._task = task self._max_slice_length = kwargs.get('max_slice_length', 1) self._added_policy_slice_length = added_policy_slice_length self._n_replay_epochs = n_replay_epochs task.set_n_replay_epochs(n_replay_epochs) if scale_value_targets: self._value_network_scale = 1 / (1 - self._task.gamma) else: self._value_network_scale = 1 self._q_value = q_value self._q_value_aggregate_max = q_value_aggregate_max self._q_value_n_samples = q_value_n_samples if q_value: value_model = functools.partial(value_model, inject_actions=True) self._value_eval_model = value_model(mode='eval') self._value_eval_model.init(self._value_model_signature) self._value_eval_jit = tl.jit_forward(self._value_eval_model.pure_fn, math.device_count(), do_mean=False) # Initialize policy training. super(ActorCriticTrainer, self).__init__(task, **kwargs) # Initialize training of the value function. value_output_dir = kwargs.get('output_dir', None) if value_output_dir is not None: value_output_dir = os.path.join(value_output_dir, 'value') # If needed, create value_output_dir and missing parent directories. if not tf.io.gfile.isdir(value_output_dir): tf.io.gfile.makedirs(value_output_dir) self._value_inputs = supervised.Inputs( train_stream=lambda _: self.value_batches_stream()) self._value_trainer = supervised.Trainer( model=value_model, optimizer=value_optimizer, lr_schedule=value_lr_schedule, loss_fn=tl.L2Loss(), inputs=self._value_inputs, output_dir=value_output_dir, metrics={'value_loss': tl.L2Loss()})