Пример #1
0
class Trainer:
    def __init__(self, device, config, model, criterion, dataloader, data_transformer=None, tensorboard=True, meta_data=None):
        super().__init__()

        config['running_loss_range'] = config['running_loss_range'] if 'running_loss_range' in config else 50

        self.device = torch.device('cuda' if torch.cuda.is_available() and device == 'cuda' else 'cpu')
        self.config = config

        self.model = model.to(device)
        self.criterion = criterion

        self.ds_length = len(dataloader.dataset)
        self.batch_size = dataloader.batch_size
        self.dataloader = dataloader
        self.data_transformer = data_transformer

        self.epoch_loss_history = []
        self.content_loss_history = []
        self.style_loss_history = []
        self.total_variation_loss_history = []
        self.loss_history = []
        self.lr_history = []
        self.meta_data = meta_data

        self.progress_bar = trange(
            math.ceil(self.ds_length / self.batch_size) * self.config['epochs'],
            leave=True
        )

        if self.config['lr_scheduler'] == 'CyclicLR':
            self.optimizer = optim.SGD(self.model.parameters(), lr=config['max_lr'], nesterov=True, momentum=0.9)
            self.scheduler = CyclicLR(
                optimizer=self.optimizer,
                base_lr=config['min_lr'],
                max_lr=config['max_lr'],
                step_size_up=self.config['lr_step_size'],
                mode='triangular2'
            )
        elif self.config['lr_scheduler'] == 'CosineAnnealingLR':
            self.optimizer = optim.Adam(self.model.parameters(), lr=config['max_lr'])
            self.scheduler = CosineAnnealingLR(
                optimizer=self.optimizer,
                T_max=self.config['lr_step_size']
            )
        elif self.config['lr_scheduler'] == 'ReduceLROnPlateau':
            self.optimizer = optim.Adam(self.model.parameters(), lr=config['max_lr'])
            self.scheduler = ReduceLROnPlateau(
                optimizer=self.optimizer,
                patience=100
            )
        elif self.config['lr_scheduler'] == 'StepLR':
            self.optimizer = optim.Adam(self.model.parameters(), lr=config['max_lr'])
            self.scheduler = StepLR(
                optimizer=self.optimizer,
                step_size=self.config['lr_step_size'],
                gamma=float(self.config['lr_multiplicator'])
            )
        elif self.config['lr_scheduler'] == 'CosineAnnealingWarmRestarts':
            self.optimizer = optim.Adam(self.model.parameters(), lr=config['max_lr'])
            self.scheduler = CosineAnnealingWarmRestarts(
                optimizer=self.optimizer,
                T_0=self.config['lr_step_size'],
                eta_min=config['min_lr'],
                T_mult=int(self.config['lr_multiplicator'])
            )
        else:
            self.optimizer = optim.Adam(self.model.parameters(), lr=config['max_lr'])
            self.scheduler = None

        if tensorboard:
            self.tensorboard_writer = SummaryWriter(log_dir=os.path.join('./runs', config['name']))
        else:
            self.tensorboard_writer = None

    def load_checkpoint(self):
        name = self.config['name']
        path = f'./checkpoints/{name}.pth'

        if os.path.exists(path):
            checkpoint = torch.load(path)

            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

            if self.scheduler:
                self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

            self.content_loss_history = checkpoint['content_loss_history']
            self.style_loss_history = checkpoint['style_loss_history']
            self.total_variation_loss_history = checkpoint['total_variation_loss_history']
            self.loss_history = checkpoint['loss_history']
            self.lr_history = checkpoint['lr_history']

            del checkpoint
            torch.cuda.empty_cache()

    def train(self):
        start = time()

        for epoch in range(self.config['epochs']):
            self.epoch_loss_history = []

            for i, batch in enumerate(self.dataloader):
                self.do_training_step(i, batch)
                self.do_progress_bar_step(epoch, self.config['epochs'], i)

                if self.config['lr_scheduler'] == 'ReduceLROnPlateau':
                    self.scheduler.step(self.loss_history[-1])
                elif self.scheduler:
                    self.scheduler.step()

                if i % self.config['save_checkpoint_interval'] == 0:
                    self.save_checkpoint(f'./checkpoints/{self.config["name"]}.pth')

                if time() - start >= self.config['max_runtime'] != 0:
                    break

        torch.cuda.empty_cache()

    def do_training_step(self, i, batch):
        self.model.train()

        with torch.autograd.detect_anomaly():
            try:

                if self.data_transformer:
                    x, y = self.data_transformer(batch)
                else:
                    x, y = batch

                x = x.to(self.device)
                y = y.to(self.device)
                self.optimizer.zero_grad()

                preds = self.model(x)

                loss = self.criterion(preds, y)
                loss.backward()
                self.optimizer.step()

                self.lr_history.append(self.optimizer.param_groups[0]['lr'])
                self.epoch_loss_history.append(self.criterion.loss_val)

                self.content_loss_history.append(self.criterion.content_loss_val)
                self.style_loss_history.append(self.criterion.style_loss_val)
                self.total_variation_loss_history.append(self.criterion.total_variation_loss_val)
                self.loss_history.append(self.criterion.loss_val)

                if self.tensorboard_writer and i % self.config['save_checkpoint_interval'] == 0:
                    grid_y = torchvision.utils.make_grid(y)
                    grid_preds = torchvision.utils.make_grid(preds)

                    self.tensorboard_writer.add_image('Inputs', grid_y, 0)
                    self.tensorboard_writer.add_image('Predictions', grid_preds, 0)

                    # writer.add_graph(network, images)
                    self.tensorboard_writer.add_scalar(
                        'Content Loss',
                        self.content_loss_history[-1],
                        len(self.content_loss_history) - 1
                    )

                    self.tensorboard_writer.add_scalar(
                        'Style Loss',
                        self.style_loss_history[-1],
                        len(self.style_loss_history) - 1
                    )

                    self.tensorboard_writer.add_scalar(
                        'TV Loss',
                        self.total_variation_loss_history[-1],
                        len(self.total_variation_loss_history) - 1
                    )

                    self.tensorboard_writer.add_scalar(
                        'Total Loss',
                        self.loss_history[-1],
                        len(self.loss_history) - 1
                    )

                    self.tensorboard_writer.add_scalar(
                        'Learning Rate',
                        self.lr_history[-1],
                        len(self.lr_history) - 1
                    )

                    self.tensorboard_writer.close()
            except:
                self.load_checkpoint()

    def do_validation_step(self):
        self.model.eval()

    def do_progress_bar_step(self, epoch, epochs, i):
        avg_epoch_loss = sum(self.epoch_loss_history) / (i + 1)

        if len(self.loss_history) >= self.config['running_loss_range']:
            running_loss = sum(
                self.loss_history[-self.config['running_loss_range']:]
            ) / self.config['running_loss_range']
        else:
            running_loss = 0

        if len(self.loss_history) > 0:
            self.progress_bar.set_description(
                f'Name: {self.config["name"]}, ' +
                f'Loss Network: {self.config["loss_network"]}, ' +
                f'Epoch: {epoch + 1}/{epochs}, ' +
                f'Average Epoch Loss: {avg_epoch_loss:,.2f}, ' +
                f'Running Loss: {running_loss:,.2f}, ' +
                f'Loss: {self.loss_history[-1]:,.2f}, ' +
                f'Learning Rate: {self.lr_history[-1]:,.6f}'
            )
        else:
            self.progress_bar.set_description(
                f'Name: {self.config["name"]}, ' +
                f'Loss Network: {self.config["loss_network"]}, ' +
                f'Epoch: {epoch + 1}/{epochs}, ' +
                f'Average Epoch Loss: {0:,.2f}, ' +
                f'Running Loss: {0:,.2f}, ' +
                f'Loss: {0:,.2f}, ' +
                f'Learning Rate: {0:,.6f}'
            )

        self.progress_bar.update(1)
        self.progress_bar.refresh()

    def save_checkpoint(self, path):
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
            'content_loss_history': self.content_loss_history,
            'style_loss_history': self.style_loss_history,
            'total_variation_loss_history': self.total_variation_loss_history,
            'loss_history': self.loss_history,
            'lr_history': self.lr_history,
            'content_image_size': self.config['content_image_size'],
            'style_image_size': self.config['style_image_size'],
            'network': str(self.config['network']),
            'content_weight': self.config['content_weight'],
            'style_weight': self.config['style_weight'],
            'total_variation_weight': self.config['total_variation_weight'],
            'bottleneck_size': self.config['bottleneck_size'],
            'bottleneck_type': str(self.config['bottleneck_type']),
            'channel_multiplier': self.config['channel_multiplier'],
            'expansion_factor': self.config['expansion_factor'],
            'intermediate_activation_fn': self.config['intermediate_activation_fn'],
            'final_activation_fn': self.config['final_activation_fn'],
            'meta_data': self.meta_data
        }, path)
