Exemplo n.º 1
0
class BaseAlgo(ABC):
    """The base class for RL algorithms."""
    def __init__(self, envs, acmodel, device, num_frames_per_proc, discount,
                 lr, gae_lambda, entropy_coef, value_loss_coef, max_grad_norm,
                 recurrence, preprocess_obss, reshape_reward):
        """
        Initializes a `BaseAlgo` instance.

        Parameters:
        ----------
        envs : list
            a list of environments that will be run in parallel
        acmodel : torch.Module
            the model
        num_frames_per_proc : int
            the number of frames collected by every process for an update
        discount : float
            the discount for future rewards
        lr : float
            the learning rate for optimizers
        gae_lambda : float
            the lambda coefficient in the GAE formula
            ([Schulman et al., 2015](https://arxiv.org/abs/1506.02438))
        entropy_coef : float
            the weight of the entropy cost in the final objective
        value_loss_coef : float
            the weight of the value loss in the final objective
        max_grad_norm : float
            gradient will be clipped to be at most this value
        recurrence : int
            the number of steps the gradient is propagated back in time
        preprocess_obss : function
            a function that takes observations returned by the environment
            and converts them into the format that the model can handle
        reshape_reward : function
            a function that shapes the reward, takes an
            (observation, action, reward, done) tuple as an input
        """

        # Store parameters

        self.env = ParallelEnv(envs)
        self.acmodel = acmodel
        self.device = device
        self.num_frames_per_proc = num_frames_per_proc
        self.discount = discount
        self.lr = lr
        self.gae_lambda = gae_lambda
        self.entropy_coef = entropy_coef
        self.value_loss_coef = value_loss_coef
        self.max_grad_norm = max_grad_norm
        self.recurrence = recurrence
        self.preprocess_obss = preprocess_obss or default_preprocess_obss
        self.reshape_reward = reshape_reward

        # Control parameters

        assert self.acmodel.recurrent or self.recurrence == 1
        assert self.num_frames_per_proc % self.recurrence == 0

        # Configure acmodel

        self.acmodel.to(self.device)
        self.acmodel.train()

        # Store helpers values

        self.num_procs = len(envs)
        self.num_frames = self.num_frames_per_proc * self.num_procs

        # Initialize experience values

        shape = (self.num_frames_per_proc, self.num_procs)

        self.obs = self.env.reset()
        self.obss = [None] * (shape[0])
        if self.acmodel.recurrent:
            self.memory = torch.zeros(shape[1],
                                      self.acmodel.memory_size,
                                      device=self.device)
            self.memories = torch.zeros(*shape,
                                        self.acmodel.memory_size,
                                        device=self.device)
        self.mask = torch.ones(shape[1], device=self.device)
        self.masks = torch.zeros(*shape, device=self.device)
        self.actions = torch.zeros(*shape, device=self.device, dtype=torch.int)
        self.values = torch.zeros(*shape, device=self.device)
        self.rewards = torch.zeros(*shape, device=self.device)
        self.advantages = torch.zeros(*shape, device=self.device)
        self.log_probs = torch.zeros(*shape, device=self.device)

        # Initialize log values

        self.log_episode_return = torch.zeros(self.num_procs,
                                              device=self.device)
        self.log_episode_reshaped_return = torch.zeros(self.num_procs,
                                                       device=self.device)
        self.log_episode_num_frames = torch.zeros(self.num_procs,
                                                  device=self.device)

        self.log_done_counter = 0
        self.log_return = [0] * self.num_procs
        self.log_reshaped_return = [0] * self.num_procs
        self.log_num_frames = [0] * self.num_procs
        self.step_counter = [0] * self.num_procs

    def collect_experiences(self):
        """Collects rollouts and computes advantages.

        Runs several environments concurrently. The next actions are computed
        in a batch mode for all environments at the same time. The rollouts
        and advantages from all environments are concatenated together.

        Returns
        -------
        exps : DictList
            Contains actions, rewards, advantages etc as attributes.
            Each attribute, e.g. `exps.reward` has a shape
            (self.num_frames_per_proc * num_envs, ...). k-th block
            of consecutive `self.num_frames_per_proc` frames contains
            data obtained from the k-th environment. Be careful not to mix
            data from different environments!
        logs : dict
            Useful stats about the training process, including the average
            reward, policy loss, value loss, etc.
        """

        for i in range(self.num_frames_per_proc):
            # Do one agent-environment interaction

            preprocessed_obs = self.preprocess_obss(self.obs,
                                                    device=self.device)
            with torch.no_grad():
                if self.acmodel.recurrent:
                    dist, value, memory = self.acmodel(
                        preprocessed_obs, self.memory * self.mask.unsqueeze(1))
                else:
                    dist, value = self.acmodel(preprocessed_obs)
            action = dist.sample()

            obs, reward, done, _ = self.env.step(action.cpu().numpy())

            self.step_counter = [x + 1 for x in self.step_counter]

            # Update experiences values

            self.obss[i] = self.obs
            self.obs = obs
            if self.acmodel.recurrent:
                self.memories[i] = self.memory
                self.memory = memory
            self.masks[i] = self.mask
            self.mask = 1 - torch.tensor(
                done, device=self.device, dtype=torch.float)
            self.actions[i] = action
            self.values[i] = value
            if self.reshape_reward is not None:
                self.rewards[i] = torch.tensor([
                    self.reshape_reward(obs_, action_, reward_, done_)
                    for obs_, action_, reward_, done_ in zip(
                        obs, action, reward, done)
                ],
                                               device=self.device)
            else:
                self.rewards[i] = torch.tensor(reward,
                                               device=self.device,
                                               dtype=torch.float)
            self.log_probs[i] = dist.log_prob(action)

            # get discounts
            discounts = [self.discount**(x - 1) for x in self.step_counter]
            discounts = torch.tensor(discounts,
                                     device=self.device,
                                     dtype=torch.float)

            # Update log values

            self.log_episode_return += torch.tensor(
                reward, device=self.device, dtype=torch.float) * discounts
            self.log_episode_reshaped_return += self.rewards[i] * discounts
            self.log_episode_num_frames += torch.ones(self.num_procs,
                                                      device=self.device)

            for i, done_ in enumerate(done):
                if done_:
                    self.log_done_counter += 1
                    self.log_return.append(self.log_episode_return[i].item())
                    self.log_reshaped_return.append(
                        self.log_episode_reshaped_return[i].item())
                    self.log_num_frames.append(
                        self.log_episode_num_frames[i].item())
                    self.step_counter[i] = 0

            self.log_episode_return *= self.mask
            self.log_episode_reshaped_return *= self.mask
            self.log_episode_num_frames *= self.mask

        # Add advantage and return to experiences

        preprocessed_obs = self.preprocess_obss(self.obs, device=self.device)
        with torch.no_grad():
            if self.acmodel.recurrent:
                _, next_value, _ = self.acmodel(
                    preprocessed_obs, self.memory * self.mask.unsqueeze(1))
            else:
                _, next_value = self.acmodel(preprocessed_obs)

        for i in reversed(range(self.num_frames_per_proc)):
            next_mask = self.masks[
                i + 1] if i < self.num_frames_per_proc - 1 else self.mask
            next_value = self.values[
                i + 1] if i < self.num_frames_per_proc - 1 else next_value
            next_advantage = self.advantages[
                i + 1] if i < self.num_frames_per_proc - 1 else 0

            delta = self.rewards[
                i] + self.discount * next_value * next_mask - self.values[i]
            self.advantages[
                i] = delta + self.discount * self.gae_lambda * next_advantage * next_mask

        # Define experiences:
        #   the whole experience is the concatenation of the experience
        #   of each process.
        # In comments below:
        #   - T is self.num_frames_per_proc,
        #   - P is self.num_procs,
        #   - D is the dimensionality.

        exps = DictList()
        exps.obs = [
            self.obss[i][j] for j in range(self.num_procs)
            for i in range(self.num_frames_per_proc)
        ]
        if self.acmodel.recurrent:
            # T x P x D -> P x T x D -> (P * T) x D
            exps.memory = self.memories.transpose(0, 1).reshape(
                -1, *self.memories.shape[2:])
            # T x P -> P x T -> (P * T) x 1
            exps.mask = self.masks.transpose(0, 1).reshape(-1).unsqueeze(1)
        # for all tensors below, T x P -> P x T -> P * T
        exps.action = self.actions.transpose(0, 1).reshape(-1)
        exps.value = self.values.transpose(0, 1).reshape(-1)
        exps.reward = self.rewards.transpose(0, 1).reshape(-1)
        exps.advantage = self.advantages.transpose(0, 1).reshape(-1)
        exps.returnn = exps.value + exps.advantage
        exps.log_prob = self.log_probs.transpose(0, 1).reshape(-1)

        # Preprocess experiences

        exps.obs = self.preprocess_obss(exps.obs, device=self.device)

        # Log some values

        keep = max(self.log_done_counter, self.num_procs)
        logs = {
            "return_per_episode": self.log_return[-keep:],
            "reshaped_return_per_episode": self.log_reshaped_return[-keep:],
            "num_frames_per_episode": self.log_num_frames[-keep:],
            "num_frames": self.num_frames
        }

        self.log_done_counter = 0
        self.log_return = self.log_return[-self.num_procs:]
        self.log_reshaped_return = self.log_reshaped_return[-self.num_procs:]
        self.log_num_frames = self.log_num_frames[-self.num_procs:]

        return exps, logs

    @abstractmethod
    def update_parameters(self):
        pass
