Exemple #1
0
 def forward(self):
     a = torch.randn(3, 2)
     b = torch.rand(3, 2)
     c = torch.rand(3)
     log_probs = torch.randn(50, 16, 20).log_softmax(2).detach()
     targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
     input_lengths = torch.full((16, ), 50, dtype=torch.long)
     target_lengths = torch.randint(10, 30, (16, ), dtype=torch.long)
     return len(
         F.binary_cross_entropy(torch.sigmoid(a), b),
         F.binary_cross_entropy_with_logits(torch.sigmoid(a), b),
         F.poisson_nll_loss(a, b),
         F.cosine_embedding_loss(a, b, c),
         F.cross_entropy(a, b),
         F.ctc_loss(log_probs, targets, input_lengths, target_lengths),
         # F.gaussian_nll_loss(a, b, torch.ones(5, 1)), # ENTER is not supported in mobile module
         F.hinge_embedding_loss(a, b),
         F.kl_div(a, b),
         F.l1_loss(a, b),
         F.mse_loss(a, b),
         F.margin_ranking_loss(c, c, c),
         F.multilabel_margin_loss(self.x, self.y),
         F.multilabel_soft_margin_loss(self.x, self.y),
         F.multi_margin_loss(self.x, torch.tensor([3])),
         F.nll_loss(a, torch.tensor([1, 0, 1])),
         F.huber_loss(a, b),
         F.smooth_l1_loss(a, b),
         F.soft_margin_loss(a, b),
         F.triplet_margin_loss(a, b, -b),
         # F.triplet_margin_with_distance_loss(a, b, -b), # can't take variable number of arguments
     )
Exemple #2
0
    def optimize(self):
        if len(self.memory._storage) < self.batch_size:
            return

        beta = self.beta_scheduler.value(self.optim_steps)
        state, action, reward, new_state, done, _, indices = self.memory.sample(
            self.batch_size, beta)

        state = torch.as_tensor(np.vstack(state),
                                dtype=torch.float32,
                                device=device)
        action = torch.as_tensor(np.vstack(action),
                                 dtype=torch.float32,
                                 device=device)
        done = torch.as_tensor(np.vstack(1 - done),
                               dtype=torch.float32,
                               device=device)
        reward = torch.as_tensor(np.vstack(reward),
                                 dtype=torch.float32,
                                 device=device)
        new_state = torch.as_tensor(np.vstack(new_state),
                                    dtype=torch.float32,
                                    device=device)

        self.target_actor.eval()
        self.target_critic.eval()
        self.critic.train()
        self.actor.train()

        Q_target = self.target_critic.forward(
            new_state, self.target_actor.forward(new_state))
        Y = reward + (done * self.gamma * Q_target)
        Q = self.critic.forward(state, action)
        TD_errors = torch.sub(Y, Q).squeeze(dim=-1)

        # Not considering weighted td errors as this approach is better
        # considering all 'PER' weights as 1.0 is a hyperparameter too!
        critic_loss = F.huber_loss(TD_errors, torch.zeros_like(TD_errors))
        self.critic.optimizer.zero_grad()
        critic_loss.backward()
        self.critic.optimizer.step()

        # Compute & Update Actor losses
        actor_loss = torch.mean(-1.0 *
                                self.critic.forward(state, self.actor(state)))
        self.actor.optimizer.zero_grad()
        actor_loss.backward()
        self.actor.optimizer.step()

        td_errors: np.ndarray = TD_errors.detach().cpu().numpy()
        new_priorities = np.abs(td_errors) + 1e-6
        self.memory.update_priorities(indices, new_priorities)

        self._update_networks(self.tau)
        self.optim_steps += 1
    def optimize(self):
        if len(self.memory._storage) < self.batch_size:
            return

        state, action, reward, new_state, done = self.memory.sample(
            self.batch_size)

        state = torch.as_tensor(np.vstack(state),
                                dtype=torch.float32,
                                device=device)
        action = torch.as_tensor(np.vstack(action),
                                 dtype=torch.float32,
                                 device=device)
        done = torch.as_tensor(np.vstack(1 - done),
                               dtype=torch.float32,
                               device=device)
        reward = torch.as_tensor(np.vstack(reward),
                                 dtype=torch.float32,
                                 device=device)
        new_state = torch.as_tensor(np.vstack(new_state),
                                    dtype=torch.float32,
                                    device=device)

        self.target_actor.eval()
        self.target_critic.eval()
        self.critic.train()
        self.actor.train()

        Q_target = self.target_critic.forward(
            new_state, self.target_actor.forward(new_state))
        Y = reward + (done * self.gamma * Q_target)
        Q = self.critic.forward(state, action)
        TD_errors = torch.sub(Y, Q).squeeze(dim=-1)

        critic_loss = F.huber_loss(TD_errors, torch.zeros_like(TD_errors))
        self.critic.optimizer.zero_grad()
        critic_loss.backward()
        self.critic.optimizer.step()

        # Compute & Update Actor losses
        actor_loss = torch.mean(-1.0 *
                                self.critic.forward(state, self.actor(state)))
        self.actor.optimizer.zero_grad()
        actor_loss.backward()
        self.actor.optimizer.step()

        self._update_networks(self.tau)