Пример #2
0
def run_training(opt):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    work_dir, epochs, train_batch, valid_batch, weights = \
        opt.work_dir, opt.epochs, opt.train_bs, opt.valid_bs, opt.weights

    # Directories
    last = os.path.join(work_dir, 'last.pt')
    best = os.path.join(work_dir, 'best.pt')

    # --------------------------------------
    # Setup train and validation set
    # --------------------------------------
    data = pd.read_csv(opt.train_csv)
    images_path = opt.data_dir

    n_classes = 6  # fixed coding :V

    data['class'] = data.apply(lambda row: categ[row["class"]], axis=1)

    train_loader, val_loader = prepare_dataloader(data,
                                                  opt.fold,
                                                  train_batch,
                                                  valid_batch,
                                                  opt.img_size,
                                                  opt.num_workers,
                                                  data_root=images_path)

    # if not opt.ovr_val:
    #     handwritten_data = pd.read_csv(opt.handwritten_csv)
    #     printed_data = pd.read_csv(opt.printed_csv)
    #     handwritten_data['class'] = handwritten_data.apply(lambda row: categ[row["class"]], axis =1)
    #     printed_data['class'] = printed_data.apply(lambda row: categ[row["class"]], axis =1)
    #     _, handwritten_val_loader = prepare_dataloader(
    #         handwritten_data, opt.fold, train_batch, valid_batch, opt.img_size, opt.num_workers, data_root=images_path)

    #     _, printed_val_loader = prepare_dataloader(
    #         printed_data, opt.fold, train_batch, valid_batch, opt.img_size, opt.num_workers, data_root=images_path)

    # --------------------------------------
    # Models
    # --------------------------------------

    model = Classifier(model_name=opt.model_name,
                       n_classes=n_classes,
                       pretrained=True).to(device)

    if opt.weights is not None:
        cp = torch.load(opt.weights)
        model.load_state_dict(cp['model'])

    # -------------------------------------------
    # Setup optimizer, scheduler, criterion loss
    # -------------------------------------------

    optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6)
    scheduler = CosineAnnealingWarmRestarts(optimizer,
                                            T_0=10,
                                            T_mult=1,
                                            eta_min=1e-6,
                                            last_epoch=-1)
    scaler = GradScaler()

    loss_tr = nn.CrossEntropyLoss().to(device)
    loss_fn = nn.CrossEntropyLoss().to(device)

    # --------------------------------------
    # Setup training
    # --------------------------------------
    if os.path.exists(work_dir) == False:
        os.mkdir(work_dir)

    best_loss = 1e5
    start_epoch = 0
    best_epoch = 0  # for early stopping

    if opt.resume == True:
        checkpoint = torch.load(last)

        start_epoch = checkpoint["epoch"]
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint["scheduler"])
        best_loss = checkpoint["best_loss"]

    # --------------------------------------
    # Start training
    # --------------------------------------
    print("[INFO] Start training...")
    for epoch in range(start_epoch, epochs):
        train_one_epoch(epoch,
                        model,
                        loss_tr,
                        optimizer,
                        train_loader,
                        device,
                        scheduler=scheduler,
                        scaler=scaler)
        with torch.no_grad():
            if opt.ovr_val:
                val_loss = valid_one_epoch_overall(epoch,
                                                   model,
                                                   loss_fn,
                                                   val_loader,
                                                   device,
                                                   scheduler=None)
            else:
                val_loss = valid_one_epoch(epoch,
                                           model,
                                           loss_fn,
                                           handwritten_val_loader,
                                           printed_val_loader,
                                           device,
                                           scheduler=None)

            if val_loss < best_loss:
                best_loss = val_loss
                best_epoch = epoch
                torch.save(
                    {
                        'epoch': epoch,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'best_loss': best_loss
                    }, os.path.join(best))

                print('best model found for epoch {}'.format(epoch + 1))

        torch.save(
            {
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'best_loss': best_loss
            }, os.path.join(last))

        if epoch - best_epoch > opt.patience:
            print("Early stop achieved at", epoch + 1)
            break

    del model, optimizer, train_loader, val_loader, scheduler, scaler
    torch.cuda.empty_cache()