Exemplo n.º 2
0
class MultiQAlgo(ABC):
    """The base class for RL algorithms."""
    def __init__(self,
                 envs,
                 model,
                 device=None,
                 num_frames_per_proc=None,
                 discount=0.99,
                 lr=0.001,
                 recurrence=4,
                 adam_eps=1e-8,
                 buffer_size=10000,
                 preprocess_obss=None,
                 reshape_reward=None):
        """
        Initializes a `BaseAlgo` instance.

        Parameters:
        ----------
        envs : list
            a list of environments that will be run in parallel
        acmodel : torch.Module
            the model
        num_frames_per_proc : int
            the number of frames collected by every process for an update
        discount : float
            the discount for future rewards
        lr : float
            the learning rate for optimizers
        gae_lambda : float
            the lambda coefficient in the GAE formula
            ([Schulman et al., 2015](https://arxiv.org/abs/1506.02438))
        entropy_coef : float
            the weight of the entropy cost in the final objective
        value_loss_coef : float
            the weight of the value loss in the final objective
        max_grad_norm : float
            gradient will be clipped to be at most this value
        recurrence : int
            the number of steps the gradient is propagated back in time
        preprocess_obss : function
            a function that takes observations returned by the environment
            and converts them into the format that the model can handle
        reshape_reward : function
            a function that shapes the reward, takes an
            (observation, action, reward, done) tuple as an input
        """

        # Store parameters

        num_frames_per_proc = num_frames_per_proc or 128  # is 128 correct here?

        self.env = ParallelEnv(envs)
        self.model = model
        self.eval_model = deepcopy(model)
        self.device = device
        self.num_frames_per_proc = num_frames_per_proc
        self.discount = discount
        self.lr = lr
        self.recurrence = recurrence
        self.preprocess_obss = preprocess_obss or default_preprocess_obss
        self.reshape_reward = reshape_reward

        self.reward_size = self.model.reward_size

        # Control parameters

        assert self.model.recurrent or self.recurrence == 1
        assert self.num_frames_per_proc % self.recurrence == 0

        # Configure acmodel

        self.model.to(self.device)
        self.model.train()

        # Store helpers values

        self.num_procs = len(envs)
        self.num_frames = self.num_frames_per_proc * self.num_procs

        # Initialize experience values

        shape = (self.num_frames_per_proc, self.num_procs)

        self.obs = self.env.reset()
        self.obss = [None] * (shape[0])
        if self.model.recurrent:
            self.memory = torch.zeros(shape[1],
                                      self.model.memory_size,
                                      device=self.device)
            self.memories = torch.zeros(*shape,
                                        self.model.memory_size,
                                        device=self.device)
        self.mask = torch.ones(shape[1], device=self.device)
        self.masks = torch.zeros(*shape, device=self.device)
        #        self.masks = torch.zeros(*shape, self.reward_size, device=self.device)
        self.actions = torch.zeros(*shape, device=self.device, dtype=torch.int)
        self.values = torch.zeros(*shape,
                                  self.env.action_space.n,
                                  self.reward_size,
                                  device=self.device)
        self.expected_values = torch.zeros(*shape,
                                           self.reward_size,
                                           device=self.device)
        self.rewards = torch.zeros(*shape,
                                   self.reward_size,
                                   device=self.device)
        self.log_probs = torch.zeros(*shape, device=self.device)

        # initialize the pareto weights

        self.weights = torch.ones(
            shape[1], self.reward_size, device=self.device) / self.reward_size

        # Initialize log values

        self.log_episode_return = torch.zeros(self.num_procs,
                                              self.reward_size,
                                              device=self.device)
        self.log_episode_reshaped_return = torch.zeros(self.num_procs,
                                                       self.reward_size,
                                                       device=self.device)
        self.log_episode_num_frames = torch.zeros(self.num_procs,
                                                  device=self.device)

        self.log_done_counter = 0
        self.log_return = [0] * self.num_procs
        self.log_reshaped_return = [0] * self.num_procs
        self.log_num_frames = [0] * self.num_procs

        self.buffer = ReplayBuffer(capacity=buffer_size)

        self.eps = 0.05

        #self.optimizer = torch.optim.Adam(self.model.parameters(), lr, eps=adam_eps)
        self.optimizer = torch.optim.RMSprop(params=self.model.parameters(),
                                             lr=self.lr)

    def collect_experiences(self):
        for n in range(self.num_frames_per_proc):
            # calculate the prediction based on current state/obs
            preprocessed_obs = self.preprocess_obss(self.obs,
                                                    device=self.device)
            with torch.no_grad():
                if self.model.recurrent:
                    q_value, memory = self.model(
                        preprocessed_obs, self.memory * self.mask.unsqueeze(1))
                else:
                    q_value = self.model(preprocessed_obs)

            # select an action based on the q-values
            action = self.pareto_action(q_value, self.weights)
            # overwrite the action based on epsilon
            eps_mask = torch.rand(action.shape) < self.eps
            action[eps_mask] = torch.randint(0, self.env.action_space.n,
                                             (sum(eps_mask), ))

            # step the environment based on the predicted action
            next_obs, reward, done, _ = self.env.step(action.cpu().numpy())

            next_preprocessed_obs = self.preprocess_obss(next_obs,
                                                         device=self.device)
            with torch.no_grad():
                if self.model.recurrent:
                    next_q, next_memory = self.model(
                        preprocessed_obs, self.memory * self.mask.unsqueeze(1))
                else:
                    next_q = self.model(preprocessed_obs)

        return exps, logs

    def update_parameters(self, exps):
        self.buffer.push(exps)

        return logs

    def collect_experiences_old(self):
        """Collects rollouts and computes advantages.

        Runs several environments concurrently. The next actions are computed
        in a batch mode for all environments at the same time. The rollouts
        and advantages from all environments are concatenated together.

        Returns
        -------
        exps : DictList
            Contains actions, rewards, advantages etc as attributes.
            Each attribute, e.g. `exps.reward` has a shape
            (self.num_frames_per_proc * num_envs, ...). k-th block
            of consecutive `self.num_frames_per_proc` frames contains
            data obtained from the k-th environment. Be careful not to mix
            data from different environments!
        logs : dict
            Useful stats about the training process, including the average
            reward, policy loss, value loss, etc.
        """

        for i in range(self.num_frames_per_proc):
            # Do one agent-environment interaction

            preprocessed_obs = self.preprocess_obss(self.obs,
                                                    device=self.device)
            with torch.no_grad():
                if self.model.recurrent:
                    value, memory = self.model(
                        preprocessed_obs, self.memory * self.mask.unsqueeze(1))
                else:
                    value = self.model(preprocessed_obs)
            action = self.pareto_action(value, self.weights)
            eps_mask = torch.rand(action.shape) < self.eps
            action[eps_mask] = torch.randint(0, self.env.action_space.n,
                                             (sum(eps_mask), ))

            obs, reward, done, _ = self.env.step(action.cpu().numpy())

            # Update experiences values

            self.obss[i] = self.obs
            self.obs = obs
            if self.model.recurrent:
                self.memories[i] = self.memory
                self.memory = memory
            self.masks[i] = self.mask
            self.mask = 1 - torch.tensor(
                done, device=self.device, dtype=torch.float)
            self.actions[i] = action
            self.values[i] = value
            if self.reshape_reward is not None:
                self.rewards[i] = torch.tensor([
                    self.reshape_reward(obs_, action_, reward_, done_)
                    for obs_, action_, reward_, done_ in zip(
                        obs, action, reward, done)
                ],
                                               device=self.device)
            else:
                self.rewards[i] = torch.tensor(reward, device=self.device)

            # Update log values

            self.log_episode_return += torch.tensor(reward,
                                                    device=self.device,
                                                    dtype=torch.float)
            self.log_episode_reshaped_return += self.rewards[i]
            self.log_episode_num_frames += torch.ones(self.num_procs,
                                                      device=self.device)

            for i, done_ in enumerate(done):
                if done_:
                    self.log_done_counter += 1
                    self.log_return.append(
                        self.log_episode_return[i])  #.item())
                    self.log_reshaped_return.append(
                        self.log_episode_reshaped_return[i])  #.item())
                    self.log_num_frames.append(
                        self.log_episode_num_frames[i].item())

                    # reroll the weights for that episode
                    if self.reward_size == 1:
                        self.weights[i, 0] = 1
                    elif self.reward_size == 2:
                        self.weights[i, 0] = torch.rand(1)
                        self.weights[i, 1] = 1 - self.weights[i, 0]
                    else:
                        raise NotImplementedError

            self.log_episode_return = (self.log_episode_return.T * self.mask).T
            self.log_episode_reshaped_return = (
                self.log_episode_reshaped_return.T * self.mask).T
            self.log_episode_num_frames *= self.mask

        preprocessed_obs = self.preprocess_obss(self.obs, device=self.device)
        with torch.no_grad():
            if self.model.recurrent:
                next_value, _ = self.eval_model(
                    preprocessed_obs, self.memory * self.mask.unsqueeze(1))
            else:
                next_value = self.eval_model(preprocessed_obs)
            next_value_clipped = torch.clip(next_value,
                                            *self.env.envs[0].reward_range)

        for i in reversed(range(self.num_frames_per_proc)):
            next_mask = self.masks[
                i + 1] if i < self.num_frames_per_proc - 1 else self.mask
            next_mask = torch.vstack([next_mask] * self.reward_size).T

            next_value = self.values[
                i + 1] if i < self.num_frames_per_proc - 1 else next_value

            self.expected_values[i] = self.rewards[i] + (
                self.pareto_rewards(next_value_clipped, self.weights) *
                (self.discount * next_mask))
