예제 #1
0
    def postprocess_trajectory(self, policy, sample_batch, tf_sess=None):
        """Calculates phi values (obs, obs', and predicted obs') and ri.

        Also calculates forward and inverse losses and updates the curiosity
        module on the provided batch using our optimizer.
        """
        # Push both observations through feature net to get both phis.
        phis, _ = self.model._curiosity_feature_net({
            SampleBatch.OBS:
            torch.cat([
                torch.from_numpy(sample_batch[SampleBatch.OBS]),
                torch.from_numpy(sample_batch[SampleBatch.NEXT_OBS])
            ])
        })
        phi, next_phi = torch.chunk(phis, 2)
        actions_tensor = torch.from_numpy(
            sample_batch[SampleBatch.ACTIONS]).long().to(policy.device)

        # Predict next phi with forward model.
        predicted_next_phi = self.model._curiosity_forward_fcnet(
            torch.cat(
                [phi, one_hot(actions_tensor, self.action_space).float()],
                dim=-1))

        # Forward loss term (predicted phi', given phi and action vs actually
        # observed phi').
        forward_l2_norm_sqared = 0.5 * torch.sum(
            torch.pow(predicted_next_phi - next_phi, 2.0), dim=-1)
        forward_loss = torch.mean(forward_l2_norm_sqared)

        # Scale intrinsic reward by eta hyper-parameter.
        sample_batch[SampleBatch.REWARDS] = \
            sample_batch[SampleBatch.REWARDS] + \
            self.eta * forward_l2_norm_sqared.detach().cpu().numpy()

        # Inverse loss term (prediced action that led from phi to phi' vs
        # actual action taken).
        phi_cat_next_phi = torch.cat([phi, next_phi], dim=-1)
        dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi)
        action_dist = TorchCategorical(dist_inputs, self.model) if \
            isinstance(self.action_space, Discrete) else \
            TorchMultiCategorical(
                dist_inputs, self.model, self.action_space.nvec)
        # Neg log(p); p=probability of observed action given the inverse-NN
        # predicted action distribution.
        inverse_loss = -action_dist.logp(actions_tensor)
        inverse_loss = torch.mean(inverse_loss)

        # Calculate the ICM loss.
        loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss
        # Perform an optimizer step.
        self._optimizer.zero_grad()
        loss.backward()
        self._optimizer.step()

        # Return the postprocessed sample batch (with the corrected rewards).
        return sample_batch
예제 #2
0
def build_CAT_vtrace_loss(policy, model, dist_class, train_batch):
    action_space_parts = model.action_space_parts

    def _make_time_major(*args, **kw):
        return make_time_major(policy, train_batch.get("seq_lens"), *args,
                               **kw)

    # Repeat the output_hidden_shape depending on the number of actions that have been generated
    # output_hidden_shape = np.tile(output_hidden_shape, action_repeats)

    actions = train_batch[SampleBatch.ACTIONS]
    dones = train_batch[SampleBatch.DONES]
    rewards = train_batch[SampleBatch.REWARDS]
    behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP]
    behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]

    invalid_action_mask = train_batch['invalid_action_mask']
    autoregressive_actions = policy.config['autoregressive_actions']

    if 'seq_lens' in train_batch:
        max_seq_len = policy.config['rollout_fragment_length']
        mask_orig = sequence_mask(train_batch["seq_lens"], max_seq_len)
        mask = torch.reshape(mask_orig, [-1])
    else:
        mask = torch.ones_like(rewards)

    actions_per_step = policy.config["actions_per_step"]

    states = []
    i = 0
    while "state_in_{}".format(i) in train_batch:
        states.append(train_batch["state_in_{}".format(i)])
        i += 1

    seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else []

    model.observation_features_module(train_batch, states, seq_lens)
    action_features, _ = model.action_features_module(train_batch, states, seq_lens)

    previous_action = None
    embedded_action = None
    logp_list = []
    entropy_list = []
    logits_list = []

    multi_actions = torch.chunk(actions, actions_per_step, dim=1)
    multi_invalid_action_mask = torch.chunk(invalid_action_mask, actions_per_step, dim=1)
    for a in range(actions_per_step):
        if autoregressive_actions:
            if a == 0:
                batch_size = action_features.shape[0]
                previous_action = torch.zeros([batch_size, len(action_space_parts)]).to(action_features.device)
            else:
                previous_action = multi_actions[a-1]

            embedded_action = model.embed_action_module(previous_action)

        logits = model.action_module(action_features, embedded_action)
        logits += torch.maximum(torch.tensor(torch.finfo().min), torch.log(multi_invalid_action_mask[a]))
        cat = TorchMultiCategorical(logits, model, action_space_parts)

        logits_list.append(logits)
        logp_list.append(cat.logp(multi_actions[a]))
        entropy_list.append(cat.entropy())

    logp = torch.stack(logp_list, dim=1).sum(dim=1)
    entropy = torch.stack(entropy_list, dim=1).sum(dim=1)
    target_logits = torch.hstack(logits_list)

    unpack_shape = np.tile(action_space_parts, actions_per_step)

    unpacked_behaviour_logits = torch.split(behaviour_logits, list(unpack_shape), dim=1)
    unpacked_outputs = torch.split(target_logits, list(unpack_shape), dim=1)

    values = model.value_function()

    # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc.
    policy.loss = VTraceLoss(
        actions=_make_time_major(actions, drop_last=True),
        actions_logp=_make_time_major(logp, drop_last=True),
        actions_entropy=_make_time_major(entropy, drop_last=True),
        dones=_make_time_major(dones, drop_last=True),
        behaviour_action_logp=_make_time_major(
            behaviour_action_logp, drop_last=True),
        behaviour_logits=_make_time_major(
            unpacked_behaviour_logits, drop_last=True),
        target_logits=_make_time_major(unpacked_outputs, drop_last=True),
        discount=policy.config["gamma"],
        rewards=_make_time_major(rewards, drop_last=True),
        values=_make_time_major(values, drop_last=True),
        bootstrap_value=_make_time_major(values)[-1],
        dist_class=TorchCategorical,
        model=model,
        valid_mask=_make_time_major(mask, drop_last=True),
        config=policy.config,
        vf_loss_coeff=policy.config["vf_loss_coeff"],
        entropy_coeff=policy.entropy_coeff,
        clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
        clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"])

    return policy.loss.total_loss