Пример #3
0
def train():
    epoch_size = len(trainset) // args.batch_size
    num_epochs = math.ceil(args.max_iter / epoch_size)
    start_epoch = 0
    iteration = 0

    df = pd.read_csv('./data/train.csv')
    tmp = np.sqrt(1 / np.sqrt(df['landmark_id'].value_counts().sort_index().values))
    margins = (tmp - tmp.min()) / (tmp.max() - tmp.min()) * 0.45 + 0.05

    print('Loading model...')
    model = EfficientNetLandmark(args.depth, args.num_classes)
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    scheduler = CosineAnnealingWarmRestarts(optimizer, num_epochs-1)

    if args.resume is not None:
        state_dict = torch.load(args.resume)
        try:
            print('Resume all state...')
            modules_state_dict = state_dict['modules']
            optimizer_state_dict = state_dict['optimizer']
            scheduler_state_dict = state_dict['scheduler']
            optimizer.load_state_dict(optimizer_state_dict)
            scheduler.load_state_dict(scheduler_state_dict)
            start_epoch = state_dict['epoch']
            iteration = state_dict['iteration']
        except KeyError:
            print('Resume only modules...')
            modules_state_dict = state_dict
        
        model_state_dict = {k.replace('module.', ''): v for k, v in modules_state_dict.items() if k.replace('module.', '') in model.state_dict().keys()}
        model.load_state_dict(model_state_dict)

    num_gpus = list(range(torch.cuda.device_count()))
    if len(num_gpus) > 1:
        print('Using data parallel...')
        model = nn.DataParallel(model, device_ids=num_gpus)

    # logger = open('log.txt', 'w')
    losses = AverageMeter()
    scores = AverageMeter()

    start_train = datetime.now()
    print(num_epochs, start_epoch, iteration)
    model.train()
    for epoch in range(start_epoch, num_epochs):
        if (epoch+1)*epoch_size < iteration:
            continue

        if iteration == args.max_iter:
            break
        
        correct = 0
        input_size = 0
        start_time = datetime.now()
        for i, (inputs, targets) in enumerate(train_loader):
            optimizer.zero_grad()

            inputs = inputs.to(device)
            targets = targets.to(device)

            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, targets, margins)
            
            confs, preds = torch.max(outputs.detach(), dim=1)
            # loss.backward()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            # optimizer.step()
            scaler.update()

            losses.update(loss.item(), inputs.size(0))
            scores.update(gap(preds, confs, targets))
            correct += (preds == targets).float().sum()
            input_size += inputs.size(0)

            iteration += 1

            writer.add_scalar('loss', losses.val, iteration)
            writer.add_scalar('gap', scores.val, iteration)

            # log = {'epoch': epoch+1, 'iteration': iteration, 'loss': losses.val, 'acc': corrects.val, 'gap': scores.val}
            # logger.write(str(log) + '\n')
            if iteration % args.verbose_eval == 0:
                print(
                    f'[{epoch+1}/{iteration}] Loss: {losses.val:.5f} Acc: {correct/input_size:.5f}' \
                    f' GAP: {scores.val:.5f} LR: {optimizer.param_groups[0]["lr"]} Time: {datetime.now() - start_time}')
            
            if iteration > 100000 and iteration % args.save_interval == 0:
                print('Save model...')
                save_checkpoint({
                    'modules': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'epoch': epoch,
                    'iteration': iteration,
                }, f'effnet_b{args.depth}_{args.max_size}_{args.batch_size}_{epoch+1}_{iteration}.pth')

            scheduler.step(epoch+i / len(train_loader))
        print()

    # logger.close()
    writer.close()
    print(datetime.now() - start_train)
    logs = []