#           self.advantages[i] = delta + (next_advantage.T * (self.discount * self.gae_lambda *  next_mask)).T

# Define experiences:
#   the whole experience is the concatenation of the experience
#   of each process.
# In comments below:
#   - T is self.num_frames_per_proc,
#   - P is self.num_procs,
#   - D is the dimensionality.

        exps = DictList()
        exps.obs = [
            self.obss[i][j] for j in range(self.num_procs)
            for i in range(self.num_frames_per_proc)
        ]
        if self.model.recurrent:
            # T x P x D -> P x T x D -> (P * T) x D
            exps.memory = self.memories.transpose(0, 1).reshape(
                -1, *self.memories.shape[2:])
            # T x P -> P x T -> (P * T) x 1
            exps.mask = self.masks.transpose(0, 1).reshape(-1).unsqueeze(1)
        # for all tensors below, T x P -> P x T -> P * T
        exps.action = self.actions.transpose(0, 1).reshape(-1)
        exps.value = self.values.transpose(0, 1).reshape(-1, self.reward_size)
        exps.reward = self.rewards.transpose(0,
                                             1).reshape(-1, self.reward_size)
        exps.exp_value = self.expected_values.transpose(0, 1).reshape(
            -1, self.reward_size)
        exps.log_prob = self.log_probs.transpose(0, 1).reshape(-1)

        # Preprocess experiences

        exps.obs = self.preprocess_obss(exps.obs, device=self.device)

        # Log some values

        keep = max(self.log_done_counter, self.num_procs)

        logs = {
            "return_per_episode": self.log_return[-keep:],
            "reshaped_return_per_episode": self.log_reshaped_return[-keep:],
            "num_frames_per_episode": self.log_num_frames[-keep:],
            "num_frames": self.num_frames
        }

        self.log_done_counter = 0
        self.log_return = self.log_return[-self.num_procs:]
        self.log_reshaped_return = self.log_reshaped_return[-self.num_procs:]
        self.log_num_frames = self.log_num_frames[-self.num_procs:]

        return exps, logs

    def update_parameters_old(self, exps):
        # Compute starting indexes

        inds = self._get_starting_indexes()

        # Initialize update values

        update_entropy = 0
        update_value = 0
        update_policy_loss = 0
        update_value_loss = 0
        update_loss = 0

        # Initialize memory

        if self.model.recurrent:
            memory = exps.memory[inds]

        for i in range(self.recurrence):
            # Create a sub-batch of experience

            sb = exps[inds + i]

            # Compute loss

            if self.model.recurrent:
                value, memory = self.model(sb.obs, memory * sb.mask)
            else:
                value = self.model(sb.obs)


