Пример #1
0
 def net_fn(iqn_inputs):
     """Function representing IQN-DQN Q-network."""
     state = iqn_inputs.state  # batch x state_shape
     taus = iqn_inputs.taus  # batch x samples
     # Apply DQN convnet to embed state.
     state_embedding = dqn_torso()(state)
     state_dim = state_embedding.shape[-1]
     # Embed taus with cosine embedding + linear layer.
     # cos(pi * i * tau) for i = 1,...,latents for each batch_element x sample.
     # Broadcast everything to batch x samples x latent_dim.
     pi_multiples = jnp.arange(1, latent_dim + 1,
                               dtype=jnp.float32) * jnp.pi
     tau_embedding = jnp.cos(pi_multiples[None, None, :] * taus[:, :, None])
     # Map tau embedding onto state_dim via linear layer.
     embedding_layer = linear(state_dim)
     tau_embedding = hk.BatchApply(embedding_layer)(tau_embedding)
     tau_embedding = jax.nn.relu(tau_embedding)
     # Reshape/broadcast both embeddings to batch x num_samples x state_dim
     # and multiply together, before applying value head.
     head_input = tau_embedding * state_embedding[:, None, :]
     value_head = dqn_value_head(num_actions)
     q_dist = hk.BatchApply(value_head)(head_input)
     q_values = jnp.mean(q_dist, axis=1)
     q_values = jax.lax.stop_gradient(q_values)
     return IqnOutputs(q_dist=q_dist, q_values=q_values)
Пример #2
0
 def unroll(
     self,
     inputs: observation_action_reward.OAR,  # [T, B, ...]
     state: hk.LSTMState  # [T, ...]
 ) -> Tuple[base.QValues, hk.LSTMState]:
   """Efficient unroll that applies torso, core, and duelling mlp in one pass."""
   embeddings = hk.BatchApply(self._embed)(inputs)  # [T, B, D+A+1]
   core_outputs, new_states = hk.static_unroll(self._core, embeddings, state)
   q_values = hk.BatchApply(self._duelling_head)(core_outputs)  # [T, B, A]
   return q_values, new_states
Пример #3
0
    def __call__(self,
                 inputs: jnp.ndarray,
                 *,
                 is_training: bool,
                 pos: Optional[jnp.ndarray] = None,
                 network_input_is_1d: bool = True) -> PreprocessorOutputT:
        if self._prep_type == 'conv':
            # Convnet image featurization.
            # Downsamples spatially by a factor of 4
            conv = self.convnet
            if len(inputs.shape) == 5:
                conv = hk.BatchApply(conv)

            inputs = conv(inputs, is_training=is_training)
        elif self._prep_type == 'conv1x1':
            # maps inputs to 64d

            conv = self.convnet_1x1

            if len(inputs.shape) == 5:
                conv = hk.BatchApply(conv)

            inputs = conv(inputs)
        elif self._prep_type == 'patches':
            # Space2depth featurization.
            # Video: B x T x H x W x C
            inputs = space_to_depth(
                inputs,
                temporal_block_size=self._temporal_downsample,
                spatial_block_size=self._spatial_downsample)

            if inputs.ndim == 5 and inputs.shape[1] == 1:
                # for flow
                inputs = jnp.squeeze(inputs, axis=1)

            if self._conv_after_patching:
                inputs = hk.Linear(self._num_channels,
                                   name='patches_linear')(inputs)
        elif self._prep_type == 'pixels':
            # if requested, downsamples in the crudest way
            if inputs.ndim == 4:
                inputs = inputs[:, ::self._spatial_downsample, ::self.
                                _spatial_downsample]
            elif inputs.ndim == 5:
                inputs = inputs[:, ::self._temporal_downsample, ::self.
                                _spatial_downsample, ::self.
                                _spatial_downsample]
            else:
                raise ValueError('Unsupported data format for pixels.')

        inputs, inputs_without_pos = self._build_network_inputs(
            inputs, pos, network_input_is_1d)
        modality_sizes = None  # Size for each modality, only needed for multimodal
        return inputs, modality_sizes, inputs_without_pos
Пример #4
0
    def __call__(
            self,
            inputs: jnp.ndarray,
            *,
            is_training: bool,
            pos: Optional[jnp.ndarray] = None,
            modality_sizes: Optional[ModalitySizeT] = None) -> jnp.ndarray:
        if self._input_reshape_size is not None:
            inputs = jnp.reshape(inputs, [inputs.shape[0]] +
                                 list(self._input_reshape_size) +
                                 [inputs.shape[-1]])

        if self._postproc_type == 'conv' or self._postproc_type == 'raft':
            # Convnet image featurization.
            conv = self.convnet
            if len(inputs.shape) == 5 and self._temporal_upsample == 1:
                conv = hk.BatchApply(conv)
            inputs = conv(inputs, is_training=is_training)
        elif self._postproc_type == 'conv1x1':
            inputs = self.conv1x1(inputs)
        elif self._postproc_type == 'patches':
            inputs = reverse_space_to_depth(inputs, self._temporal_upsample,
                                            self._spatial_upsample)

        return inputs