G = Generator()
D = Discriminator(norm='batch', pool_kernel_size=[4,2,2,2])
G_optimizer = optim.Adam(G.parameters(), lr=MAX_LR, betas=(0.5, 0.999))
G_scheduler = CosineAnnealingWarmRestarts(G_optimizer, T_0=200, T_mult=1, eta_min=MIN_LR)
D_optimizer = optim.Adam(D.parameters(), lr=MAX_LR, betas=(0.5, 0.999))
D_scheduler = CosineAnnealingWarmRestarts(D_optimizer, T_0=200, T_mult=1, eta_min=MIN_LR)

if LOAD_MODEL_EPOCH:
    print('G load: ', f'{model_g_dir}/model_psnr_{LOAD_G_MODEL_SCORE}_epoch{str(LOAD_MODEL_EPOCH).zfill(3)}.pth')
    checkpoint = torch.load(
        f'{model_g_dir}/model_psnr_{LOAD_G_MODEL_SCORE}_epoch{str(LOAD_MODEL_EPOCH).zfill(3)}.pth')
    G.load_state_dict(checkpoint['model'])
    G_optimizer.load_state_dict(checkpoint['optimizer'])
    G_scheduler.load_state_dict(checkpoint['scheduler'])
    for state in G_optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.to(DEVICE)

    print('D load: ', f'{model_d_dir}/model_psnr_{LOAD_D_MODEL_SCORE}_epoch{str(LOAD_MODEL_EPOCH).zfill(3)}.pth')
    checkpoint = torch.load(
        f'{model_d_dir}/model_psnr_{LOAD_D_MODEL_SCORE}_epoch{str(LOAD_MODEL_EPOCH).zfill(3)}.pth', map_location=DEVICE)
    D.load_state_dict(checkpoint['model'])
    D_optimizer.load_state_dict(checkpoint['optimizer'])
    D_scheduler.load_state_dict(checkpoint['scheduler'])
    for state in D_optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.to(DEVICE)    