#            entropy = dist.entropy().mean()

#            policy_loss = -(dist.log_prob(sb.action) * sb.advantage).mean()

            loss = (value - sb.exp_value.unsqueeze(1)).pow(2).mean()

            # Update batch values

            update_loss += loss

        # Update update values

        update_value /= self.recurrence
        update_loss /= self.recurrence

        # Update actor-critic

        self.optimizer.zero_grad()
        update_loss.backward()
        update_grad_norm = sum(
            p.grad.data.norm(2)**2 for p in self.model.parameters())**0.5
        #        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
        self.optimizer.step()

        # Log some values

        logs = {
            "entropy": update_entropy,
            "value": update_value,
            "policy_loss": update_policy_loss,
            "value_loss": update_value_loss,
            "grad_norm": update_grad_norm
        }

        return logs

    def _get_starting_indexes(self):
        """Gives the indexes of the observations given to the model and the
        experiences used to compute the loss at first.

        The indexes are the integers from 0 to `self.num_frames` with a step of
        `self.recurrence`. If the model is not recurrent, they are all the
        integers from 0 to `self.num_frames`.

        Returns
        -------
        starting_indexes : list of int
            the indexes of the experiences to be used at first
        """

        starting_indexes = np.arange(0, self.num_frames, self.recurrence)
        return starting_indexes

    def pareto_action(self, values, weights):
        #col = torch.randint(0,self.reward_size,(1,))
        #return torch.max(values[:,:,col], dim=1).indices.squeeze()
        #print(torch.tensor([torch.argmax(torch.matmul(values[i,:,:],weights[i,:])) for i in range(values.shape[0])]))
        return torch.tensor([
            torch.argmax(torch.matmul(values[i, :, :], weights[i, :]))
            for i in range(values.shape[0])
        ])

    def pareto_rewards(self, values, weights):
        #col = torch.randint(0,self.reward_size,(1,))
        #inds = torch.max(values[:,:,col], dim=1).indices
        #return values.gather(inds)

        actions = self.pareto_action(values, weights)
        return torch.vstack(
            [values[i, actions[i], :] for i in range(values.shape[0])])
Exemplo n.º 3
0
# Load agent

model_dir = utils.get_model_dir(args.model)
agent = utils.Agent(env.observation_space, env.action_space, model_dir, device,
                    args.argmax, args.procs)
print("Agent loaded\n")

# Initialize logs

logs = {"num_frames_per_episode": [], "return_per_episode": []}

# Run agent

start_time = time.time()

obss = env.reset()

log_done_counter = 0
log_episode_return = torch.zeros(args.procs, device=device)
log_episode_num_frames = torch.zeros(args.procs, device=device)
positions = []
while log_done_counter < args.episodes:
    actions = agent.get_actions(obss)
    obss, rewards, dones, infos = env.step(actions)
    positions.extend([info["agent_pos"] for info in infos])
    agent.analyze_feedbacks(rewards, dones)

    log_episode_return += torch.tensor(rewards,
                                       device=device,
                                       dtype=torch.float)
    log_episode_num_frames += torch.ones(args.procs, device=device)
