Ejemplo n.º 1
0
class Trainer:
    def __init__(self,
                 encoder,
                 decoder,
                 optimizer_params={},
                 amp_params={},
                 n_jobs=0,
                 rank=0):

        lr = optimizer_params.get('lr', 1e-3)
        weight_decay = optimizer_params.get('weight_decay', 0)
        warmap = optimizer_params.get('warmap', 100)
        amsgrad = optimizer_params.get('amsgrad', False)
        opt_level = amp_params.get('opt_level', 'O0')
        loss_scale = amp_params.get('loss_scale', None)

        self.device = torch.device('cuda:' + str(rank))
        self.encoder = encoder.to(self.device)
        #self.decoder = decoder.to(self.device)
        self.num_classes = decoder.num_classes
        self.mse_critetion = nn.L1Loss()
        self.ce_criterion = LabelSmoothingLoss(self.num_classes,
                                               smoothing=0.1,
                                               reduction='none').to(
                                                   self.device)
        self.vat_criterion = VATLoss()
        self.cutmix = CutMix(self.num_classes)

        param_optimizer = list(self.encoder.named_parameters()
                               )  #+ list(self.decoder.named_parameters())
        no_decay = ['bn', 'bias']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            weight_decay
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        self.optimizer = RAdam(optimizer_grouped_parameters,
                               lr=lr,
                               weight_decay=weight_decay)

        self.is_master = torch.distributed.get_rank() == 0
        torch.cuda.set_device(rank)
        [self.encoder
         ], self.optimizer = apex.amp.initialize([self.encoder],
                                                 self.optimizer,
                                                 opt_level=opt_level,
                                                 loss_scale=loss_scale,
                                                 verbosity=1)

        self.scheduler = StepLR(self.optimizer, step_size=20, gamma=0.5)

        self.encoder = apex.parallel.DistributedDataParallel(
            self.encoder, delay_allreduce=True)
        #self.decoder = apex.parallel.DistributedDataParallel(self.decoder, delay_allreduce=True)

        self.last_epoch = 0
        self.n_jobs = n_jobs

    def _train_epoch(self, train_dataloader):
        if self.is_master:
            pbar = tqdm(desc=f'Train, epoch #{self.last_epoch}',
                        total=len(train_dataloader))

        self.encoder.train()
        #self.decoder.train()

        sum_loss, cls_loss = AvgMeter(), AvgMeter()
        for images, labels in train_dataloader:
            images, labels, shuffled_labels, l = self.cutmix(images, labels)
            images = images.to(self.device)
            labels = labels.to(self.device)
            shuffled_labels = shuffled_labels.to(self.device)
            l = l.to(self.device)

            self.optimizer.zero_grad()

            #loss_vat = self.vat_criterion(self.encoder, images)

            label_preds = self.encoder(images)
            #reconsts_l = self.decoder(latents, labels)
            #with disable_tracking_bn_stats(self.encoder):
            #    latents_l, label_preds_l = self.encoder(reconsts_l)
            #labels_r = torch.randint_like(labels, low=0, high=self.num_classes)
            #reconsts_r = self.decoder(latents, labels_r)
            #with disable_tracking_bn_stats(self.encoder):
            #    latents_r, label_preds_r = self.encoder(reconsts_r)

            loss_c = (l * self.ce_criterion(label_preds, labels) + (1 - l) *
                      self.ce_criterion(label_preds, shuffled_labels)).mean()
            #loss_r = self.mse_critetion(reconsts_l, images)
            #loss_e = self.ce_criterion(label_preds_r, labels_r)
            #loss_i = self.mse_critetion(latents_l, latents_r)

            losses = loss_c  #+ loss_vat # + loss_r + loss_e + loss_i

            with apex.amp.scale_loss(losses, self.optimizer) as scaled_loss:
                scaled_loss.backward()

            self.optimizer.step()

            sum_loss.update(losses.item())
            cls_loss.update(loss_c.item())

            info_tensor = torch.tensor([sum_loss(), cls_loss()],
                                       device=self.device)
            torch.distributed.reduce(info_tensor, dst=0)

            if self.is_master:
                info_tensor = info_tensor / torch.distributed.get_world_size()
                pbar.update(1)
                pbar.set_postfix({
                    'sum_loss': info_tensor[0].item(),
                    'cls_loss': info_tensor[1].item()
                })

        self.scheduler.step()

    def _test_epoch(self, test_dataloader):
        with torch.no_grad():
            if self.is_master:
                pbar = tqdm(desc=f'Test, epoch #{self.last_epoch}',
                            total=len(test_dataloader))

            self.encoder.eval()

            loss, acc, quality_metric = AvgMeter(), AvgMeter(), 0
            for images, labels in test_dataloader:
                images = images.to(self.device)
                labels = labels.to(self.device)

                label_preds = self.encoder(images)
                loss_val = self.ce_criterion(label_preds, labels).mean()
                acc_val = (torch.argmax(label_preds,
                                        dim=-1) == labels).float().mean()

                loss.update(loss_val.item())
                acc.update(acc_val.item())

                info_tensor = torch.tensor([loss(), acc()], device=self.device)
                torch.distributed.reduce(info_tensor, dst=0)

                if self.is_master:
                    info_tensor = info_tensor / torch.distributed.get_world_size(
                    )
                    quality_metric = info_tensor[1].item()
                    pbar.update(1)
                    pbar.set_postfix({
                        'loss': info_tensor[0].item(),
                        'acc': info_tensor[1].item()
                    })

            return quality_metric

    def _save_checkpoint(self, checkpoint_path):
        os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
        torch.save(self.encoder.module.state_dict(), checkpoint_path)

    def train(self,
              train_data,
              n_epochs,
              batch_size,
              test_data=None,
              last_checkpoint_path=None,
              best_checkpoint_path=None):

        num_replicas = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()
        batch_size = batch_size // num_replicas

        train_sampler = DistributedSampler(train_data,
                                           shuffle=True,
                                           num_replicas=num_replicas,
                                           rank=rank)
        train_dataloader = DataLoader(train_data,
                                      batch_size=batch_size,
                                      sampler=train_sampler,
                                      num_workers=self.n_jobs)

        if test_data is not None:
            test_sampler = DistributedSampler(test_data,
                                              shuffle=False,
                                              num_replicas=num_replicas,
                                              rank=rank)
            test_dataloader = DataLoader(test_data,
                                         batch_size=batch_size,
                                         sampler=test_sampler,
                                         num_workers=self.n_jobs)

        best_metric = float('-inf')
        for epoch in range(n_epochs):
            torch.cuda.empty_cache()
            self._train_epoch(train_dataloader)

            if last_checkpoint_path is not None and self.is_master:
                self._save_checkpoint(last_checkpoint_path)

            if test_data is not None:
                torch.cuda.empty_cache()
                metric = self._test_epoch(test_dataloader)

                if best_checkpoint_path is not None and self.is_master:
                    if metric > best_metric:
                        best_metric = metric
                        self._save_checkpoint(best_checkpoint_path)

            self.last_epoch += 1
