예제 #1
0
class Trainer():
    def __init__(self, 
                device, 
                config,
                model, 
                criterion,
                optimizer,
                scheduler,
                metric
                ):
        super(Trainer, self).__init__() 

        self.config = config 
        self.device = device
        self.model = model 
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.metric = metric 

        self.train_ID = self.config.get('id', None)
        self.train_ID += '-' + datetime.now().strftime('%Y_%m_%d-%H_%M_%S') 

        self.nepochs = self.config['trainer']['nepochs']
        self.backward_step = self.config['trainer']['backward_step']
        self.val_step = self.config['trainer']['val_step']
        
        self.best_loss = np.inf 
        self.best_metric = {k: 0.0 for k in self.metric.keys()}
        self.val_loss = list() 
        self.val_metric = {k: list() for k in self.metric.keys() }

        self.save_dir = os.path.join(config['trainer']['save_dir'], self.train_ID)
        self.tsboard = TensorboardLogger(path=self.save_dir) 

    def save_checkpoint(self, epoch, val_loss = 0, val_metric = None):
        data = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'config': self.config
        }

        if val_loss < self.best_loss:
            print(
                f'Loss is improved from {self.best_loss: .6f} to {val_loss: .6f}. Saving weights...')
            torch.save(data, os.path.join(self.save_dir, 'best_loss.pth'))
            # Update best_loss
            self.best_loss = val_loss
        else:
            print(f'Loss is not improved from {self.best_loss:.6f}.')

        for k in self.metric.keys():
           if val_metric[k] > self.best_metric[k]:
               print(
                   f'{k} is improved from {self.best_metric[k]: .6f} to {val_metric[k]: .6f}. Saving weights...')
               torch.save(data, os.path.join(
                   self.save_dir, f'best_metric_{k}.pth'))
               self.best_metric[k] = val_metric[k]
           else:
               print(
                   f'{k} is not improved from {self.best_metric[k]:.6f}.')

        print('Saving current model...')
        torch.save(data, os.path.join(self.save_dir, 'current.pth'))    

    def save_current_checkpoint(self, epoch):
        data = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'config': self.config
        }
        print('Saving mini current model...')
        torch.save(data, os.path.join(self.save_dir, 'mini_current.pth'))    
    
    def train_epoch(self, epoch, dataloader):
        total_loss = meter.AverageValueMeter() 
        for m in self.metric.values():
            m.reset()
        self.model.train()
        print("Training..........")
        progress_bar = tqdm(dataloader)
        max_iter = len(dataloader)
        self.optimizer.zero_grad()
        for i, (inp, lbl) in enumerate(progress_bar):
            # 1: Load img_inputs and labels
            inp = move_to(inp, self.device)
            lbl = move_to(lbl, self.device)
            # 2: Clear gradients from previous iteration
            # 3: Get network outputs
            outs = self.model(inp)
            # 4: Calculate the loss
            prediction = torch.argmax(outs, dim=1)
            #print(torch.min(prediction), torch.max(prediction), torch.min(lbl),torch.max(lbl))
        
            loss = self.criterion(outs, lbl)
            # 5: Calculate gradients
            #print(loss.item(), "Loss")
            loss.backward()
            # 6: Performing backpropagation
            # if (i + 1) % self.backward_step == 0:
            self.optimizer.step()
            self.optimizer.zero_grad()
        
            total_loss.add(loss.item()) 

            outs = detach(outs)
            lbl = detach(lbl)
            for m in self.metric.values():
                # value = m.calculate(outs, lbl)
                m.update(outs, lbl)

            with torch.no_grad():
                total_loss.add(loss.item())
                desc = 'Iteration: {}/{}. Total loss: {:.5f}. '.format(
                        i + 1, len(dataloader), loss.item())
                for m in self.metric.values():
                  value = m.value()
                  metric = m.__class__.__name__
                  desc += f'{metric}: {value:.5f}, '
                progress_bar.set_description(desc)
                
                self.tsboard.update_scalar(
                    'Loss/train', loss, epoch * len(dataloader) + i 
                )

            

            # if (i + 1) % self.config['trainer']['checkpoint_mini_step'] == 0:
            #     self.save_current_checkpoint(epoch)
                
        print("+ Train result")
        avg_loss = total_loss.value()[0]
        print("Loss:", avg_loss)
        for m in self.metric.values():
            m.summary() 
            m.reset()

    @torch.no_grad() 
    def val_epoch(self, epoch, dataloader):
        total_loss = meter.AverageValueMeter() 
        for m in self.metric.values():
            m.reset()
        self.model.eval() 
        print("Evaluating.....")
        progress_bar = tqdm(dataloader)
        # cls_loss
        for i, (inp, lbl) in enumerate(progress_bar):
            # 1: Load inputs and labels
            inp = move_to(inp, self.device)
            lbl = move_to(lbl, self.device)
            # 2: Get network outputs
            outs = self.model(inp)
            # 3: Calculate the loss
            loss = self.criterion(outs, lbl)
            # 4: Update loss
            # 5: Update metric
            outs = detach(outs)
            lbl = detach(lbl)
            for m in self.metric.values():
                # value = m.calculate(outs, lbl)
                m.update(outs, lbl)

            total_loss.add(loss.item())
            desc = 'Iteration: {}/{}. Total loss: {:.5f}. '.format(
                        i + 1, len(dataloader), loss.item())
            for m in self.metric.values():
              value = m.value()
              metric = m.__class__.__name__
              desc += f'{metric}: {value:.5f}, '
            progress_bar.set_description(desc)
            
            
        print("+ Evaluation result")
        avg_loss = total_loss.value()[0]
        print("Loss: ", avg_loss)
        
        self.val_loss.append(avg_loss)
        
        self.tsboard.update_scalar(
            'Loss/val', total_loss.value()[0], epoch * len(dataloader) + i 
        )
                        
        # Calculate metric here
        for k in self.metric.keys():
            m = self.metric[k].value()
            self.metric[k].summary()
            self.val_metric[k].append(m)
            self.tsboard.update_metric('val', k, m, epoch)

    def train(self, train_dataloader, val_dataloader):
        for epoch in range(self.nepochs):
            print('\nEpoch {:>3d}'.format(epoch))
            print('-----------------------------------')

            # Note learning rate
            for i, group in enumerate(self.optimizer.param_groups):
                self.tsboard.update_lr(i, group['lr'], epoch)

            # 1: Training phase
            self.train_epoch(epoch=epoch, dataloader=train_dataloader)

            print()

            # 2: Evalutation phase
            if (epoch + 1) % self.val_step == 0:
                # 2: Evaluating model
                self.val_epoch(epoch, dataloader=val_dataloader)
                print('-----------------------------------')

                # 3: Learning rate scheduling
                self.scheduler.step(self.val_loss[-1])

                # 4: Saving checkpoints

                # if not self.debug:
                # Get latest val loss here
                val_loss = self.val_loss[-1]
                val_metric = {k: m[-1] for k, m in self.val_metric.items()}
                self.save_checkpoint(epoch, val_loss, val_metric)