Exemplo n.º 4
0
def run_eval():
    envs = []
    for i in range(1):
        env = utils.make_env(args.env, args.seed + 10000 * i)
        env.is_teaching = False
        env.end_pos = args.eval_goal
        envs.append(env)
    env = ParallelEnv(envs)

    # Load agent

    model_dir = utils.get_model_dir(args.model)
    agent = utils.Agent(env.observation_space, env.action_space, model_dir, device, args.argmax, args.procs)

    # Initialize logs

    logs = {"num_frames_per_episode": [], "return_per_episode": []}

    # Run agent

    start_time = time.time()

    obss = env.reset()

    log_done_counter = 0
    log_episode_return = torch.zeros(args.procs, device=device)
    log_episode_num_frames = torch.zeros(args.procs, device=device)
    positions = []
    while log_done_counter < args.episodes:
        actions = agent.get_actions(obss)
        obss, rewards, dones, infos = env.step(actions)
        positions.extend([info["agent_pos"] for info in infos])
        agent.analyze_feedbacks(rewards, dones)

        log_episode_return += torch.tensor(rewards, device=device, dtype=torch.float)
        log_episode_num_frames += torch.ones(args.procs, device=device)

        for i, done in enumerate(dones):
            if done:
                log_done_counter += 1
                logs["return_per_episode"].append(log_episode_return[i].item())
                logs["num_frames_per_episode"].append(log_episode_num_frames[i].item())

        mask = 1 - torch.tensor(dones, device=device, dtype=torch.float)
        log_episode_return *= mask
        log_episode_num_frames *= mask

    end_time = time.time()

    # Print logs

    num_frames = sum(logs["num_frames_per_episode"])
    fps = num_frames/(end_time - start_time)
    duration = int(end_time - start_time)
    return_per_episode = utils.synthesize(logs["return_per_episode"])
    num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"])

    print("Eval: F {} | FPS {:.0f} | D {} | R:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {}"
          .format(num_frames, fps, duration,
                  *return_per_episode.values(),
                  *num_frames_per_episode.values()))
    return return_per_episode