Exemple #4
0
    def _train(self, BATCH):
        q_dist = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A, N]
        q_dist = (q_dist * BATCH.action.unsqueeze(-1)).sum(-2)  # [T, B, A, N] => [T, B, N]

        target_q_dist = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A, N]
        target_q = target_q_dist.mean(-1)  # [T, B, A, N] => [T, B, A]
        _a = target_q.argmax(-1)  # [T, B]
        next_max_action = F.one_hot(_a, self.a_dim).float().unsqueeze(-1)  # [T, B, A, 1]
        # [T, B, A, N] => [T, B, N]
        target_q_dist = (target_q_dist * next_max_action).sum(-2)

        target = n_step_return(BATCH.reward.repeat(1, 1, self.nums),
                               self.gamma,
                               BATCH.done.repeat(1, 1, self.nums),
                               target_q_dist,
                               BATCH.begin_mask.repeat(1, 1, self.nums)).detach()  # [T, B, N]

        q_eval = q_dist.mean(-1, keepdim=True)  # [T, B, 1]
        q_target = target.mean(-1, keepdim=True)  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1], used for PER

        target = target.unsqueeze(-2)  # [T, B, 1, N]
        q_dist = q_dist.unsqueeze(-1)  # [T, B, N, 1]

        # [T, B, 1, N] - [T, B, N, 1] => [T, B, N, N]
        quantile_error = target - q_dist
        huber = F.huber_loss(target, q_dist, reduction="none", delta=self.huber_delta)  # [T, B, N, N]
        # [N,] - [T, B, N, N] => [T, B, N, N]
        huber_abs = (self.quantiles - quantile_error.detach().le(0.).float()).abs()
        loss = (huber_abs * huber).mean(-1)  # [T, B, N, N] => [T, B, N]
        loss = loss.sum(-1, keepdim=True)  # [T, B, N] => [T, B, 1]
        loss = (loss * BATCH.get('isw', 1.0)).mean()  # 1

        self.oplr.optimize(loss)
        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }
Exemple #5
0
    def _train(self, BATCH):
        time_step = BATCH.reward.shape[0]
        batch_size = BATCH.reward.shape[1]

        quantiles, quantiles_tiled = self._generate_quantiles(  # [T*B, N, 1], [N*T*B, X]
            batch_size=time_step * batch_size,
            quantiles_num=self.online_quantiles)
        # [T*B, N, 1] => [T, B, N, 1]
        quantiles = quantiles.view(time_step, batch_size, -1, 1)
        quantiles_tiled = quantiles_tiled.view(time_step, -1, self.quantiles_idx)  # [N*T*B, X] => [T, N*B, X]

        quantiles_value = self.q_net(BATCH.obs, quantiles_tiled, begin_mask=BATCH.begin_mask)  # [T, N, B, A]
        # [T, N, B, A] => [N, T, B, A] * [T, B, A] => [N, T, B, 1]
        quantiles_value = (quantiles_value.swapaxes(0, 1) * BATCH.action).sum(-1, keepdim=True)
        q_eval = quantiles_value.mean(0)  # [N, T, B, 1] => [T, B, 1]

        _, select_quantiles_tiled = self._generate_quantiles(  # [N*T*B, X]
            batch_size=time_step * batch_size,
            quantiles_num=self.select_quantiles)
        select_quantiles_tiled = select_quantiles_tiled.view(
            time_step, -1, self.quantiles_idx)  # [N*T*B, X] => [T, N*B, X]

        q_values = self.q_net(
            BATCH.obs_, select_quantiles_tiled, begin_mask=BATCH.begin_mask)  # [T, N, B, A]
        q_values = q_values.mean(1)  # [T, N, B, A] => [T, B, A]
        next_max_action = q_values.argmax(-1)  # [T, B]
        next_max_action = F.one_hot(
            next_max_action, self.a_dim).float()  # [T, B, A]

        _, target_quantiles_tiled = self._generate_quantiles(  # [N'*T*B, X]
            batch_size=time_step * batch_size,
            quantiles_num=self.target_quantiles)
        target_quantiles_tiled = target_quantiles_tiled.view(
            time_step, -1, self.quantiles_idx)  # [N'*T*B, X] => [T, N'*B, X]
        target_quantiles_value = self.q_net.t(BATCH.obs_, target_quantiles_tiled,
                                              begin_mask=BATCH.begin_mask)  # [T, N', B, A]
        target_quantiles_value = target_quantiles_value.swapaxes(0, 1)  # [T, N', B, A] => [N', T, B, A]
        target_quantiles_value = (target_quantiles_value * next_max_action).sum(-1, keepdim=True)  # [N', T, B, 1]

        target_q = target_quantiles_value.mean(0)  # [T, B, 1]
        q_target = n_step_return(BATCH.reward,  # [T, B, 1]
                                 self.gamma,
                                 BATCH.done,  # [T, B, 1]
                                 target_q,  # [T, B, 1]
                                 BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1]

        # [N', T, B, 1] => [N', T, B]
        target_quantiles_value = target_quantiles_value.squeeze(-1)
        target_quantiles_value = target_quantiles_value.permute(
            1, 2, 0)  # [N', T, B] => [T, B, N']
        quantiles_value_target = n_step_return(BATCH.reward.repeat(1, 1, self.target_quantiles),
                                               self.gamma,
                                               BATCH.done.repeat(1, 1, self.target_quantiles),
                                               target_quantiles_value,
                                               BATCH.begin_mask.repeat(1, 1,
                                                                       self.target_quantiles)).detach()  # [T, B, N']
        # [T, B, N'] => [T, B, 1, N']
        quantiles_value_target = quantiles_value_target.unsqueeze(-2)
        quantiles_value_online = quantiles_value.permute(1, 2, 0, 3)  # [N, T, B, 1] => [T, B, N, 1]
        # [T, B, N, 1] - [T, B, 1, N'] => [T, B, N, N']
        quantile_error = quantiles_value_online - quantiles_value_target
        huber = F.huber_loss(quantiles_value_online, quantiles_value_target,
                             reduction="none", delta=self.huber_delta)  # [T, B, N, N]
        # [T, B, N, 1] - [T, B, N, N'] => [T, B, N, N']
        huber_abs = (quantiles - quantile_error.detach().le(0.).float()).abs()
        loss = (huber_abs * huber).mean(-1)  # [T, B, N, N'] => [T, B, N]
        loss = loss.sum(-1, keepdim=True)  # [T, B, N] => [T, B, 1]

        loss = (loss * BATCH.get('isw', 1.0)).mean()  # 1
        self.oplr.optimize(loss)
        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }
Exemple #6
0
    def compute_loss(self):
        if len(self.memory) < 5:
            return None
        transitions = self.memory.sample(min(self.batch_size, len(self.memory)))
        batch = Transition(*zip(*transitions))
        state_batch = dict()
        for k in batch.state[0].keys():
            state_batch[k] = torch.stack([data[k] for data in batch.state])

        assert state_batch['image'].max() < 1.1
        for i, img in enumerate(state_batch['image']):
            img_trans = torch.as_tensor(self.transform(img.permute(1,2,0).numpy() * 255)).permute(2, 0, 1) / 255
            state_batch['image'][i] = img_trans
            # import cv2
            # cv2.imshow('old', img.numpy().transpose(1,2,0))
            # cv2.imshow('new', state_batch['image'][i].numpy().transpose(1,2,0))
            # cv2.waitKey(1000)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.stack(batch.reward)
        next_states = batch.next_state

        # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
        # columns of actions taken. These are the actions which would've been taken
        # for each batch state according to policy_net
        state_action_values = self.policy_net(state_batch)
        Q_values = state_action_values[numpy.arange(len(state_action_values)), action_batch]

        # Compute a mask of non-final states and concatenate the batch elements
        # (a final state would've been the one after which simulation ended)
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                              batch.next_state)), dtype=torch.bool)
        non_final_next_states = [s for s in batch.next_state
                                                if s is not None]

        non_final_state = dict()
        if non_final_next_states:
            for k in non_final_next_states[0].keys():
                non_final_state[k] = torch.stack([data[k] for data in non_final_next_states])

        device = next(self.target_net.parameters()).device
        next_Q_values = torch.zeros(len(non_final_mask)).to(device)

        # argmax a' Q(s', a')
        next_Q_values[non_final_mask] = self.target_net(non_final_state).max(1)[0].detach()

        # E[r + gamma argmax a' Q(s', a', theta)]
        expected_Q_values = (next_Q_values * self.gamma) + reward_batch.to(device)

        # Compute Huber loss
        #loss = F.smooth_l1_loss(Q_values, expected_Q_values, beta=101)
        #loss = F.mse_loss(Q_values, expected_Q_values)
        loss = F.huber_loss(Q_values, expected_Q_values, delta=10)
        if torch.isnan(loss):
            import pdb;pdb.set_trace()
        self.iteration += 1
            # Update the target network, copying all weights and biases in DQN
        if self.iteration % self.target_update == 0:
            logging.debug('update target network')
            with torch.no_grad():
                for pol_param, target_param in zip(self.policy_net.parameters(),
                                                   self.target_net.parameters()):
                    mean = 0.4 * pol_param.detach().cpu().numpy() + 0.6 * target_param.detach().cpu().numpy()
                    target_param[:] = torch.as_tensor(mean).to(pol_param)
        if self.iteration % self.save_interval == 0 and (self.iteration):
            self.save_memory()
        return loss