예제 #1
0
파일: mlp.py 프로젝트: Xingyu-Lin/softagent
    def sample(self, dist_info):
        logits, delta_dist_info = dist_info.cat_dist, dist_info.delta_dist
        u = torch.rand_like(logits)
        u = torch.clamp(u, 1e-5, 1 - 1e-5)
        gumbel = -torch.log(-torch.log(u))
        prob = F.softmax((logits + gumbel) / 10, dim=-1)

        cat_sample = torch.argmax(prob, dim=-1)
        one_hot = to_onehot(cat_sample, 4, dtype=torch.float32)

        if len(prob.shape) == 1: # Edge case for when it gets buffer shapes
            cat_sample = cat_sample.unsqueeze(0)

        if self._all_corners:
            mu, log_std = delta_dist_info.mean, delta_dist_info.log_std
            mu, log_std = mu.view(-1, 4, 3), log_std.view(-1, 4, 3)
            mu = select_at_indexes(cat_sample, mu)
            log_std = select_at_indexes(cat_sample, log_std)

            if len(prob.shape) == 1: # Edge case for when it gets buffer shapes
                mu, log_std = mu.squeeze(0), log_std.squeeze(0)

            new_dist_info = DistInfoStd(mean=mu, log_std=log_std)
        else:
            new_dist_info = delta_dist_info

        if self.training:
            self.delta_distribution.set_std(None)
        else:
            self.delta_distribution.set_std(0)
        delta_sample = self.delta_distribution.sample(new_dist_info)
        return torch.cat((one_hot, delta_sample), dim=-1)
예제 #2
0
파일: mlp.py 프로젝트: Xingyu-Lin/softagent
    def sample_loglikelihood(self, dist_info):
        logits, delta_dist_info = dist_info.cat_dist, dist_info.delta_dist

        u = torch.rand_like(logits)
        u = torch.clamp(u, 1e-5, 1 - 1e-5)
        gumbel = -torch.log(-torch.log(u))
        prob = F.softmax((logits + gumbel) / 10, dim=-1)

        cat_sample = torch.argmax(prob, dim=-1)
        cat_loglikelihood = select_at_indexes(cat_sample, prob)

        one_hot = to_onehot(cat_sample, 4, dtype=torch.float32)
        one_hot = (one_hot - prob).detach() + prob # Make action differentiable through prob

        if self._all_corners:
            mu, log_std = delta_dist_info.mean, delta_dist_info.log_std
            mu, log_std = mu.view(-1, 4, 3), log_std.view(-1, 4, 3)
            mu = mu[torch.arange(len(cat_sample)), cat_sample.squeeze(-1)]
            log_std = log_std[torch.arange(len(cat_sample)), cat_sample.squeeze(-1)]
            new_dist_info = DistInfoStd(mean=mu, log_std=log_std)
        else:
            new_dist_info = delta_dist_info

        delta_sample, delta_loglikelihood = self.delta_distribution.sample_loglikelihood(new_dist_info)
        action = torch.cat((one_hot, delta_sample), dim=-1)
        log_likelihood = cat_loglikelihood + delta_loglikelihood
        return action, log_likelihood
예제 #3
0
파일: mlp.py 프로젝트: Xingyu-Lin/softagent
    def sample(self, dist_info):
        if isinstance(dist_info, DistInfoStd):
            if self.training:
                self.delta_distribution.set_std(None)
            else:
                self.delta_distribution.set_std(0)
            action = self.delta_distribution.sample(dist_info)
        else:
            logits = dist_info
            u = torch.rand_like(logits)
            u = torch.clamp(u, 1e-5, 1 - 1e-5)
            gumbel = -torch.log(-torch.log(u))
            prob = F.softmax((logits + gumbel) / 10, dim=-1)

            cat_sample = torch.argmax(prob, dim=-1)
            action = to_onehot(cat_sample, 4, dtype=torch.float32)

        return action
예제 #4
0
파일: mlp.py 프로젝트: Xingyu-Lin/softagent
    def sample_loglikelihood(self, dist_info):
        if isinstance(dist_info, DistInfoStd):
            action, log_likelihood = self.delta_distribution.sample_loglikelihood(dist_info)
        else:
            logits = dist_info

            u = torch.rand_like(logits)
            u = torch.clamp(u, 1e-5, 1 - 1e-5)
            gumbel = -torch.log(-torch.log(u))
            prob = F.softmax((logits + gumbel) / 10, dim=-1)

            cat_sample = torch.argmax(prob, dim=-1)
            log_likelihood = select_at_indexes(cat_sample, prob)

            one_hot = to_onehot(cat_sample, 4, dtype=torch.float32)
            action = (one_hot - prob).detach() + prob  # Make action differentiable through prob

        return action, log_likelihood