Пример #5
0
                              num_workers=5,
                              shuffle=True,
                              drop_last=True)
    val_loader = DataLoader(dataset_validation,
                            batch_size=args["batch_size"],
                            num_workers=5,
                            shuffle=True,
                            drop_last=True)

    starting_epoch = 0
    if checkpoint:
        print("Loading state dict")
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        print("Starting training from evaluation accuracy: %s" %
              checkpoint["evaluation_accuracy"])
        starting_epoch = checkpoint["epoch"] + 1

    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=args["SGDR_T0"],
        T_mult=args["SGDR_T_MULT"],
        eta_min=args["SGDR_ETA_MIN"],
        last_epoch=checkpoint["epoch"] if checkpoint else -1)

    if checkpoint:
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])

    train(model, criterion, optimizer, scheduler, train_loader, val_loader,
          starting_epoch, args)
Пример #6
0
class Learner:
    def __init__(self, model, train_loader, valid_loader, config):
        self.config = config
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.model = model.to(self.config.device)

        self.logger = init_logger(self.config.log_dir, 'train_main.log')
        self.tb_logger = init_tb_logger(self.config.log_dir, 'train_main')
        self.log('\n'.join(
            [f"{k} = {v}" for k, v in self.config.__dict__.items()]))

        self.summary_loss = AverageMeter()
        self.evaluator = Evaluator()

        self.criterion = torch.nn.CrossEntropyLoss(
            ignore_index=self.config.ignore_index)
        self.u_criterion = torch.nn.CrossEntropyLoss(
            ignore_index=self.config.ignore_index)
        train_params = [{
            'params': getattr(model, 'encoder').parameters(),
            'lr': self.config.lr
        }, {
            'params': getattr(model, 'decoder').parameters(),
            'lr': self.config.lr * 10
        }]
        self.optimizer = RAdam(train_params,
                               weight_decay=self.config.weight_decay)

        self.scheduler = CosineAnnealingWarmRestarts(self.optimizer,
                                                     T_0=2,
                                                     T_mult=2,
                                                     eta_min=1e-6)

        self.n_ensemble = 0
        self.epoch = 0
        self.best_epoch = 0
        self.best_loss = np.inf
        self.best_score = -np.inf

    def train_one_epoch(self):
        self.model.train()
        self.summary_loss.reset()
        iters = len(self.train_loader)
        for step, (images, scribbles, weights) in enumerate(self.train_loader):
            self.tb_logger.add_scalar('Train/lr',
                                      self.optimizer.param_groups[0]['lr'],
                                      iters * self.epoch + step)
            scribbles = scribbles.to(self.config.device).long()
            images = images.to(self.config.device)
            batch_size = images.shape[0]

            self.optimizer.zero_grad()
            outputs = self.model(images)
            if self.epoch < self.config.thr_epoch:
                loss = self.criterion(outputs, scribbles)
            else:
                x_loss = self.criterion(outputs, scribbles)

                scribbles = scribbles.cpu()
                mean = weights[..., 0]
                u_labels = torch.where(
                    ((mean < (1 - self.config.thr_conf)) |
                     (mean > self.config.thr_conf)) &
                    (scribbles == self.config.ignore_index),
                    mean.round().long(),
                    self.config.ignore_index * torch.ones_like(scribbles)).to(
                        self.config.device)
                u_loss = self.u_criterion(outputs, u_labels)
                loss = x_loss + 0.5 * u_loss

            loss.backward()
            self.summary_loss.update(loss.detach().item(), batch_size)
            self.optimizer.step()
            if self.scheduler.__class__.__name__ != 'ReduceLROnPlateau':
                self.scheduler.step()

        return self.summary_loss.avg

    def validation(self):
        self.model.eval()
        self.summary_loss.reset()
        self.evaluator.reset()
        for step, (_, images, _, targets) in enumerate(self.valid_loader):
            with torch.no_grad():
                targets = targets.to(self.config.device).long()
                batch_size = images.shape[0]
                images = images.to(self.config.device)
                outputs = self.model(images)
                loss = self.criterion(outputs, targets)

                targets = targets.cpu().numpy()
                outputs = torch.argmax(outputs, dim=1)
                outputs = outputs.data.cpu().numpy()
                self.evaluator.add_batch(targets, outputs)
                self.summary_loss.update(loss.detach().item(), batch_size)

        if self.scheduler.__class__.__name__ == 'ReduceLROnPlateau':
            self.scheduler.step(self.evaluator.IoU)
        return self.summary_loss.avg, self.evaluator.IoU

    def ensemble_prediction(self):
        ds = self.train_loader.dataset
        transforms = Compose([Normalize(), ToTensorV2()])
        for idx, images in tqdm(ds.images.items(), total=len(ds)):
            augmented = transforms(image=images['image'])
            img = augmented['image'].unsqueeze(0).to(self.config.device)
            with torch.no_grad():
                pred = torch.nn.functional.softmax(self.model(img), dim=1)
            weight = torch.tensor(images['weight'])
            pred = pred.squeeze(0).cpu()
            x = pred[1]
            weight[..., 0] = self.config.alpha * x + (
                1 - self.config.alpha) * weight[..., 0]
            self.train_loader.dataset.images[idx]['weight'] = weight.numpy()
        self.n_ensemble += 1

    def fit(self, epochs):
        for e in range(epochs):
            t = time.time()
            loss = self.train_one_epoch()

            self.log(
                f'[Train] \t Epoch: {self.epoch}, loss: {loss:.5f}, time: {(time.time() - t):.2f}'
            )
            self.tb_log(loss, None, 'Train', self.epoch)

            t = time.time()
            loss, score = self.validation()

            self.log(
                f'[Valid] \t Epoch: {self.epoch}, loss: {loss:.5f}, IoU: {score:.4f}, time: {(time.time() - t):.2f}'
            )
            self.tb_log(loss, score, 'Valid', self.epoch)
            self.post_processing(loss, score)

            if (self.epoch + 1) % self.config.period_epoch == 0:
                self.log(
                    f'[Ensemble] \t the {self.n_ensemble}th Prediction Ensemble ...'
                )
                self.ensemble_prediction()

            self.epoch += 1
        self.log(
            f'best epoch: {self.best_epoch}, best loss: {self.best_loss}, best_score: {self.best_score}'
        )

    def post_processing(self, loss, score):
        if loss < self.best_loss:
            self.best_loss = loss

        if score > self.best_score:
            self.best_score = score
            self.best_epoch = self.epoch

            self.model.eval()
            torch.save(
                {
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                    'best_score': self.best_score,
                    'epoch': self.epoch,
                }, f'{os.path.join(self.config.log_dir, "best_model.pth")}')
            self.log(f'best model: {self.epoch} epoch - {score:.4f}')

    def load(self, path):
        checkpoint = torch.load(path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.best_score = checkpoint['best_score']
        self.epoch = checkpoint['epoch'] + 1

    def log(self, text):
        self.logger.info(text)

    def tb_log(self, loss, IoU, split, step):
        if loss: self.tb_logger.add_scalar(f'{split}/Loss', loss, step)
        if IoU: self.tb_logger.add_scalar(f'{split}/IoU', IoU, step)