Exemplo n.º 5
0
class BaseSRAlgo(ABC):
    """The base class for RL algorithms."""
    def __init__(self,
                 envs,
                 model,
                 target,
                 device,
                 num_frames_per_proc,
                 discount,
                 lr,
                 gae_lambda,
                 max_grad_norm,
                 recurrence,
                 memory_cap,
                 preprocess_obss,
                 reshape_reward=None):
        """
        Initializes a `BaseSRAlgo` instance.

        Parameters:
        ----------
        envs : list
            a list of environments that will be run in parallel
        acmodel : torch.Module
            the model
        num_frames_per_proc : int
            the number of frames collected by every process for an update
        discount : float
            the discount for future rewards
        lr : float
            the learning rate for optimizers
        gae_lambda : float
            the lambda coefficient in the GAE formula
            ([Schulman et al., 2015](https://arxiv.org/abs/1506.02438))
        entropy_coef : float
            the weight of the entropy cost in the final objective
        value_loss_coef : float
            the weight of the value loss in the final objective
        max_grad_norm : float
            gradient will be clipped to be at most this value
        recurrence : int
            the number of steps the gradient is propagated back in time
        preprocess_obss : function
            a function that takes observations returned by the environment
            and converts them into the format that the model can handle
        reshape_reward : function
            a function that shapes the reward, takes an
            (observation, action, reward, done) tuple as an input
        """

        # Store parameters
        self.env = ParallelEnv(envs)
        self.model = model
        self.target = target
        self.device = device
        self.num_frames_per_proc = num_frames_per_proc
        self.discount = discount
        self.lr = lr
        self.gae_lambda = gae_lambda
        self.max_grad_norm = max_grad_norm
        self.recurrence = recurrence

        self.replay_memory = ReplayMemory(memory_cap)

        use_cuda = torch.cuda.is_available()
        self.FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

        self.total_updates = 0
        self.preprocess_obss = preprocess_obss or default_preprocess_obss
        self.reshape_reward = reshape_reward
        self.continuous_action = model.continuous_action

        # Control parameters

        assert self.model.recurrent or self.recurrence == 1
        assert self.num_frames_per_proc % self.recurrence == 0

        # Configure acmodel

        self.model.to(self.device)
        self.model.train()
        self.target.to(self.device)
        self.target.train()

        # Store helpers values

        self.num_procs = len(envs)
        self.num_frames = self.num_frames_per_proc * self.num_procs

        # Initialize experience values

        shape = (self.num_frames_per_proc, self.num_procs)
        vec_shape = (self.num_frames_per_proc, self.num_procs,
                     self.model.embedding_size)

        self.obs = self.env.reset()
        self.obss = [None] * (shape[0])
        if self.model.recurrent:
            self.memory = torch.zeros(shape[1],
                                      self.model.memory_size,
                                      device=self.device)
            self.target_memory = torch.zeros(shape[1],
                                             self.model.memory_size,
                                             device=self.device)
            self.memories = torch.zeros(*shape,
                                        self.model.memory_size,
                                        device=self.device)
        self.mask = torch.ones(shape[1], device=self.device)
        self.masks = torch.zeros(*shape, device=self.device)
        if self.continuous_action:
            self.actions = torch.zeros(*shape, device=self.device)
        else:
            self.actions = torch.zeros(*shape,
                                       device=self.device,
                                       dtype=torch.int)
        self.values = torch.zeros(*shape, device=self.device)
        self.target_values = torch.zeros(*shape, device=self.device)
        self.rewards = torch.zeros(*shape, device=self.device)
        self.SR_advantages = torch.zeros(*vec_shape, device=self.device)
        self.V_advantages = torch.zeros(*shape, device=self.device)
        self.log_probs = torch.zeros(*shape, device=self.device)
        self.embeddings = torch.zeros(*vec_shape, device=self.device)
        self.target_embeddings = torch.zeros(*vec_shape, device=self.device)
        self.successors = torch.zeros(*vec_shape, device=self.device)
        self.target_successors = torch.zeros(*vec_shape, device=self.device)

        # Initialize log values

        self.log_episode_return = torch.zeros(self.num_procs,
                                              device=self.device)
        self.log_episode_reshaped_return = torch.zeros(self.num_procs,
                                                       device=self.device)
        self.log_episode_num_frames = torch.zeros(self.num_procs,
                                                  device=self.device)

        self.log_done_counter = 0
        self.log_return = [0] * self.num_procs
        self.log_reshaped_return = [0] * self.num_procs
        self.log_num_frames = [0] * self.num_procs

    def collect_experiences(self):
        """Collects rollouts and computes advantages.

        Runs several environments concurrently. The next actions are computed
        in a batch mode for all environments at the same time. The rollouts
        and advantages from all environments are concatenated together.

        Returns
        -------
        exps : DictList
            Contains actions, rewards, advantages etc as attributes.
            Each attribute, e.g. `exps.reward` has a shape
            (self.num_frames_per_proc * num_envs, ...). k-th block
            of consecutive `self.num_frames_per_proc` frames contains
            data obtained from the k-th environment. Be careful not to mix
            data from different environments!
        logs : dict
            Useful stats about the training process, including the average
            reward, policy loss, value loss, etc.
        """

        for i in range(self.num_frames_per_proc):
            # Do one agent-environment interaction

            if self.continuous_action:
                self.obs = [
                    self.model.scaler.transform(self.obs[0].reshape(
                        1, -1)).reshape(-1).astype('float64')
                ]

            preprocessed_obs = self.preprocess_obss(self.obs,
                                                    device=self.device)

            with torch.no_grad():
                if self.model.recurrent:
                    dist, value, embedding, _, successor, _, memory = self.model(
                        preprocessed_obs,
                        memory=self.memory * self.mask.unsqueeze(1))
                    _, target_value, target_embedding, _, target_successor, _, _ = self.target(
                        preprocessed_obs,
                        memory=self.memory * self.mask.unsqueeze(1))
                # target
                else:
                    dist, value, embedding, _, successor, _, _ = self.model(
                        preprocessed_obs)
                    _, target_value, target_embedding, _, target_successor, _, _ = self.target(
                        preprocessed_obs)

            if self.continuous_action:
                # Should this (eps + stochastic policy) be done? Or use (eps + det policy) or just stochastic policy?
                epsample = random.random()
                eps_threshold = 0.02 + (0.9 - 0.02) * math.exp(
                    -1. * self.total_updates / 200)
                if epsample > eps_threshold:
                    noise_dist = torch.distributions.normal.Normal(0, 0.03)
                    action = dist.sample() + noise_dist.sample()
                    action = torch.clamp(action, self.env.envs[0].min_action,
                                         self.env.envs[0].max_action)
                else:
                    action = torch.Tensor(
                        self.env.envs[0].action_space.sample())

                obs, reward, done, _ = self.env.step([action.cpu().numpy()])
                obs = (obs[0].reshape(1, -1))
            else:
                action = dist.sample()
                obs, reward, done, _ = self.env.step(action.cpu().numpy())

            # Update experiences values
            self.replay_memory.push((self.FloatTensor([obs[0]['image']]),
                                     self.FloatTensor([reward])))

            self.obss[i] = self.obs
            self.obs = obs
            if self.model.recurrent:
                self.memories[i] = self.memory
                self.memory = memory
            self.masks[i] = self.mask
            self.mask = 1 - torch.tensor(
                done, device=self.device, dtype=torch.float)
            self.actions[i] = action
            self.values[i] = value
            self.target_values[i] = target_value
            self.embeddings[i] = embedding
            self.target_embeddings[i] = target_embedding
            self.successors[i] = successor
            self.target_successors[i] = target_successor
            if self.reshape_reward is not None:
                self.rewards[i] = torch.tensor([
                    self.reshape_reward(obs_, action_, reward_, done_)
                    for obs_, action_, reward_, done_ in zip(
                        obs, action, reward, done)
                ],
                                               device=self.device)
            else:
                self.rewards[i] = torch.tensor(reward, device=self.device)
            self.log_probs[i] = dist.log_prob(action)

            # Update log values

            self.log_episode_return += torch.tensor(reward,
                                                    device=self.device,
                                                    dtype=torch.float)
            self.log_episode_reshaped_return += self.rewards[i]
            self.log_episode_num_frames += torch.ones(self.num_procs,
                                                      device=self.device)

            for i, done_ in enumerate(done):
                if done_:
                    self.log_done_counter += 1
                    self.log_return.append(self.log_episode_return[i].item())
                    self.log_reshaped_return.append(
                        self.log_episode_reshaped_return[i].item())
                    self.log_num_frames.append(
                        self.log_episode_num_frames[i].item())

            self.log_episode_return *= self.mask
            self.log_episode_reshaped_return *= self.mask
            self.log_episode_num_frames *= self.mask

        # Add advantage and return to experiences

        if self.continuous_action:
            # asuming flat observations for continuous action case:
            # this is true for the Mountain Cart example but may not be in general
            # Ideally the continuous action code should be modifed to handle flat or image input
            # And the use of a scaler should be an option to train.py
            # And either use checks here to do the following
            # or create a wrapper that does the scaling and set it up in train.py
            self.obs[0] = self.model.scaler.transform(self.obs[0].reshape(
                1, -1)).reshape(-1)
            self.obs = self.obs.astype('float32')

        preprocessed_obs = self.preprocess_obss(self.obs, device=self.device)
        with torch.no_grad():
            if self.model.recurrent:
                _, next_value, _, _, next_successor, _, _ = self.target(
                    preprocessed_obs,
                    memory=self.memory * self.mask.unsqueeze(1))  #target
            else:
                _, next_value, _, _, next_successor, _ = self.target(
                    preprocessed_obs)

        for i in reversed(range(self.num_frames_per_proc)):
            next_mask = self.masks[
                i + 1] if i < self.num_frames_per_proc - 1 else self.mask
            next_successor = self.target_successors[
                i + 1] if i < self.num_frames_per_proc - 1 else next_successor
            next_value = self.target_values[
                i + 1] if i < self.num_frames_per_proc - 1 else next_value
            next_SR_advantage = self.SR_advantages[
                i + 1] if i < self.num_frames_per_proc - 1 else 0
            next_V_advantage = self.V_advantages[
                i + 1] if i < self.num_frames_per_proc - 1 else 0

            SR_delta = self.target_embeddings[i] + (
                self.discount * next_successor *
                next_mask.reshape(-1, 1)) - self.successors[i]
            self.SR_advantages[i] = SR_delta + (
                self.discount * self.gae_lambda * next_SR_advantage *
                next_mask.reshape(-1, 1))

            V_delta = self.rewards[
                i] + self.discount * next_value * next_mask - self.values[i]
            self.V_advantages[
                i] = V_delta + self.discount * self.gae_lambda * next_V_advantage * next_mask

        # Define experiences:
        #   the whole experience is the concatenation of the experience
        #   of each process.
        # In comments below:
        #   - T is self.num_frames_per_proc,
        #   - P is self.num_procs,
        #   - D is the dimensionality.

        exps = DictList()
        exps.obs = [
            self.obss[i][j] for j in range(self.num_procs)
            for i in range(self.num_frames_per_proc)
        ]
        if self.model.recurrent:
            # T x P x D -> P x T x D -> (P * T) x D
            exps.memory = self.memories.transpose(0, 1).reshape(
                -1, *self.memories.shape[2:])
            # T x P -> P x T -> (P * T) x 1
            exps.mask = self.masks.transpose(0, 1).reshape(-1).unsqueeze(1)
        # for all tensors below, T x P -> P x T -> P * T
        exps.action = self.actions.transpose(0, 1).reshape(-1)
        exps.value = self.values.transpose(0, 1).reshape(-1)
        exps.reward = self.rewards.transpose(0, 1).reshape(-1)
        exps.SR_advantage = self.SR_advantages.transpose(0, 1).reshape(
            -1, self.model.embedding_size)
        exps.successor = self.successors.transpose(0, 1).reshape(
            -1, self.model.embedding_size)
        exps.successorn = exps.successor + exps.SR_advantage
        exps.V_advantage = self.V_advantages.transpose(0, 1).reshape(-1)
        exps.returnn = exps.value + exps.V_advantage
        exps.log_prob = self.log_probs.transpose(0, 1).reshape(-1)

        # Preprocess experiences

        exps.obs = self.preprocess_obss(exps.obs, device=self.device)

        # Log some values

        keep = max(self.log_done_counter, self.num_procs)

        logs = {
            "return_per_episode": self.log_return[-keep:],
            "reshaped_return_per_episode": self.log_reshaped_return[-keep:],
            "num_frames_per_episode": self.log_num_frames[-keep:],
            "num_frames": self.num_frames
        }

        self.log_done_counter = 0
        self.log_return = self.log_return[-self.num_procs:]
        self.log_reshaped_return = self.log_reshaped_return[-self.num_procs:]
        self.log_num_frames = self.log_num_frames[-self.num_procs:]

        return exps, logs

    @abstractmethod
    def update_parameters(self):
        pass