예제 #5
0
    def forward(self,
                observation: torch.Tensor,
                prev_action: torch.Tensor = None,
                prev_state: RSSMState = None):
        lead_dim, T, B, img_shape = infer_leading_dims(observation, 3)
        observation = observation.reshape(T * B, *img_shape).type(
            self.dtype) / 255.0 - 0.5
        prev_action = to_onehot(prev_action.reshape(T * B, ),
                                self.action_size,
                                dtype=self.dtype)
        if prev_state is None:
            prev_state = self.representation.initial_state(
                prev_action.size(0),
                device=prev_action.device,
                dtype=self.dtype)
        state = self.get_state_representation(observation, prev_action,
                                              prev_state)

        action, action_dist = self.policy(state)
        action = from_onehot(action)
        return_spec = ModelReturnSpec(action, state)
        return_spec = buffer_func(return_spec, restore_leading_dims, lead_dim,
                                  T, B)
        return return_spec
예제 #6
0
 def to_onehot(self, indexes, dtype=None):
     """Convert from integer indexes to one-hot, preserving leading dimensions."""
     return to_onehot(indexes, self._dim, dtype=dtype or self.onehot_dtype)
예제 #7
0
 def to_onehot(self, indexes, dtype=None):
     return to_onehot(indexes, self._dim, dtype=dtype or self.onehot_dtype)
예제 #8
0
    def cpc_loss(self, samples):
        ##################################
        # Compute all the network outputs:

        observation = samples.observation

        prev_action = samples.prev_action
        if self.onehot_actions:
            prev_action = to_onehot(prev_action,
                                    self._act_dim,
                                    dtype=torch.float)
        prev_reward = samples.prev_reward
        observation, prev_action, prev_reward = buffer_to(
            (observation, prev_action, prev_reward), device=self.device)

        z_latent, conv_output = self.encoder(observation)  # [T,B,..]
        rnn_input = torch.cat(
            [z_latent, prev_action,
             prev_reward.unsqueeze(-1)],  # [T,B,..]
            dim=-1)
        context, _ = self.prediction_rnn(rnn_input)

        valid = valid_from_done(samples.done).type(torch.bool)

        # Extract only the ones to train (all were needed to compute).
        z_latent = z_latent[self.warmup_T:]
        conv_output = conv_output[self.warmup_T:]
        context = context[self.warmup_T:]
        valid = valid[self.warmup_T:]

        ###############################
        # Contrast the network outputs:

        # Should have T,B,C=context.shape, T,B=valid.shape, T,B,Z=z_latent.shape
        T, B, Z = z_latent.shape
        target_trans = z_latent.view(-1, Z).transpose(
            1, 0)  # [T,B,H]->[T*B,H]->[H,T*B]

        # Draw from base_labels according to the location of the corresponding
        # positive latent for contrast, using [T,B]; will give the location
        # within T*B.
        base_labels = torch.arange(T * B, dtype=torch.long,
                                   device=self.device).view(T, B)
        base_labels[~valid] = IGNORE_INDEX  # By location of z_latent.

        # All predictions and labels into one tensor for efficient contrasting.
        prediction_list = list()
        label_list = list()
        for delta_t in range(1, T):
            # Predictions based on context starting from t=0 up to the point where
            # there isn't a future latent within the timesteps of the minibatch.
            # [T-dt,B,C] -> [T-dt,B,H] -> [(T-dt)*B,H]
            prediction_list.append(self.transforms[delta_t](
                context[:-delta_t]).view(-1, Z))
            # The correct latent is delta_t time steps ahead:
            # [T-dt,B] -> [(T-dt)*B]
            label_list.append(base_labels[delta_t:].view(-1))

        # Before cat, to isolate delta_t for diagnostic accuracy check later:
        dt_lengths = [0] + [len(label) for label in label_list]
        dtb = torch.cumsum(torch.tensor(dt_lengths),
                           dim=0)  # delta_t_boundaries

        # Total number of predictions: P = T*(T-1)/2*B
        # from: \sum_{dt=1}^T ((T-dt) * B)
        predictions = torch.cat(prediction_list)  # [P,H]
        labels = torch.cat(label_list)  # [P]
        # contrast against ALL latents, not just the "future" ones:
        logits = torch.matmul(predictions,
                              target_trans)  # [P,H]*[H,T*B] -> [P,T*B]
        logits = logits - torch.max(logits, dim=1,
                                    keepdim=True)[0]  # [P,T*B] normalize
        cpc_loss = self.c_e_loss(logits,
                                 labels)  # every logit weighted equally

        ##################################################
        # Compute some downsampled accuracies for diagnostics:

        logits_d = logits.detach()
        # begin, end, step (downsample):
        b, e, s = dtb[0], dtb[1], 4  # delta_t = 1
        logits1, labels1 = logits_d[b:e:s], labels[b:e:s]
        correct1 = torch.argmax(logits1, dim=1) == labels1
        accuracy1 = valid_mean(correct1.float(),
                               valid=labels1 >= 0)  # IGNORE=-100

        b, e, s = dtb[1], dtb[2], 4  # delta_t = 2
        logits2, labels2 = logits_d[b:e:s], labels[b:e:s]
        correct2 = torch.argmax(logits2, dim=1) == labels2
        accuracy2 = valid_mean(correct2.float(), valid=labels2 >= 0)

        b, e, s = dtb[-2], dtb[-1], 1  # delta_t = T - 1
        logitsT1, labelsT1 = logits_d[b:e:s], labels[b:e:s]
        correctT1 = torch.argmax(logitsT1, dim=1) == labelsT1
        accuracyT1 = valid_mean(correctT1.float(), valid=labelsT1 >= 0)

        b, e, s = dtb[-3], dtb[-2], 1  # delta_t = T - 2
        logitsT2, labelsT2 = logits_d[b:e:s], labels[b:e:s]
        correctT2 = torch.argmax(logitsT2, dim=1) == labelsT2
        accuracyT2 = valid_mean(correctT2.float(), valid=labelsT2 >= 0)

        accuracies = (accuracy1, accuracy2, accuracyT1, accuracyT2)

        return cpc_loss, accuracies, conv_output
