Exemplo n.º 1
0
    if epoch_counter % config['eval_every_n_epochs'] == 0:

        # validation steps
        with torch.no_grad():
            model.eval()

            valid_loss = 0.0
            for counter, ((xis, xjs), _) in enumerate(valid_loader):

                if train_gpu:
                    xis = xis.cuda()
                    xjs = xjs.cuda()
                loss = (step(xis, xjs))
                valid_loss += loss.item()

            valid_loss /= counter

            if valid_loss < best_valid_loss:
                # save the model weights
                best_valid_loss = valid_loss
                torch.save(model.state_dict(),
                           os.path.join(model_checkpoints_folder, 'model.pth'))

            train_writer.add_scalar('validation_loss',
                                    valid_loss,
                                    global_step=valid_n_iter)
            valid_n_iter += 1

        model.train()
Exemplo n.º 2
0
    def train(self):
        #Data
        train_loader, valid_loader = self.dataset.get_data_loaders()

        #Model
        model = ResNetSimCLR(**self.config["model"])
        if self.device == 'cuda':
            model = nn.DataParallel(model, device_ids=[i for i in range(self.config['gpu']['gpunum'])])
        #model = model.to(self.device)
        model = model.cuda()
        print(model)
        model = self._load_pre_trained_weights(model)
        
        each_epoch_steps = len(train_loader)
        total_steps = each_epoch_steps * self.config['train']['epochs'] 
        warmup_steps = each_epoch_steps * self.config['train']['warmup_epochs']
        scaled_lr = eval(self.config['train']['lr']) * self.batch_size / 256.

        optimizer = torch.optim.Adam(
                     model.parameters(), 
                     scaled_lr, 
                     weight_decay=eval(self.config['train']['weight_decay']))
       
        '''
        optimizer = LARS(params=model.parameters(),
                     lr=eval(self.config['train']['lr']),
                     momentum=self.config['train']['momentum'],
                     weight_decay=eval(self.config['train']['weight_decay'],
                     eta=0.001,
                     max_epoch=self.config['train']['epochs'])
        '''

        # scheduler during warmup stage
        lambda1 = lambda epoch:epoch*1.0 / int(warmup_steps)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)

        if apex_support and self.config['train']['fp16_precision']:
            model, optimizer = amp.initialize(model, optimizer,
                                              opt_level='O2',
                                              keep_batchnorm_fp32=True)

        model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')

        # save config file
        _save_config_file(model_checkpoints_folder)

        n_iter = 0
        valid_n_iter = 0
        best_valid_loss = np.inf
        lr = eval(self.config['train']['lr']) 

        end = time.time()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        
        for epoch_counter in range(self.config['train']['epochs']):
            model.train()
            for i, ((xis, xjs), _) in enumerate(train_loader):
                data_time.update(time.time() - end)
                optimizer.zero_grad()

                xis = xis.cuda()
                xjs = xjs.cuda()

                loss = self._step(model, xis, xjs, n_iter)

                #print("Loss: ",loss.data.cpu())
                losses.update(loss.item(), 2 * xis.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                print('Epoch: [{epoch}][{step}/{each_epoch_steps}] Loss {loss.val:.4f} Avg Loss {loss.avg:.4f} DataTime {datatime.val:.4f} BatchTime {batchtime.val:.4f} LR {lr})'.format(epoch=epoch_counter, step=i, each_epoch_steps=each_epoch_steps, loss=losses, datatime=data_time, batchtime=batch_time, lr=lr))

                if n_iter % self.config['train']['log_every_n_steps'] == 0:
                    self.writer.add_scalar('train_loss', loss, global_step=n_iter)

                if apex_support and self.config['train']['fp16_precision']:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                optimizer.step()
                n_iter += 1

                #adjust lr
                if n_iter == warmup_steps:
                    # scheduler after warmup stage
                    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps-warmup_steps, eta_min=0, last_epoch=-1)
                scheduler.step()
                lr = scheduler.get_lr()[0]
                self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter)
                sys.stdout.flush()

            # validate the model if requested
            if epoch_counter % self.config['train']['eval_every_n_epochs'] == 0:
                valid_loss = self._validate(model, valid_loader)
                if valid_loss < best_valid_loss:
                    # save the model weights
                    best_valid_loss = valid_loss
                    torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth'))

                self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter)
                valid_n_iter += 1