Exemplo n.º 6
0
class BaseAlgo(ABC):
    """The base class for RL algorithms."""
    def __init__(self,
                 envs,
                 acmodel,
                 device,
                 num_frames_per_proc,
                 discount,
                 lr,
                 gae_lambda,
                 entropy_coef,
                 value_loss_coef,
                 max_grad_norm,
                 recurrence,
                 preprocess_obss,
                 reshape_reward,
                 use_entropy_reward=False):
        """
        Initializes a `BaseAlgo` instance.

        Parameters:
        ----------
        envs : list
            a list of environments that will be run in parallel
        acmodel : torch.Module
            the model
        num_frames_per_proc : int
            the number of frames collected by every process for an update
        discount : float
            the discount for future rewards
        lr : float
            the learning rate for optimizers
        gae_lambda : float
            the lambda coefficient in the GAE formula
            ([Schulman et al., 2015](https://arxiv.org/abs/1506.02438))
        entropy_coef : float
            the weight of the entropy cost in the final objective
        value_loss_coef : float
            the weight of the value loss in the final objective
        max_grad_norm : float
            gradient will be clipped to be at most this value
        recurrence : int
            the number of steps the gradient is propagated back in time
        preprocess_obss : function
            a function that takes observations returned by the environment
            and converts them into the format that the model can handle
        reshape_reward : function
            a function that shapes the reward, takes an
            (observation, action, reward, done) tuple as an input
        """
        self.use_entropy_reward = use_entropy_reward

        # Store parameters

        self.env = ParallelEnv(envs)
        self.acmodel = acmodel
        self.device = device
        self.num_frames_per_proc = num_frames_per_proc
        self.discount = discount
        self.lr = lr
        self.gae_lambda = gae_lambda
        self.entropy_coef = entropy_coef
        self.value_loss_coef = value_loss_coef
        self.max_grad_norm = max_grad_norm
        self.recurrence = recurrence
        self.preprocess_obss = preprocess_obss or default_preprocess_obss
        self.reshape_reward = reshape_reward

        # Control parameters

        assert self.acmodel.recurrent or self.recurrence == 1
        assert self.num_frames_per_proc % self.recurrence == 0

        # Configure acmodel

        self.acmodel.to(self.device)
        self.acmodel.train()

        self.k = 3
        self.s_ent_stats = TorchRunningMeanStd(shape=[1], device=self.device)

        self.random_encoder = nn.Sequential(nn.Conv2d(3, 16,
                                                      (2, 2)), nn.ReLU(),
                                            nn.MaxPool2d((2, 2)),
                                            nn.Conv2d(16, 32,
                                                      (2, 2)), nn.ReLU(),
                                            nn.Conv2d(32, 64, (2, 2)),
                                            nn.ReLU())

        self.random_encoder.to(self.device)
        self.random_encoder.train()

        # Store helpers values

        self.num_procs = len(envs)
        self.num_frames = self.num_frames_per_proc * self.num_procs

        # Initialize experience values

        shape = (self.num_frames_per_proc, self.num_procs)

        self.obs = self.env.reset()
        self.obss = [None] * (shape[0])
        if self.acmodel.recurrent:
            self.memory = torch.zeros(shape[1],
                                      self.acmodel.memory_size,
                                      device=self.device)
            self.memories = torch.zeros(*shape,
                                        self.acmodel.memory_size,
                                        device=self.device)
        self.mask = torch.ones(shape[1], device=self.device)
        self.masks = torch.zeros(*shape, device=self.device)
        self.actions = torch.zeros(*shape, device=self.device, dtype=torch.int)
        self.values = torch.zeros(*shape, device=self.device)
        self.rewards = torch.zeros(*shape, device=self.device)
        self.advantages = torch.zeros(*shape, device=self.device)
        self.log_probs = torch.zeros(*shape, device=self.device)

        # Initialize log values

        self.log_episode_return = torch.zeros(self.num_procs,
                                              device=self.device)
        self.log_episode_reshaped_return = torch.zeros(self.num_procs,
                                                       device=self.device)
        self.log_episode_num_frames = torch.zeros(self.num_procs,
                                                  device=self.device)

        self.log_done_counter = 0
        self.log_return = [0] * self.num_procs
        self.log_reshaped_return = [0] * self.num_procs
        self.log_num_frames = [0] * self.num_procs

        self.agent_pos_visits = dict()

    def collect_experiences(self):
        """Collects rollouts and computes advantages.

        Runs several environments concurrently. The next actions are computed
        in a batch mode for all environments at the same time. The rollouts
        and advantages from all environments are concatenated together.

        Returns
        -------
        exps : DictList
            Contains actions, rewards, advantages etc as attributes.
            Each attribute, e.g. `exps.reward` has a shape
            (self.num_frames_per_proc * num_envs, ...). k-th block
            of consecutive `self.num_frames_per_proc` frames contains
            data obtained from the k-th environment. Be careful not to mix
            data from different environments!
        logs : dict
            Useful stats about the training process, including the average
            reward, policy loss, value loss, etc.
        """

        for i in range(self.num_frames_per_proc):
            # Do one agent-environment interaction

            preprocessed_obs = self.preprocess_obss(self.obs,
                                                    device=self.device)
            with torch.no_grad():
                if self.acmodel.recurrent:
                    dist, value, memory = self.acmodel(
                        preprocessed_obs, self.memory * self.mask.unsqueeze(1))
                else:
                    dist, value = self.acmodel(preprocessed_obs)
            action = dist.sample()

            obs, reward, done, _ = self.env.step(action.cpu().numpy())

            # Update experiences values

            self.obss[i] = self.obs
            self.obs = obs
            if self.acmodel.recurrent:
                self.memories[i] = self.memory
                self.memory = memory
            self.masks[i] = self.mask
            self.mask = 1 - torch.tensor(
                done, device=self.device, dtype=torch.float)
            self.actions[i] = action
            self.values[i] = value
            if self.reshape_reward is not None:
                self.rewards[i] = torch.tensor([
                    self.reshape_reward(obs_, action_, reward_, done_)
                    for obs_, action_, reward_, done_ in zip(
                        obs, action, reward, done)
                ],
                                               device=self.device)
            else:
                self.rewards[i] = torch.tensor(reward, device=self.device)
            self.log_probs[i] = dist.log_prob(action)

            # Update log values

            self.log_episode_return += torch.tensor(reward,
                                                    device=self.device,
                                                    dtype=torch.float)
            self.log_episode_reshaped_return += self.rewards[i]
            self.log_episode_num_frames += torch.ones(self.num_procs,
                                                      device=self.device)

            for i, done_ in enumerate(done):
                if done_:
                    self.log_done_counter += 1
                    self.log_return.append(self.log_episode_return[i].item())
                    self.log_reshaped_return.append(
                        self.log_episode_reshaped_return[i].item())
                    self.log_num_frames.append(
                        self.log_episode_num_frames[i].item())

            self.log_episode_return *= self.mask
            self.log_episode_reshaped_return *= self.mask
            self.log_episode_num_frames *= self.mask

        # Add advantage and return to experiences

        preprocessed_obs = self.preprocess_obss(self.obs, device=self.device)
        with torch.no_grad():
            if self.acmodel.recurrent:
                _, next_value, _ = self.acmodel(
                    preprocessed_obs, self.memory * self.mask.unsqueeze(1))
            else:
                _, next_value = self.acmodel(preprocessed_obs)

        for i in reversed(range(self.num_frames_per_proc)):
            next_mask = self.masks[
                i + 1] if i < self.num_frames_per_proc - 1 else self.mask
            next_value = self.values[
                i + 1] if i < self.num_frames_per_proc - 1 else next_value
            next_advantage = self.advantages[
                i + 1] if i < self.num_frames_per_proc - 1 else 0

            delta = self.rewards[
                i] + self.discount * next_value * next_mask - self.values[i]
            self.advantages[
                i] = delta + self.discount * self.gae_lambda * next_advantage * next_mask

        # Define experiences:
        #   the whole experience is the concatenation of the experience
        #   of each process.
        # In comments below:
        #   - T is self.num_frames_per_proc,
        #   - P is self.num_procs,
        #   - D is the dimensionality.

        exps = DictList()
        exps.obs = [
            self.obss[i][j] for j in range(self.num_procs)
            for i in range(self.num_frames_per_proc)
        ]

        if self.acmodel.recurrent:
            # T x P x D -> P x T x D -> (P * T) x D
            exps.memory = self.memories.transpose(0, 1).reshape(
                -1, *self.memories.shape[2:])
            # T x P -> P x T -> (P * T) x 1
            exps.mask = self.masks.transpose(0, 1).reshape(-1).unsqueeze(1)
        # for all tensors below, T x P -> P x T -> P * T
        exps.action = self.actions.transpose(0, 1).reshape(-1)
        exps.value = self.values.transpose(0, 1).reshape(-1)
        exps.reward = self.rewards.transpose(0, 1).reshape(-1)
        exps.advantage = self.advantages.transpose(0, 1).reshape(-1)
        exps.returnn = exps.value + exps.advantage
        exps.log_prob = self.log_probs.transpose(0, 1).reshape(-1)

        # Preprocess experiences

        exps.obs = self.preprocess_obss(exps.obs, device=self.device)

        # Log some values

        keep = max(self.log_done_counter, self.num_procs)

        logs = {
            "return_per_episode": self.log_return[-keep:],
            "reshaped_return_per_episode": self.log_reshaped_return[-keep:],
            "num_frames_per_episode": self.log_num_frames[-keep:],
            "num_frames": self.num_frames
        }

        self.log_done_counter = 0
        self.log_return = self.log_return[-self.num_procs:]
        self.log_reshaped_return = self.log_reshaped_return[-self.num_procs:]
        self.log_num_frames = self.log_num_frames[-self.num_procs:]

        return exps, logs

    def soft_update_params(self, net, target_net, tau):
        for param, target_param in zip(net.parameters(),
                                       target_net.parameters()):
            target_param.data.copy_(tau * param.data +
                                    (1 - tau) * target_param.data)

    def compute_logits(self, z_a, z_pos):
        """
        Uses logits trick for CURL:
        - compute (B,B) matrix z_a (W z_pos.T)
        - positives are all diagonal elements
        - negatives are all other elements
        - to compute loss use multiclass cross entropy with identity matrix for labels
        """
        Wz = torch.matmul(self.W, z_pos.T)  # (z_dim,B)
        logits = torch.matmul(z_a, Wz)  # (B,B)
        logits = logits - torch.max(logits, 1)[0][:, None]
        return logits

    @abstractmethod
    def update_parameters(self):
        pass

    def compute_state_entropy(self,
                              src_feats,
                              tgt_feats,
                              average_entropy=False):
        with torch.no_grad():
            dists = []
            for idx in range(len(tgt_feats) // 10000 + 1):
                start = idx * 10000
                end = (idx + 1) * 10000
                dist = torch.norm(src_feats[:, None, :] -
                                  tgt_feats[None, start:end, :],
                                  dim=-1,
                                  p=2)
                dists.append(dist)

            dists = torch.cat(dists, dim=1)
            knn_dists = 0.0
            if average_entropy:
                for k in range(5):
                    knn_dists += torch.kthvalue(dists, k + 1, dim=1).values
                knn_dists /= 5
            else:
                knn_dists = torch.kthvalue(dists, k=self.k + 1, dim=1).values
            state_entropy = knn_dists
        return state_entropy.unsqueeze(1)