예제 #2
0
class Trainer():
    def __init__(self, device, config, model, criterion, optimier, scheduler,
                 metric):
        super(Trainer, self).__init__()

        self.config = config
        self.device = device
        self.model = model
        self.criterion = criterion
        self.optimizer = optimier
        self.scheduler = scheduler
        self.metric = metric

        # Train ID
        self.train_id = self.config.get('id', 'None')
        self.train_id += '-' + datetime.now().strftime('%Y_%m_%d-%H_%M_%S')

        # Get arguments
        self.nepochs = self.config['trainer']['nepochs']
        self.log_step = self.config['trainer']['log_step']
        self.val_step = self.config['trainer']['val_step']
        self.debug = self.config['debug']

        # Instantiate global variables
        self.best_loss = np.inf
        self.best_metric = {k: 0.0 for k in self.metric.keys()}
        self.val_loss = list()
        self.val_metric = {k: list() for k in self.metric.keys()}

        # Instantiate loggers
        self.save_dir = os.path.join('runs', self.train_id)
        self.tsboard = TensorboardLogger(path=self.save_dir)

    def save_checkpoint(self, epoch, val_loss, val_metric):

        data = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'config': self.config
        }

        if val_loss < self.best_loss:
            print(
                f'Loss is improved from {self.best_loss: .6f} to {val_loss: .6f}. Saving weights...'
            )
            torch.save(data, os.path.join(self.save_dir, 'best_loss.pth'))
            # Update best_loss
            self.best_loss = val_loss
        else:
            print(f'Loss is not improved from {self.best_loss:.6f}.')

        for k in self.metric.keys():
            if val_metric[k] > self.best_metric[k]:
                print(
                    f'{k} is improved from {self.best_metric[k]: .6f} to {val_metric[k]: .6f}. Saving weights...'
                )
                torch.save(data,
                           os.path.join(self.save_dir, f'best_metric_{k}.pth'))
                self.best_metric[k] = val_metric[k]
            else:
                print(f'{k} is not improved from {self.best_metric[k]:.6f}.')

        # print('Saving current model...')
        # torch.save(data, os.path.join(self.save_dir, 'current.pth'))

    def train_epoch(self, epoch, dataloader):
        # 0: Record loss during training process
        running_loss = meter.AverageValueMeter()
        total_loss = meter.AverageValueMeter()
        for m in self.metric.values():
            m.reset()
        self.model.train()
        print('Training........')
        progress_bar = tqdm(dataloader)
        for i, (inp, lbl) in enumerate(progress_bar):
            # 1: Load img_inputs and labels
            inp = move_to(inp, self.device)
            lbl = move_to(lbl, self.device)
            # 2: Clear gradients from previous iteration
            self.optimizer.zero_grad()
            # 3: Get network outputs
            outs = self.model(inp)
            # 4: Calculate the loss
            loss = self.criterion(outs, lbl)
            # 5: Calculate gradients
            loss.backward()
            # 6: Performing backpropagation
            self.optimizer.step()
            with torch.no_grad():
                # 7: Update loss
                running_loss.add(loss.item())
                total_loss.add(loss.item())

                if (i + 1) % self.log_step == 0 or (i + 1) == len(dataloader):
                    self.tsboard.update_loss('train',
                                             running_loss.value()[0],
                                             epoch * len(dataloader) + i)
                    running_loss.reset()

                # 8: Update metric
                outs = detach(outs)
                lbl = detach(lbl)
                for m in self.metric.values():
                    value = m.calculate(outs, lbl)
                    m.update(value)

        print('+ Training result')
        avg_loss = total_loss.value()[0]
        print('Loss:', avg_loss)
        for m in self.metric.values():
            m.summary()

    @torch.no_grad()
    def val_epoch(self, epoch, dataloader):
        running_loss = meter.AverageValueMeter()
        for m in self.metric.values():
            m.reset()

        self.model.eval()
        print('Evaluating........')
        progress_bar = tqdm(dataloader)
        for i, (inp, lbl) in enumerate(progress_bar):
            # 1: Load inputs and labels
            inp = move_to(inp, self.device)
            lbl = move_to(lbl, self.device)
            # 2: Get network outputs
            outs = self.model(inp)
            # 3: Calculate the loss
            loss = self.criterion(outs, lbl)
            # 4: Update loss
            running_loss.add(loss.item())
            # 5: Update metric
            outs = detach(outs)
            lbl = detach(lbl)
            for m in self.metric.values():
                value = m.calculate(outs, lbl)
                m.update(value)

        print('+ Evaluation result')
        avg_loss = running_loss.value()[0]
        print('Loss:', avg_loss)
        self.val_loss.append(avg_loss)
        self.tsboard.update_loss('val', avg_loss, epoch)

        for k in self.metric.keys():
            m = self.metric[k].value()
            self.metric[k].summary()
            self.val_metric[k].append(m)
            self.tsboard.update_metric('val', k, m, epoch)

    def train(self, train_dataloader, val_dataloader):
        for epoch in range(self.nepochs):
            print('\nEpoch {:>3d}'.format(epoch))
            print('-----------------------------------')

            # Note learning rate
            for i, group in enumerate(self.optimizer.param_groups):
                self.tsboard.update_lr(i, group['lr'], epoch)

            # 1: Training phase
            self.train_epoch(epoch=epoch, dataloader=train_dataloader)

            print()

            # 2: Evalutation phase
            if (epoch + 1) % self.val_step == 0:
                # 2: Evaluating model
                self.val_epoch(epoch, dataloader=val_dataloader)
                print('-----------------------------------')

                # 3: Learning rate scheduling
                self.scheduler.step(self.val_loss[-1])

                # 4: Saving checkpoints
                if not self.debug:
                    # Get latest val loss here
                    val_loss = self.val_loss[-1]
                    val_metric = {k: m[-1] for k, m in self.val_metric.items()}
                    self.save_checkpoint(epoch, val_loss, val_metric)