예제 #9
0
 def to_onehot(self, indexes, dtype=None):
     """
     参数里使用了 or 表达式,使得当 dtype=None 时,表达式的值为 self.onehot_dtype,当 dtype 不为 None 时,表达式的值为传入的dtype
     """
     return to_onehot(indexes, self._dim, dtype=dtype or self.onehot_dtype)
예제 #10
0
    def loss(self, samples: Samples):
        """
        Compute the loss for a batch of data.  This includes computing the model and reward losses on the given data,
        as well as using the dynamics model to generate additional rollouts, which are used for the actor and value
        components of the loss.
        :param samples: rlpyt Samples object, containing a batch of data.  All vectors are [timestep, batch, vector_dim]
        :return: FloatTensor containing the loss
        """
        model = self.agent.model
        device = next(model.parameters()).device

        # Extract tensors from the Samples object
        # They all have the batch_t dimension first, but we'll put the batch_b dimension first.
        # Also, we convert all tensors to floats so they can be fed into our models.

        lead_dim, batch_t, batch_b, img_shape = infer_leading_dims(
            samples.env.observation, 3)
        # squeeze batch sizes to single batch dimension for imagination roll-out
        batch_size = batch_t * batch_b

        observation = samples.env.observation.to(device)
        # normalize image
        observation = observation.type(self.type) / 255.0 - 0.5
        # embed the image
        embed = model.observation_encoder(observation)

        # get action
        action = samples.agent.action
        # make actions one-hot
        action = to_onehot(action, model.action_size,
                           dtype=self.type).to(device)

        reward = samples.env.reward.to(device)

        # if we want to continue the the agent state from the previous time steps, we can do it like so:
        # prev_state = samples.agent.agent_info.prev_state[0]
        prev_state = model.representation.initial_state(batch_b, device=device)
        # Rollout model by taking the same series of actions as the real model
        post, prior = model.rollout.rollout_representation(
            batch_t, embed, action, prev_state)
        # Flatten our data (so first dimension is batch_t * batch_b = batch_size)
        # since we're going to do a new rollout starting from each state visited in each batch.
        flat_post = RSSMState(mean=post.mean.reshape(batch_size, -1),
                              std=post.std.reshape(batch_size, -1),
                              stoch=post.stoch.reshape(batch_size, -1),
                              deter=post.deter.reshape(batch_size, -1))
        flat_action = action.reshape(batch_size, -1)
        # Rollout the policy for self.horizon steps. Variable names with imag_ indicate this data is imagined not real.
        # imag_feat shape is [horizon, batch_t * batch_b, feature_size]
        imag_dist, _ = model.rollout.rollout_policy(self.horizon, model.policy,
                                                    flat_action, flat_post)

        # Use state features (deterministic and stochastic) to predict the image and reward
        imag_feat = get_feat(
            imag_dist)  # [horizon, batch_t * batch_b, feature_size]
        # Assumes these are normal distributions. In the TF code it's be mode, but for a normal distribution mean = mode
        # If we want to use other distributions we'll have to fix this.
        imag_reward = model.reward_model(imag_feat).mean
        value = model.value_model(imag_feat).mean

        # Compute the exponential discounted sum of rewards
        discount_arr = self.discount * torch.ones_like(imag_reward)
        returns = self.compute_return(imag_reward[:-1],
                                      value[:-1],
                                      discount_arr[:-1],
                                      bootstrap=value[-1],
                                      lambda_=self.discount_lambda)
        discount = torch.cumprod(discount_arr[:-1], 1).detach()

        # Compute losses for each component of the model
        model_loss = self.model_loss(observation, prior, post, reward)
        actor_loss = self.actor_loss(discount, returns)
        value_loss = self.value_loss(imag_feat, discount, returns)
        loss = self.model_weight * model_loss + self.actor_weight * actor_loss + self.value_weight * value_loss
        return loss