Пример #5
0
    def unroll(self, x, state):
        """Unrolls more efficiently than dynamic_unroll."""
        if self._use_resnet:
            torso = AtariDeepTorso()
        else:
            torso = AtariShallowTorso()

        torso_output = hk.BatchApply(torso)(x.observation)
        if self._use_lstm:
            should_reset = jnp.equal(x.step_type, int(dm_env.StepType.FIRST))
            core_input = (torso_output, should_reset)
            core_output, state = hk.dynamic_unroll(self._core, core_input,
                                                   state)
        else:
            core_output = torso_output
            # state passes through.

        return hk.BatchApply(self._head)(core_output), state
Пример #6
0
    def loss(self, params: hk.Params, trajs: Transition) -> jnp.ndarray:
        """Computes a loss of trajs wrt params."""
        # Re-run the agent over the trajectories.
        # Due to https://github.com/google/jax/issues/1459, we use hk.BatchApply
        # instead of vmap.
        # BatchApply turns the input tensors from [T, B, ...] into [T*B, ...].
        # We `functools.partial` params in so it does not get transformed.
        net_curried = hk.BatchApply(functools.partial(self._net, params))
        learner_logits, baseline_with_bootstrap = net_curried(trajs.timestep)

        # Separate the bootstrap from the value estimates.
        baseline = baseline_with_bootstrap[:-1]
        baseline_tp1 = baseline_with_bootstrap[1:]

        # Remove bootstrap timestep from non-observations.
        _, actions, behavior_logits = jax.tree_map(lambda t: t[:-1], trajs)
        learner_logits = learner_logits[:-1]

        # Shift step_type/reward/discount back by one, so that actions match the
        # timesteps caused by the action.
        timestep = jax.tree_map(lambda t: t[1:], trajs.timestep)
        discount = timestep.discount * self._discount
        # The step is uninteresting if we transitioned LAST -> FIRST.
        mask = jnp.not_equal(timestep.step_type, int(dm_env.StepType.FIRST))
        mask = mask.astype(jnp.float32)

        # Compute v-trace returns.
        vtrace_td_error_and_advantage = jax.vmap(
            rlax.vtrace_td_error_and_advantage, in_axes=1, out_axes=1)
        rhos = rlax.categorical_importance_sampling_ratios(
            learner_logits, behavior_logits, actions)
        vtrace_returns = vtrace_td_error_and_advantage(baseline, baseline_tp1,
                                                       timestep.reward,
                                                       discount, rhos)

        # Note that we use mean here, rather than sum as in canonical IMPALA.
        # Compute policy gradient loss.
        pg_advantage = jax.lax.stop_gradient(vtrace_returns.pg_advantage)
        tb_pg_loss_fn = jax.vmap(rlax.policy_gradient_loss,
                                 in_axes=1,
                                 out_axes=0)
        pg_loss = tb_pg_loss_fn(learner_logits, actions, pg_advantage, mask)
        pg_loss = jnp.mean(pg_loss)

        # Baseline loss.
        bl_loss = 0.5 * jnp.mean(jnp.square(vtrace_returns.errors) * mask)

        # Entropy regularization.
        ent_loss_fn = jax.vmap(rlax.entropy_loss, in_axes=1, out_axes=0)
        ent_loss = ent_loss_fn(learner_logits, mask)
        ent_loss = jnp.mean(ent_loss)

        total_loss = pg_loss + 0.5 * bl_loss + 0.01 * ent_loss
        return total_loss
Пример #7
0
  def __call__(self, inputs, prev_state):
    current_input, return_target = inputs

    em_state, core_state = prev_state
    (counter, memories) = em_state

    if self._apply_core_to_input:
      current_input, core_state = self._core(current_input, core_state)

    # Synthetic return for the current state
    synth_return = jnp.squeeze(self._synthetic_return(current_input), -1)

    # Current state bias term
    bias = self._bias(current_input)

    # Gate computed from current state
    gate = self._gate(current_input)

    # When counter > capacity, mask will be all ones
    mask = 1 - jnp.cumsum(jax.nn.one_hot(counter, self._capacity), axis=1)
    mask = jnp.expand_dims(mask, axis=2)

    # Synthetic returns for each state in memory
    past_synth_returns = hk.BatchApply(self._synthetic_return)(memories)

    # Sum of synthetic returns from previous states
    sr_sum = jnp.sum(past_synth_returns * mask, axis=1)

    prediction = jnp.squeeze(sr_sum * gate + bias, -1)
    sr_loss = self._loss(prediction, return_target)

    augmented_return = jax.lax.stop_gradient(
        self._alpha * synth_return + self._beta * return_target)

    # Write current state to memory
    _, em_state = self._em(current_input, em_state)

    if not self._apply_core_to_input:
      output, core_state = self._core(current_input, core_state)
    else:
      output = current_input

    output = SRCoreWrapperOutput(
        output=output,
        synthetic_return=synth_return,
        augmented_return=augmented_return,
        sr_loss=sr_loss,
    )
    return output, (em_state, core_state)
Пример #8
0
 def unroll(self, x, state):
     """Unrolls more efficiently than dynamic_unroll."""
     out, _ = hk.BatchApply(self)(x, None)
     return out, state