Ejemplo n.º 2
0
class IQN(nn.Module):
    def __init__(self, state_space, act_n, quantile_dim, num_quantiles,
                 hidden_dim, num_hidden, optim_params):
        """
        Rainbow Recurrent IQN

        IQN: https://arxiv.org/pdf/1806.06923.pdf
        R2D2: https://openreview.net/pdf?id=r1lyTjAqYX
        R2D3: https://arxiv.org/abs/1909.01387
        """
        nn.Module.__init__(self)

        self.online = Model(state_space, act_n, quantile_dim, num_quantiles,
                            hidden_dim, num_hidden)
        self.target = deepcopy(self.online)

        self.loss_func = nn.SmoothL1Loss(reduction="mean")
        self.optim = RAdam(self.online.parameters(), **optim_params)

    def forward(self, inp):
        return self.online(inp)

    def step(self, state, greedy=False):
        """
        Takes a step into the environment
        """
        return self.online.step(state, greedy)

    def train_batch(self, rollouts, burn_in_length, sequence_length):
        """
        Trains for a batch of rollouts with the given burn in length and
        training sequence length
        """
        self.optim.zero_grad()

        states, actions, rewards, next_states, terminals, hidden_state = rollouts

        # Add burn in here #######
        
        next_q_vals, next_quantile_vals, next_quantiles, next_hidden = self.target(next_states)
        num_quantiles = next_quantile_vals[1]

        next_actions = next_quantile_vals.argmax(-1, keepdim=1)
        next_actions = next_actions.unsqueeze(1).repeat(1, num_quantiles, 1)
        next_values = next_quantile_vals.gather(-1, next_actions).squeeze(1)

        q_vals, quantile_vals, quantiles = self.online(states)
        action_values = quantile_vals.gather(-1, actions)

        td_error = next_values.unsqueeze(2) - action_values.unsqueeze(1)
        quantile_loss = self.loss_func(next_values.unsqueeze(2),
                                       action_values.unsqueeze(1))

        quantiles = quantiles.unsqueeze(1).repeat(1, self.num_quantiles, 1)
        penalty = torch.abs(quantiles - (td_error < 0).float().detach())

        loss = penalty * quantile_loss # Divide by huber kappa
        loss = loss.sum(2).mean(1)
        meaned_loss = loss.mean(1)

        meaned_loss.backward()
        self.optim.step()

        return meaned_loss, loss

    def train(self, num_batches, batch_size, burn_in_length, sequence_length,
              online_replay_buffer=None, supervised_replay_buffer=None,
              supervised_chance=0.25, writer=None):
        """
        Trains R2D3 style with 2 replay buffers
        """
        assert not online_replay_buffer == supervised_replay_buffer == None

        for batch in range(1, num_batches + 1):
            buff_choice = np.rand()
            if(online_replay_buffer is None or buff_choice < supervised_chance):
                replay_buffer = supervised_replay_buffer
            else:
                replay_buffer = online_replay_buffer

            while(not replay_buffer.ready_to_sample(batch_size)):
                pass

            rollouts, idxs, is_weights = replay_buffer.sample(batch_size)

            loss, new_errors = self.train_batch(rollouts, burn_in_length,
                                                sequence_length)
            replay_buffer.update_priorities(new_errors, idxs)

            if(writer is not None):
                if(buff_choice < supervised_chance):
                    writer.add_summary("Supervised Loss", loss, batch)
                else:
                    writer.add_summary("Online Loss", loss, batch)

                writer.add_summary("Loss", loss, batch)

    def update_target(self):
        """
        Updates the target network
        """
        self.target.load_state_dict(self.online.state_dict())