def net_fn(iqn_inputs): """Function representing IQN-DQN Q-network.""" state = iqn_inputs.state # batch x state_shape taus = iqn_inputs.taus # batch x samples # Apply DQN convnet to embed state. state_embedding = dqn_torso()(state) state_dim = state_embedding.shape[-1] # Embed taus with cosine embedding + linear layer. # cos(pi * i * tau) for i = 1,...,latents for each batch_element x sample. # Broadcast everything to batch x samples x latent_dim. pi_multiples = jnp.arange(1, latent_dim + 1, dtype=jnp.float32) * jnp.pi tau_embedding = jnp.cos(pi_multiples[None, None, :] * taus[:, :, None]) # Map tau embedding onto state_dim via linear layer. embedding_layer = linear(state_dim) tau_embedding = hk.BatchApply(embedding_layer)(tau_embedding) tau_embedding = jax.nn.relu(tau_embedding) # Reshape/broadcast both embeddings to batch x num_samples x state_dim # and multiply together, before applying value head. head_input = tau_embedding * state_embedding[:, None, :] value_head = dqn_value_head(num_actions) q_dist = hk.BatchApply(value_head)(head_input) q_values = jnp.mean(q_dist, axis=1) q_values = jax.lax.stop_gradient(q_values) return IqnOutputs(q_dist=q_dist, q_values=q_values)
def unroll( self, inputs: observation_action_reward.OAR, # [T, B, ...] state: hk.LSTMState # [T, ...] ) -> Tuple[base.QValues, hk.LSTMState]: """Efficient unroll that applies torso, core, and duelling mlp in one pass.""" embeddings = hk.BatchApply(self._embed)(inputs) # [T, B, D+A+1] core_outputs, new_states = hk.static_unroll(self._core, embeddings, state) q_values = hk.BatchApply(self._duelling_head)(core_outputs) # [T, B, A] return q_values, new_states
def __call__(self, inputs: jnp.ndarray, *, is_training: bool, pos: Optional[jnp.ndarray] = None, network_input_is_1d: bool = True) -> PreprocessorOutputT: if self._prep_type == 'conv': # Convnet image featurization. # Downsamples spatially by a factor of 4 conv = self.convnet if len(inputs.shape) == 5: conv = hk.BatchApply(conv) inputs = conv(inputs, is_training=is_training) elif self._prep_type == 'conv1x1': # maps inputs to 64d conv = self.convnet_1x1 if len(inputs.shape) == 5: conv = hk.BatchApply(conv) inputs = conv(inputs) elif self._prep_type == 'patches': # Space2depth featurization. # Video: B x T x H x W x C inputs = space_to_depth( inputs, temporal_block_size=self._temporal_downsample, spatial_block_size=self._spatial_downsample) if inputs.ndim == 5 and inputs.shape[1] == 1: # for flow inputs = jnp.squeeze(inputs, axis=1) if self._conv_after_patching: inputs = hk.Linear(self._num_channels, name='patches_linear')(inputs) elif self._prep_type == 'pixels': # if requested, downsamples in the crudest way if inputs.ndim == 4: inputs = inputs[:, ::self._spatial_downsample, ::self. _spatial_downsample] elif inputs.ndim == 5: inputs = inputs[:, ::self._temporal_downsample, ::self. _spatial_downsample, ::self. _spatial_downsample] else: raise ValueError('Unsupported data format for pixels.') inputs, inputs_without_pos = self._build_network_inputs( inputs, pos, network_input_is_1d) modality_sizes = None # Size for each modality, only needed for multimodal return inputs, modality_sizes, inputs_without_pos
def __call__( self, inputs: jnp.ndarray, *, is_training: bool, pos: Optional[jnp.ndarray] = None, modality_sizes: Optional[ModalitySizeT] = None) -> jnp.ndarray: if self._input_reshape_size is not None: inputs = jnp.reshape(inputs, [inputs.shape[0]] + list(self._input_reshape_size) + [inputs.shape[-1]]) if self._postproc_type == 'conv' or self._postproc_type == 'raft': # Convnet image featurization. conv = self.convnet if len(inputs.shape) == 5 and self._temporal_upsample == 1: conv = hk.BatchApply(conv) inputs = conv(inputs, is_training=is_training) elif self._postproc_type == 'conv1x1': inputs = self.conv1x1(inputs) elif self._postproc_type == 'patches': inputs = reverse_space_to_depth(inputs, self._temporal_upsample, self._spatial_upsample) return inputs
def unroll(self, x, state): """Unrolls more efficiently than dynamic_unroll.""" if self._use_resnet: torso = AtariDeepTorso() else: torso = AtariShallowTorso() torso_output = hk.BatchApply(torso)(x.observation) if self._use_lstm: should_reset = jnp.equal(x.step_type, int(dm_env.StepType.FIRST)) core_input = (torso_output, should_reset) core_output, state = hk.dynamic_unroll(self._core, core_input, state) else: core_output = torso_output # state passes through. return hk.BatchApply(self._head)(core_output), state
def loss(self, params: hk.Params, trajs: Transition) -> jnp.ndarray: """Computes a loss of trajs wrt params.""" # Re-run the agent over the trajectories. # Due to https://github.com/google/jax/issues/1459, we use hk.BatchApply # instead of vmap. # BatchApply turns the input tensors from [T, B, ...] into [T*B, ...]. # We `functools.partial` params in so it does not get transformed. net_curried = hk.BatchApply(functools.partial(self._net, params)) learner_logits, baseline_with_bootstrap = net_curried(trajs.timestep) # Separate the bootstrap from the value estimates. baseline = baseline_with_bootstrap[:-1] baseline_tp1 = baseline_with_bootstrap[1:] # Remove bootstrap timestep from non-observations. _, actions, behavior_logits = jax.tree_map(lambda t: t[:-1], trajs) learner_logits = learner_logits[:-1] # Shift step_type/reward/discount back by one, so that actions match the # timesteps caused by the action. timestep = jax.tree_map(lambda t: t[1:], trajs.timestep) discount = timestep.discount * self._discount # The step is uninteresting if we transitioned LAST -> FIRST. mask = jnp.not_equal(timestep.step_type, int(dm_env.StepType.FIRST)) mask = mask.astype(jnp.float32) # Compute v-trace returns. vtrace_td_error_and_advantage = jax.vmap( rlax.vtrace_td_error_and_advantage, in_axes=1, out_axes=1) rhos = rlax.categorical_importance_sampling_ratios( learner_logits, behavior_logits, actions) vtrace_returns = vtrace_td_error_and_advantage(baseline, baseline_tp1, timestep.reward, discount, rhos) # Note that we use mean here, rather than sum as in canonical IMPALA. # Compute policy gradient loss. pg_advantage = jax.lax.stop_gradient(vtrace_returns.pg_advantage) tb_pg_loss_fn = jax.vmap(rlax.policy_gradient_loss, in_axes=1, out_axes=0) pg_loss = tb_pg_loss_fn(learner_logits, actions, pg_advantage, mask) pg_loss = jnp.mean(pg_loss) # Baseline loss. bl_loss = 0.5 * jnp.mean(jnp.square(vtrace_returns.errors) * mask) # Entropy regularization. ent_loss_fn = jax.vmap(rlax.entropy_loss, in_axes=1, out_axes=0) ent_loss = ent_loss_fn(learner_logits, mask) ent_loss = jnp.mean(ent_loss) total_loss = pg_loss + 0.5 * bl_loss + 0.01 * ent_loss return total_loss
def __call__(self, inputs, prev_state): current_input, return_target = inputs em_state, core_state = prev_state (counter, memories) = em_state if self._apply_core_to_input: current_input, core_state = self._core(current_input, core_state) # Synthetic return for the current state synth_return = jnp.squeeze(self._synthetic_return(current_input), -1) # Current state bias term bias = self._bias(current_input) # Gate computed from current state gate = self._gate(current_input) # When counter > capacity, mask will be all ones mask = 1 - jnp.cumsum(jax.nn.one_hot(counter, self._capacity), axis=1) mask = jnp.expand_dims(mask, axis=2) # Synthetic returns for each state in memory past_synth_returns = hk.BatchApply(self._synthetic_return)(memories) # Sum of synthetic returns from previous states sr_sum = jnp.sum(past_synth_returns * mask, axis=1) prediction = jnp.squeeze(sr_sum * gate + bias, -1) sr_loss = self._loss(prediction, return_target) augmented_return = jax.lax.stop_gradient( self._alpha * synth_return + self._beta * return_target) # Write current state to memory _, em_state = self._em(current_input, em_state) if not self._apply_core_to_input: output, core_state = self._core(current_input, core_state) else: output = current_input output = SRCoreWrapperOutput( output=output, synthetic_return=synth_return, augmented_return=augmented_return, sr_loss=sr_loss, ) return output, (em_state, core_state)
def unroll(self, x, state): """Unrolls more efficiently than dynamic_unroll.""" out, _ = hk.BatchApply(self)(x, None) return out, state