Пример #1
0
def train(args, train_dataset, model):
    tb_writer = SummaryWriter(args.tb_writer_dir)
    result_writer = ResultWriter(args.eval_results_dir)

    if args.weighted_sampling == 1:
        # 세 가지 구질이 불균일하게 분포되었으므로 세 개를 동일한 비율로 샘플링
        # 결과적으로 이 방법을 썼을 때 좋지 않아서 wighted_sampling은 쓰지 않았음
        ball_type, counts = np.unique(train_dataset.pitch, return_counts=True)
        count_dict = dict(zip(ball_type, counts))
        weights = [1.0 / count_dict[p] for p in train_dataset.pitch]
        sampler = WeightedRandomSampler(weights,
                                        len(train_dataset),
                                        replacement=True)
        logger.info("Do Weighted Sampling")
    else:
        sampler = RandomSampler(train_dataset)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.train_batch_size,
                                  sampler=sampler)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // len(train_dataloader) + 1
    else:
        t_total = len(train_dataloader) * args.num_train_epochs

    args.warmup_step = int(args.warmup_percent * t_total)

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = [
        "bias",
        "layernorm.weight",
    ]

    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]

    optimizer = optim.Adam(optimizer_grouped_parameters,
                           lr=args.learning_rate,
                           eps=args.adam_epsilon)
    if args.warmup_step != 0:
        scheduler_cosine = CosineAnnealingLR(optimizer, t_total)
        scheduler = GradualWarmupScheduler(optimizer,
                                           1,
                                           args.warmup_step,
                                           after_scheduler=scheduler_cosine)
    else:
        scheduler = CosineAnnealingLR(optimizer, t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    loss_fct = torch.nn.NLLLoss()

    # Train!
    logger.info("***** Running Baseball Transformer *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Warmup Steps = %d", args.warmup_step)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.train_batch_size)
    logger.info("  Total train batch size = %d", args.train_batch_size)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    best_step = 0
    steps_trained_in_current_epoch = 0
    tr_loss, logging_loss, logging_val_loss = 0.0, 0.0, 0.0

    best_pitch_micro_f1, best_pitch_macro_f1, = 0, 0
    best_loss = 1e10
    best_pitch_macro_f1 = 0

    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
    )
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):

            (
                pitcher,
                batter,
                state,
                pitch,
                label,
                pitch_memory,
                label_memory,
                memory_mask,
            ) = list(map(lambda x: x.to(args.device), batch))
            model.train()
            pitching_score, memories = model(
                pitcher,
                batter,
                state,
                pitch_memory,
                label_memory,
                memory_mask,
            )

            pitching_score = pitching_score.log_softmax(dim=-1)
            loss = loss_fct(pitching_score, pitch)

            if args.n_gpu > 1:
                loss = loss.mean()

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()

            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)
            optimizer.step()
            scheduler.step()
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                if args.evaluate_during_training:
                    results, f1_results, f1_log, cm = evaluate(
                        args, args.eval_data_file, model)
                    output_eval_file = os.path.join(args.output_dir,
                                                    "eval_results.txt")
                    print_result(output_eval_file, results, f1_log, cm)

                    for key, value in results.items():
                        tb_writer.add_scalar("eval_{}".format(key), value,
                                             global_step)
                    logging_val_loss = results["loss"]

                tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                     args.logging_steps, global_step)
                logging_loss = tr_loss

                # best 모델 선정 지표를 loss말고 macro-f1으로 설정(trade-off 존재)
                # if best_loss > results["loss"]:
                if best_pitch_macro_f1 < results["pitch_macro_f1"]:
                    best_pitch_micro_f1 = results["pitch_micro_f1"]
                    best_pitch_macro_f1 = results["pitch_macro_f1"]
                    best_loss = results["loss"]
                    results["best_step"] = best_step = global_step

                    output_dir = os.path.join(args.output_dir, "best_model/")
                    os.makedirs(output_dir, exist_ok=True)
                    torch.save(model.state_dict(),
                               os.path.join(output_dir, "pytorch_model.bin"))
                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving best model to %s", output_dir)

                    result_path = os.path.join(output_dir, "best_results.txt")
                    print_result(result_path,
                                 results,
                                 f1_log,
                                 cm,
                                 off_logger=True)

                    results.update(dict(f1_results))
                    result_writer.update(args, **results)

                logger.info("  best pitch micro f1 : %s", best_pitch_micro_f1)
                logger.info("  best pitch macro f1 : %s", best_pitch_macro_f1)
                logger.info("  best loss : %s", best_loss)
                logger.info("  best step : %s", best_step)

            if args.save_steps > 0 and global_step % args.save_steps == 0:
                checkpoint_prefix = "checkpoint"
                # Save model checkpoint
                output_dir = os.path.join(
                    args.output_dir, "{}-{}".format(checkpoint_prefix,
                                                    global_step))
                os.makedirs(output_dir, exist_ok=True)
                torch.save(model.state_dict(),
                           os.path.join(output_dir, "pytorch_model.bin"))
                torch.save(args, os.path.join(output_dir, "training_args.bin"))
                logger.info("Saving model checkpoint to %s", output_dir)

                rotate_checkpoints(args, checkpoint_prefix)

                torch.save(optimizer.state_dict(),
                           os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(),
                           os.path.join(output_dir, "scheduler.pt"))
                logger.info("Saving optimizer and scheduler states to %s",
                            output_dir)

    tb_writer.close()

    return global_step, tr_loss / global_step
Пример #2
0
def train(name, df, VAL_FOLD=0, resume=False):
    dt_string = datetime.now().strftime("%d|%m_%H|%M|%S")
    print("Starting -->", dt_string)

    os.makedirs(OUTPUT_DIR, exist_ok=True)
    os.makedirs('checkpoint', exist_ok=True)
    run = f"{name}_[{dt_string}]"

    wandb.init(project="imanip", config=config_defaults, name=run)
    config = wandb.config

    # model = SRM_Classifer(num_classes=1, encoder_checkpoint='weights/pretrain_[31|03_12|16|32].h5')
    model = SMP_SRM_UPP(classifier_only=True)

    # for name_, param in model.named_parameters():
    #     if 'classifier' in name_:
    #         continue
    #     else:
    #         param.requires_grad = False

    print("Parameters : ",
          sum(p.numel() for p in model.parameters() if p.requires_grad))

    wandb.save('segmentation/smp_srm.py')
    wandb.save('dataset.py')

    train_imgaug, train_geo_aug = get_train_transforms()
    transforms_normalize = get_transforms_normalize()

    #region ########################-- CREATE DATASET and DATALOADER --########################
    train_dataset = DATASET(dataframe=df,
                            mode="train",
                            val_fold=VAL_FOLD,
                            test_fold=TEST_FOLD,
                            transforms_normalize=transforms_normalize,
                            imgaug_augment=train_imgaug,
                            geo_augment=train_geo_aug)
    train_loader = DataLoader(train_dataset,
                              batch_size=config.train_batch_size,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True,
                              drop_last=False)

    valid_dataset = DATASET(
        dataframe=df,
        mode="val",
        val_fold=VAL_FOLD,
        test_fold=TEST_FOLD,
        transforms_normalize=transforms_normalize,
    )
    valid_loader = DataLoader(valid_dataset,
                              batch_size=config.valid_batch_size,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True,
                              drop_last=False)

    test_dataset = DATASET(
        dataframe=df,
        mode="test",
        val_fold=VAL_FOLD,
        test_fold=TEST_FOLD,
        transforms_normalize=transforms_normalize,
    )
    test_loader = DataLoader(test_dataset,
                             batch_size=config.valid_batch_size,
                             shuffle=True,
                             num_workers=4,
                             pin_memory=True,
                             drop_last=False)
    #endregion ######################################################################################

    optimizer = get_optimizer(model, config.optimizer, config.learning_rate,
                              config.weight_decay)
    # after_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #     optimizer,
    #     patience=config.schedule_patience,
    #     mode="min",
    #     factor=config.schedule_factor,
    # )
    after_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                                 T_0=35,
                                                                 T_mult=2)
    scheduler = GradualWarmupScheduler(optimizer=optimizer,
                                       multiplier=1,
                                       total_epoch=config.warmup + 1,
                                       after_scheduler=after_scheduler)

    # this zero gradient update is needed to avoid a warning message, issue #8.
    # optimizer.zero_grad()
    # optimizer.step()

    criterion = nn.BCEWithLogitsLoss()
    es = EarlyStopping(patience=200, mode="min")

    model = nn.DataParallel(model).to(device)

    # wandb.watch(model, log_freq=50, log='all')

    start_epoch = 0
    if resume:
        checkpoint = torch.load(
            'checkpoint/(using pretrain)COMBO_ALL_FULL_[09|04_12|46|35].pt')
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print("-----------> Resuming <------------")

    for epoch in range(start_epoch, config.epochs):
        print(f"Epoch = {epoch}/{config.epochs-1}")
        print("------------------")

        train_metrics = train_epoch(model, train_loader, optimizer, scheduler,
                                    criterion, epoch)
        valid_metrics = valid_epoch(model, valid_loader, criterion, epoch)

        scheduler.step(valid_metrics['valid_loss'])

        print(
            f"TRAIN_ACC = {train_metrics['train_acc_05']}, TRAIN_LOSS = {train_metrics['train_loss']}"
        )
        print(
            f"VALID_ACC = {valid_metrics['valid_acc_05']}, VALID_LOSS = {valid_metrics['valid_loss']}"
        )
        print("Optimizer LR", optimizer.param_groups[0]['lr'])
        print("Scheduler LR", scheduler.get_lr()[0])
        wandb.log({
            'optim_lr': optimizer.param_groups[0]['lr'],
            'schedule_lr': scheduler.get_lr()[0]
        })

        es(
            valid_metrics["valid_loss"],
            model,
            model_path=os.path.join(OUTPUT_DIR, f"{run}.h5"),
        )
        if es.early_stop:
            print("Early stopping")
            break

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
        }
        torch.save(checkpoint, os.path.join('checkpoint', f"{run}.pt"))

    if os.path.exists(os.path.join(OUTPUT_DIR, f"{run}.h5")):
        print(
            model.load_state_dict(
                torch.load(os.path.join(OUTPUT_DIR, f"{run}.h5"))))
        print("LOADED FOR TEST")

    test_metrics = test(model, test_loader, criterion)
    wandb.save(os.path.join(OUTPUT_DIR, f"{run}.h5"))

    return test_metrics
Пример #3
0
def main():

    data_dir = '../data/'
    df_biopsy = pd.read_csv(os.path.join(data_dir, 'train.csv'))
    image_folder = os.path.join(data_dir, 'train_images')

    kernel_type = 'efficientnet-b3_36x256x256'
    enet_type = 'efficientnet-b3'
    num_folds = 5
    fold = 0
    tile_size = 256
    n_tiles = 32
    batch_size = 9
    num_workers = 24
    out_dim = 5
    init_lr = 3e-4
    warmup_factor = 10
    warmup_epo = 1
    n_epochs = 30
    use_amp = True

    writer = SummaryWriter(f'tensorboard_logs/{kernel_type}/fold-{fold}')

    if use_amp and not APEX_AVAILABLE:
        print("Error: could not import APEX module")
        exit()

    skf = StratifiedKFold(num_folds, shuffle=True, random_state=42)
    df_biopsy['fold'] = -1
    for i, (train_idx, valid_idx) in enumerate(
            skf.split(df_biopsy, df_biopsy['isup_grade'])):
        df_biopsy.loc[valid_idx, 'fold'] = i

    mean = [0.90949707, 0.8188697, 0.87795304]
    std = [0.36357649, 0.49984502, 0.40477625]
    transform_train = transforms.Compose([
        transforms.RandomChoice([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            RotationTransform([90, -90])
        ]),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    transform_val = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    df_train = df_biopsy.loc[df_biopsy['fold'] != fold]
    df_valid = df_biopsy.loc[df_biopsy['fold'] == fold]

    dataset_train = PANDADataset(df_train, image_folder, tile_size, n_tiles, \
        out_dim, transform=transform_train)
    dataset_valid = PANDADataset(df_valid, image_folder, tile_size, n_tiles, \
        out_dim, transform=transform_val)

    train_loader = DataLoader(
        dataset_train,
        batch_size=batch_size,
        sampler=RandomSampler(dataset_train),
        num_workers=num_workers,
    )
    valid_loader = DataLoader(dataset_valid,
                              batch_size=batch_size,
                              sampler=SequentialSampler(dataset_valid),
                              num_workers=num_workers)

    model = enetv2(enet_type, out_dim=out_dim)
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=init_lr / warmup_factor)
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, n_epochs - warmup_epo)
    scheduler = GradualWarmupScheduler(optimizer, multiplier=warmup_factor, \
                                    total_epoch=warmup_epo, after_scheduler=scheduler_cosine)

    criterion = nn.BCEWithLogitsLoss()

    if use_amp:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          keep_batchnorm_fp32=None,
                                          loss_scale="dynamic")
    model = nn.DataParallel(model)

    print("Number of train samples : {}".format(len(dataset_train)))
    print("Number of validation samples : {}".format(len(dataset_valid)))

    best_model = f'{kernel_type}_fold-{fold}_best.pth'
    save_path = f'../trained_models/{kernel_type}/fold-{fold}/'
    os.makedirs(save_path, exist_ok=True)

    qwk_max = 0.
    for epoch in range(1, n_epochs + 1):
        print(time.ctime(), 'Epoch:', epoch)
        scheduler.step(epoch - 1)

        train_loss = train_epoch(model,
                                 train_loader,
                                 optimizer,
                                 criterion,
                                 use_amp=use_amp)
        val_loss, acc, (qwk, qwk_k, qwk_r) = val_epoch(model, valid_loader,
                                                       criterion, df_valid)

        writer.add_scalars(f'loss', {
            'train': np.mean(train_loss),
            'val': val_loss
        }, epoch)
        writer.add_scalars(f'qwk', {
            'total': qwk,
            'Karolinska': qwk_k,
            'Radboud': qwk_r
        }, epoch)
        content = "{}, Epoch {}, lr: {:.7f}, train loss: {:.5f}," \
                " val loss: {:.5f}, acc: {:.5f}, qwk: {:.5f}".format(
                    time.ctime(), epoch, optimizer.param_groups[0]["lr"],
                    np.mean(train_loss), np.mean(val_loss), acc, qwk
                )
        print(content)

        with open('train_logs/log_{}_fold-{}.txt'.format(kernel_type, fold),
                  'a') as appender:
            appender.write(content + '\n')

        if qwk > qwk_max:
            print('score2 ({:.6f} --> {:.6f}).  Saving current best model ...'.
                  format(qwk_max, qwk))
            torch.save(model.state_dict(), os.path.join(save_path, best_model))
            qwk_max = qwk

        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'qwk_max': qwk_max
            }, os.path.join(save_path,
                            f'{kernel_type}_fold-{fold}_{epoch}.pth'))
Пример #4
0
class TrainingLoop():
    def __init__(self,
                 model_kwargs,
                 train_positive_paths,
                 train_negative_paths,
                 train_unlabeled_paths,
                 val_positive_paths,
                 val_negative_paths,
                 val_unlabeled_paths,
                 data_cache_dir: str,
                 notify_callback: Callable[[Dict[str, Any]],
                                           None] = lambda x: None):
        '''The training loop for background splitting models.'''
        self.data_cache_dir = data_cache_dir
        self.notify_callback = notify_callback

        self._setup_model_kwargs(model_kwargs)

        # Setup dataset
        self._setup_dataset(train_positive_paths, train_negative_paths,
                            train_unlabeled_paths, val_positive_paths,
                            val_negative_paths, val_unlabeled_paths)

        # Setup model
        self._setup_model()

        # Setup optimizer

        # Resume if requested
        resume_from = model_kwargs.get('resume_from', None)
        if resume_from:
            resume_training = model_kwargs.get('resume_training', False)
            self.load_checkpoint(resume_from, resume_training=resume_training)

        self.writer = SummaryWriter(log_dir=model_kwargs['log_dir'])

        # Variables for estimating run-time
        self.train_batch_time = EMA(0)
        self.val_batch_time = EMA(0)
        self.train_batches_per_epoch = (len(self.train_dataloader.dataset) /
                                        self.train_dataloader.batch_size)
        self.val_batches_per_epoch = (len(self.val_dataloader.dataset) /
                                      self.val_dataloader.batch_size)
        self.train_batch_idx = 0
        self.val_batch_idx = 0
        self.train_epoch_loss = 0
        self.train_epoch_main_loss = 0
        self.train_epoch_aux_loss = 0

    def _setup_model_kwargs(self, model_kwargs):
        self.model_kwargs = copy.deepcopy(model_kwargs)
        self.num_workers = NUM_WORKERS
        self.val_frequency = model_kwargs.get('val_frequency', 1)
        self.checkpoint_frequency = model_kwargs.get('checkpoint_frequency', 1)
        self.use_cuda = bool(model_kwargs.get('use_cuda', True))
        assert 'model_dir' in model_kwargs
        self.model_dir = model_kwargs['model_dir']
        assert 'aux_labels' in model_kwargs
        self.aux_weight = float(model_kwargs.get('aux_weight', 0.1))
        assert 'log_dir' in model_kwargs

    def _setup_dataset(self, train_positive_paths, train_negative_paths,
                       train_unlabeled_paths, val_positive_paths,
                       val_negative_paths, val_unlabeled_paths):
        assert self.model_kwargs
        aux_labels = self.model_kwargs['aux_labels']
        image_input_size = self.model_kwargs.get('input_size', 224)
        batch_size = int(self.model_kwargs.get('batch_size', 64))
        num_workers = self.num_workers
        restrict_aux_labels = bool(
            self.model_kwargs.get('restrict_aux_labels', True))
        cache_images_on_disk = self.model_kwargs.get('cache_images_on_disk',
                                                     False)

        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ConvertImageDtype(torch.float32),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        resize_size = int(image_input_size * 1.15)
        resize_size += int(resize_size % 2)
        val_transform = transforms.Compose([
            transforms.Resize(resize_size),
            transforms.CenterCrop(image_input_size),
            transforms.ConvertImageDtype(torch.float32),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        self.train_dataloader = DataLoader(AuxiliaryDataset(
            positive_paths=train_positive_paths,
            negative_paths=train_negative_paths,
            unlabeled_paths=train_unlabeled_paths,
            auxiliary_labels=aux_labels,
            restrict_aux_labels=restrict_aux_labels,
            cache_images_on_disk=cache_images_on_disk,
            data_cache_dir=self.data_cache_dir,
            transform=train_transform),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=num_workers)
        self.val_dataloader = DataLoader(AuxiliaryDataset(
            positive_paths=val_positive_paths,
            negative_paths=val_negative_paths,
            unlabeled_paths=val_unlabeled_paths,
            auxiliary_labels=aux_labels,
            restrict_aux_labels=restrict_aux_labels,
            cache_images_on_disk=cache_images_on_disk,
            data_cache_dir=self.data_cache_dir,
            transform=val_transform),
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=num_workers)

    def _setup_model(self):
        num_classes = 2
        num_aux_classes = self.train_dataloader.dataset.num_auxiliary_classes
        freeze_backbone = self.model_kwargs.get('freeze_backbone', False)
        self.model_kwargs['num_aux_classes'] = num_aux_classes
        self.model = Model(num_main_classes=num_classes,
                           num_aux_classes=num_aux_classes,
                           freeze_backbone=freeze_backbone)
        if self.model_kwargs.get('aux_labels_type', None) == "imagenet":
            # Initialize auxiliary head to imagenet fc
            self.model.auxiliary_head.weight = self.model.backbone.fc.weight
            self.model.auxiliary_head.bias = self.model.backbone.fc.bias
        if self.use_cuda:
            self.model = self.model.cuda()
        self.model = nn.DataParallel(self.model)
        self.main_loss = nn.CrossEntropyLoss()
        self.auxiliary_loss = nn.CrossEntropyLoss()
        self.start_epoch = 0
        self.end_epoch = self.model_kwargs.get('epochs_to_run', 1)
        self.current_epoch = 0
        self.global_train_batch_idx = 0
        self.global_val_batch_idx = 0

        lr = float(self.model_kwargs.get('initial_lr', 0.01))
        endlr = float(self.model_kwargs.get('endlr', 0.0))
        optim_params = dict(
            lr=lr,
            momentum=float(self.model_kwargs.get('momentum', 0.9)),
            weight_decay=float(self.model_kwargs.get('weight_decay', 0.0001)),
        )
        self.optimizer = optim.SGD(self.model.parameters(), **optim_params)
        max_epochs = int(self.model_kwargs.get('max_epochs', 90))
        warmup_epochs = int(self.model_kwargs.get('warmup_epochs', 0))
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,
                                                               max_epochs -
                                                               warmup_epochs,
                                                               eta_min=endlr)
        self.optimizer_scheduler = GradualWarmupScheduler(
            optimizer=self.optimizer,
            multiplier=1.0,
            warmup_epochs=warmup_epochs,
            after_scheduler=scheduler)

    def _notify(self):
        epochs_left = self.end_epoch - self.current_epoch - 1
        num_train_batches_left = (
            epochs_left * self.train_batches_per_epoch +
            max(0, self.train_batches_per_epoch - self.train_batch_idx - 1))
        num_val_batches_left = (
            (1 + round(epochs_left / self.val_frequency)) *
            self.val_batches_per_epoch +
            max(0, self.val_batches_per_epoch - self.val_batch_idx - 1))
        time_left = (num_train_batches_left * self.train_batch_time.value +
                     num_val_batches_left * self.val_batch_time.value)
        self.notify_callback(**{"training_time_left": time_left})

    def setup_resume(self, train_positive_paths, train_negative_paths,
                     train_unlabeled_paths, val_positive_paths,
                     val_negative_paths, val_unlabeled_paths):
        self._setup_dataset(train_positive_paths, train_negative_paths,
                            train_unlabeled_paths, val_positive_paths,
                            val_negative_paths, val_unlabeled_paths)
        self.start_epoch = self.end_epoch
        self.current_epoch = self.start_epoch
        self.end_epoch = self.start_epoch + self.model_kwargs.get(
            'epochs_to_run', 1)

    def load_checkpoint(self, path: str, resume_training: bool = False):
        checkpoint_state = torch.load(path)
        self.model.load_state_dict(checkpoint_state['state_dict'])
        if resume_training:
            self.global_train_batch_idx = checkpoint_state[
                'global_train_batch_idx']
            self.global_val_batch_idx = checkpoint_state[
                'global_val_batch_idx']
            self.start_epoch = checkpoint_state['epoch'] + 1
            self.current_epoch = self.start_epoch
            self.end_epoch = (self.start_epoch +
                              self.model_kwargs.get('epochs_to_run', 1))
            self.optimizer.load_state_dict(checkpoint_state['optimizer'])
            self.optimizer_scheduler.load_state_dict(
                checkpoint_state['optimizer_scheduler'])
            # Copy tensorboard state
            prev_log_dir = checkpoint_state['model_kwargs']['log_dir']
            curr_log_dir = self.model_kwargs['log_dir']
            shutil.copytree(prev_log_dir, curr_log_dir)

    def save_checkpoint(self, epoch, checkpoint_path: str):
        kwargs = dict(self.model_kwargs)
        del kwargs['aux_labels']
        state = dict(
            global_train_batch_idx=self.global_train_batch_idx,
            global_val_batch_idx=self.global_val_batch_idx,
            model_kwargs=kwargs,
            epoch=epoch,
            state_dict=self.model.state_dict(),
            optimizer=self.optimizer.state_dict(),
            optimizer_scheduler=self.optimizer_scheduler.state_dict(),
        )
        os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
        torch.save(state, checkpoint_path)

    def _validate(self, dataloader):
        self.model.eval()
        loss_value = 0
        main_gts = []
        aux_gts = []
        main_preds = []
        aux_preds = []
        for batch_idx, (images, main_labels,
                        aux_labels) in enumerate(dataloader):
            batch_start = time.perf_counter()
            self.val_batch_idx = batch_idx
            if self.use_cuda:
                images = images.cuda()
                main_labels = main_labels.cuda()
                aux_labels = aux_labels.cuda()
            main_logits, aux_logits = self.model(images)
            valid_main_labels = main_labels != -1
            valid_aux_labels = aux_labels != -1
            main_loss_value = self.main_loss(main_logits[valid_main_labels],
                                             main_labels[valid_main_labels])
            aux_loss_value = self.aux_weight * self.auxiliary_loss(
                aux_logits[valid_aux_labels], aux_labels[valid_aux_labels])
            loss_value = torch.zeros_like(main_loss_value)
            if valid_main_labels.sum() > 0:
                loss_value += main_loss_value
            if valid_aux_labels.sum() > 0:
                loss_value += aux_loss_value
            loss_value = loss_value.item()

            if valid_main_labels.sum() > 0:
                main_pred = F.softmax(main_logits[valid_main_labels])
                main_preds += list(
                    main_pred.argmax(dim=1)[valid_main_labels].cpu().numpy())
                main_gts += list(main_labels[valid_main_labels].cpu().numpy())
            if valid_aux_labels.sum() > 0:
                aux_pred = F.softmax(main_logits[valid_main_labels])
                aux_preds += list(
                    aux_pred.argmax(dim=1)[valid_aux_labels].cpu().numpy())
                aux_gts += list(aux_labels[valid_aux_labels].cpu().numpy())
            batch_end = time.perf_counter()
            self.val_batch_time += (batch_end - batch_start)
            self.global_val_batch_idx += 1
        # Compute F1 score
        if len(dataloader) > 0:
            loss_value /= (len(dataloader) + 1e-10)
            main_prec, main_recall, main_f1, _ = \
                sklearn.metrics.precision_recall_fscore_support(
                    main_gts, main_preds, average='binary')
            aux_prec, aux_recall, aux_f1, _ = \
                sklearn.metrics.precision_recall_fscore_support(
                    aux_gts, aux_preds, average='micro')
        else:
            loss_value = 0
            main_prec = -1
            main_recall = -1
            main_f1 = -1
            aux_prec = -1
            aux_recall = -1
            aux_f1 = -1

        summary_data = [
            ('loss', loss_value),
            ('f1/main_head', main_f1),
            ('prec/main_head', main_prec),
            ('recall/main_head', main_recall),
            ('f1/aux_head', aux_f1),
            ('prec/aux_head', aux_prec),
            ('recall/aux_head', aux_recall),
        ]
        for k, v in [('val/epoch/' + tag, v) for tag, v in summary_data]:
            self.writer.add_scalar(k, v, self.current_epoch)

    def validate(self):
        self._validate(self.val_dataloader)

    def train(self):
        self.model.train()
        logger.info('Starting train epoch')
        load_start = time.perf_counter()
        self.train_epoch_loss = 0
        self.train_epoch_main_loss = 0
        self.train_epoch_aux_loss = 0
        main_gts = []
        aux_gts = []
        main_logits_all = []
        main_preds = []
        aux_preds = []
        for batch_idx, (images, main_labels,
                        aux_labels) in enumerate(self.train_dataloader):
            load_end = time.perf_counter()
            batch_start = time.perf_counter()
            self.train_batch_idx = batch_idx
            logger.debug('Train batch')
            if self.use_cuda:
                images = images.cuda()
                main_labels = main_labels.cuda()
                aux_labels = aux_labels.cuda()

            main_logits, aux_logits = self.model(images)
            # Compute loss
            valid_main_labels = main_labels != -1
            valid_aux_labels = aux_labels != -1

            main_loss_value = self.main_loss(main_logits[valid_main_labels],
                                             main_labels[valid_main_labels])
            aux_loss_value = self.aux_weight * self.auxiliary_loss(
                aux_logits[valid_aux_labels], aux_labels[valid_aux_labels])

            loss_value = torch.zeros_like(main_loss_value)
            if valid_main_labels.sum() > 0:
                loss_value += main_loss_value
            if valid_aux_labels.sum() > 0:
                loss_value += aux_loss_value

            self.train_epoch_loss += loss_value.item()
            if torch.sum(valid_main_labels) > 0:
                self.train_epoch_main_loss += main_loss_value.item()
            if torch.sum(valid_aux_labels) > 0:
                self.train_epoch_aux_loss += aux_loss_value.item()
            # Update gradients
            self.optimizer.zero_grad()
            loss_value.backward()
            self.optimizer.step()

            if valid_main_labels.sum() > 0:
                main_pred = F.softmax(main_logits[valid_main_labels], dim=1)
                main_logits_all += list(
                    main_logits[valid_main_labels].detach().cpu().numpy())
                main_preds += list(
                    main_pred[valid_main_labels].argmax(dim=1).cpu().numpy())
                main_gts += list(main_labels[valid_main_labels].cpu().numpy())
            if valid_aux_labels.sum() > 0:
                aux_pred = F.softmax(aux_logits[valid_aux_labels], dim=1)
                aux_preds += list(
                    aux_pred[valid_aux_labels].argmax(dim=1).cpu().numpy())
                aux_gts += list(aux_labels[valid_aux_labels].cpu().numpy())

            batch_end = time.perf_counter()
            total_batch_time = (batch_end - batch_start)
            total_load_time = (load_end - load_start)
            self.train_batch_time += total_batch_time + total_load_time
            logger.debug(f'Train batch time: {self.train_batch_time.value}, '
                         f'this batch time: {total_batch_time}, '
                         f'this load time: {total_load_time}, '
                         f'batch epoch loss: {loss_value.item()}, '
                         f'main loss: {main_loss_value.item()}, '
                         f'aux loss: {aux_loss_value.item()}')
            summary_data = [
                ('loss', loss_value.item()),
                ('loss/main_head', main_loss_value.item()),
                ('loss/aux_head', aux_loss_value.item()),
            ]
            for k, v in [('train/batch/' + tag, v) for tag, v in summary_data]:
                self.writer.add_scalar(k, v, self.global_train_batch_idx)

            self._notify()
            self.global_train_batch_idx += 1
            load_start = time.perf_counter()

        model_lr = self.optimizer.param_groups[-1]['lr']
        self.optimizer_scheduler.step()
        logger.debug(f'Train epoch loss: {self.train_epoch_loss}, '
                     f'main loss: {self.train_epoch_main_loss}, '
                     f'aux loss: {self.train_epoch_aux_loss}')
        main_prec, main_recall, main_f1, _ = \
            sklearn.metrics.precision_recall_fscore_support(
                main_gts, main_preds, average='binary')
        aux_prec, aux_recall, aux_f1, _ = \
            sklearn.metrics.precision_recall_fscore_support(
                aux_gts, aux_preds, average='micro')
        logger.debug(
            f'Train epoch main: {main_prec}, {main_recall}, {main_f1}, '
            f'aux: {aux_prec}, {aux_recall}, {aux_f1}'
            f'main loss: {self.train_epoch_main_loss}, '
            f'aux loss: {self.train_epoch_aux_loss}')
        summary_data = [('lr', model_lr), ('loss', self.train_epoch_loss),
                        ('loss/main_head', self.train_epoch_main_loss),
                        ('loss/aux_head', self.train_epoch_aux_loss),
                        ('f1/main_head', main_f1),
                        ('prec/main_head', main_prec),
                        ('recall/main_head', main_recall),
                        ('f1/aux_head', aux_f1), ('prec/aux_head', aux_prec),
                        ('recall/aux_head', aux_recall)]
        for k, v in [('train/epoch/' + tag, v) for tag, v in summary_data]:
            self.writer.add_scalar(k, v, self.current_epoch)

        if len(main_logits_all):
            self.writer.add_histogram(
                'train/epoch/softmax/main_head',
                scipy.special.softmax(main_logits_all, axis=1)[:, 1])

    def run(self):
        self.last_checkpoint_path = None
        for i in range(self.start_epoch, self.end_epoch):
            logger.info(f'Train: Epoch {i}')
            self.current_epoch = i
            self.train()
            if i % self.val_frequency == 0 or i == self.end_epoch - 1:
                logger.info(f'Validate: Epoch {i}')
                self.validate()
            if i % self.checkpoint_frequency == 0 or i == self.end_epoch - 1:
                logger.info(f'Checkpoint: Epoch {i}')
                self.last_checkpoint_path = os.path.join(
                    self.model_dir, f'checkpoint_{i:03}.pth')
                self.save_checkpoint(i, self.last_checkpoint_path)
        return self.last_checkpoint_path
Пример #5
0
def train(args, train_dataset, model):
    tb_writer = SummaryWriter()
    result_writer = ResultWriter(args.eval_results_dir)

    # Sampling 빈도 설정
    if args.sample_criteria is None:
        # 긍/부정 상황 & 투구구질 비율대로 Random sampling
        sampler = RandomSampler(train_dataset)
    else:
        # <투구구질> 기준(7 가지)으로 sampling 빈도 동등하게 조절
        if args.sample_criteria == "pitcher":
            counts = train_dataset.pitch_counts
            logger.info("  Counts of each ball type : %s", counts)
            pitch_contiguous = [
                i for p in train_dataset.origin_pitch for i, j in enumerate(p)
                if j == 1
            ]
            weights = [
                0 if p == 5 or p == 6 else 1.0 / counts[p]
                for p in pitch_contiguous
            ]
            sampler = WeightedRandomSampler(weights,
                                            len(train_dataset),
                                            replacement=True)
        # <긍,부정> 기준(2 가지)으로 sampling 빈도 동등하게 조절
        elif args.sample_criteria == "batter":
            counts = train_dataset.label_counts
            logger.info("  Counts of each label type : %s", counts)
            weights = [1.0 / counts[l] for l in train_dataset.label]
            sampler = WeightedRandomSampler(weights,
                                            len(train_dataset),
                                            replacement=True)
        # <투구구질 & 긍,부정> 기준(14 가지)으로 sampling 빈도 동등하게 조절
        elif args.sample_criteria == "both":
            counts = train_dataset.pitch_and_label_count
            logger.info("  Counts of each both type : %s", counts)
            pitch_contiguous = [
                i for p in train_dataset.origin_pitch for i, j in enumerate(p)
                if j == 1
            ]
            weights = [
                0 if p == 5 or p == 6 else 1.0 / counts[(p, l)]
                for p, l in zip(pitch_contiguous, train_dataset.label)
            ]
            sampler = WeightedRandomSampler(weights,
                                            len(train_dataset),
                                            replacement=True)
        else:
            sampler = RandomSampler(train_dataset)

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.train_batch_size,
        sampler=sampler,
    )
    t_total = len(train_dataloader) * args.num_train_epochs
    args.warmup_step = int(args.warmup_percent * t_total)

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = [
        "bias",
        "layernorm.weight",
    ]  # LayerNorm.weight -> layernorm.weight (model_parameter name)
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]

    optimizer = optim.Adam(optimizer_grouped_parameters,
                           lr=args.learning_rate,
                           eps=args.adam_epsilon)
    if args.warmup_step != 0:
        scheduler_cosine = CosineAnnealingLR(optimizer, t_total)
        scheduler = GradualWarmupScheduler(optimizer,
                                           1,
                                           args.warmup_step,
                                           after_scheduler=scheduler_cosine)
        # scheduler_plateau = ReduceLROnPlateau(optimizer, "min")
        # scheduler = GradualWarmupScheduler(
        #     optimizer, 1, args.warmup_step, after_scheduler=scheduler_plateau
        # )
    else:
        scheduler = CosineAnnealingLR(optimizer, t_total)
        # scheduler = ReduceLROnPlateau(optimizer, "min")

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)
    m = torch.nn.Sigmoid()
    loss_fct = torch.nn.BCELoss()

    # Train!
    logger.info("***** Running Baseball Transformer *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Warmup Steps = %d", args.warmup_step)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.train_batch_size)
    logger.info("  Total train batch size = %d", args.train_batch_size)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    tr_loss, logging_loss, logging_val_loss = 0.0, 0.0, 0.0

    best_pitch_micro_f1, best_pitch_macro_f1, = 0, 0
    best_loss = 1e10

    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
    )
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):

            (
                pitcher_discrete,
                pitcher_continuous,
                batter_discrete,
                batter_continuous,
                state_discrete,
                state_continuous,
                pitch,
                hit,
                label,
                masked_pitch,
                origin_pitch,
            ) = list(map(lambda x: x.to(args.device), batch))
            model.train()

            # sentiment input
            pitching_score = model(
                pitcher_discrete,
                pitcher_continuous,
                batter_discrete,
                batter_continuous,
                state_discrete,
                state_continuous,
                label,
                args.concat if args.concat else 0,
            )

            pitching_score = pitching_score.contiguous()
            pitch = pitch.contiguous()
            # with sigmoid(m)
            loss = loss_fct(m(pitching_score), pitch)

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()

            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)
            optimizer.step()
            scheduler.step()
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                # Log metrics
                if args.evaluate_during_training:
                    results, f1_results, f1_log, cm_pos, cm_neg = evaluate(
                        args, args.eval_data_file, model)
                    output_eval_file = os.path.join(args.output_dir,
                                                    "eval_results_pos.txt")
                    print_result(output_eval_file, results, f1_log, cm_pos)

                    for key, value in results.items():
                        tb_writer.add_scalar("eval_{}".format(key), value,
                                             global_step)
                    logging_val_loss = results["loss"]
                    # scheduler.step(logging_val_loss)

                tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                     args.logging_steps, global_step)
                logging_loss = tr_loss
                if best_loss > results["loss"]:
                    best_pitch_micro_f1 = results["pitch_micro_f1"]
                    best_pitch_macro_f1 = results["pitch_macro_f1"]
                    best_loss = results["loss"]

                    output_dir = os.path.join(args.output_dir, "best_model/")
                    os.makedirs(output_dir, exist_ok=True)
                    torch.save(model.state_dict(),
                               os.path.join(output_dir, "pytorch_model.bin"))
                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving best model to %s", output_dir)

                    result_path = os.path.join(output_dir, "best_results.txt")
                    print_result(result_path,
                                 results,
                                 f1_log,
                                 cm_pos,
                                 off_logger=True)

                    results.update(dict(f1_results))
                    result_writer.update(args, **results)

                logger.info("  best pitch micro f1 : %s", best_pitch_micro_f1)
                logger.info("  best pitch macro f1 : %s", best_pitch_macro_f1)
                logger.info("  best loss : %s", best_loss)

            if args.save_steps > 0 and global_step % args.save_steps == 0:
                checkpoint_prefix = "checkpoint"
                # Save model checkpoint
                output_dir = os.path.join(
                    args.output_dir, "{}-{}".format(checkpoint_prefix,
                                                    global_step))
                os.makedirs(output_dir, exist_ok=True)
                torch.save(model.state_dict(),
                           os.path.join(output_dir, "pytorch_model.bin"))
                torch.save(args, os.path.join(output_dir, "training_args.bin"))
                logger.info("Saving model checkpoint to %s", output_dir)

                rotate_checkpoints(args, checkpoint_prefix)

                torch.save(optimizer.state_dict(),
                           os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(),
                           os.path.join(output_dir, "scheduler.pt"))
                logger.info("Saving optimizer and scheduler states to %s",
                            output_dir)

    tb_writer.close()

    return global_step, tr_loss / global_step
Пример #6
0
def main(MODEL_TYPE, LEARNING_RATE, CROP_SIZE, UPSCALE_FACTOR, NUM_EPOCHS, BATCH_SIZE, IMAGE_DIR, LAST_EPOCH, MODEL_NAME='', TV_LOSS_RATE=1e-3):
    
    global net, history, lr_img, hr_img, test_lr_img, test_sr_img, valid_hr_img, valid_sr_img, valid_lr_img, scheduler, optimizer
    
    train_set    = TrainDatasetFromFolder(IMAGE_DIR, crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
    train_loader = DataLoader(dataset=train_set, num_workers=8, batch_size=BATCH_SIZE, shuffle=True)
    n_iter = (len(train_set) // BATCH_SIZE + 1) * NUM_EPOCHS

    net = eval(f"{MODEL_TYPE}({UPSCALE_FACTOR})")
    criterion = TotalLoss(TV_LOSS_RATE)
    optimizer = optim.RAdam(net.parameters(), lr=LEARNING_RATE)
    scheduler = schedule.StepLR(optimizer, int(n_iter * 0.3), gamma=0.5, last_epoch=LAST_EPOCH)
    if LAST_EPOCH == -1:
        scheduler = GradualWarmupScheduler(optimizer, 1, n_iter // 50, after_scheduler=scheduler)
        
    # plot_scheduler(scheduler, n_iter)

    if MODEL_NAME:
        net.load_state_dict(torch.load('epochs/' + MODEL_NAME))
        print(f"# Loaded model: [{MODEL_NAME}]")
    
    print(f'# {MODEL_TYPE} parameters:', sum(param.numel() for param in net.parameters()))
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net.to(device)
    criterion.to(device)
    
    tta_transform = tta.Compose([
        tta.HorizontalFlip(),
        tta.VerticalFlip()
    ])
    
    # Train model:
    tta_net = tta.SegmentationTTAWrapper(net, tta_transform)
    history = []
    img_test  = plt.imread(r'data\testing_lr_images\09.png')
    img_valid = plt.imread(r'data\valid_hr_images\t20.png')
    test_lr_img = torch.from_numpy(img_test.transpose(2, 0, 1)).unsqueeze(0).to(device)
    valid_hr_img = ToTensor()(img_valid)
    valid_lr_img = valid_hr_transform(img_valid.shape, UPSCALE_FACTOR)(img_valid).to(device)
    for epoch in range(1, NUM_EPOCHS + 1):
        train_bar = tqdm(train_loader)
        running_results = {'batch_sizes': 0, 'loss': 0, 'psnr': 0}
        
        # Train a epoch:
        net.train()
        for lr_img, hr_img in train_bar:
            running_results['batch_sizes'] += BATCH_SIZE
            optimizer.zero_grad()
            
            lr_img, hr_img = MixUp()(lr_img, hr_img, lr_img, hr_img)
            lr_img = lr_img.type(torch.FloatTensor).to(device)
            hr_img = hr_img.type(torch.FloatTensor).to(device)
            sr_img = tta_net(lr_img)
            
            loss = criterion(sr_img, hr_img)
            loss.backward()
            optimizer.step()
            scheduler.step()

            running_results['loss'] += loss.item() * BATCH_SIZE
            running_results['psnr'] += psnr(hr_img, sr_img) * BATCH_SIZE
            train_bar.set_description(desc='[%d/%d] Loss: %.4f, PSNR: %.4f' % (
                epoch, NUM_EPOCHS,
                running_results['loss'] / running_results['batch_sizes'],
                running_results['psnr'] / running_results['batch_sizes']
            ))
    
        # Save model parameters:
        psnr_now = running_results['psnr'] / running_results['batch_sizes']
        filename = f'epochs/{MODEL_TYPE}_x%d_epoch=%d_PSNR=%.4f.pth' % (UPSCALE_FACTOR, epoch, psnr_now)
        torch.save(net.state_dict(), filename)
        history.append(running_results)
        
        # Test model:
        if epoch % 5 == 0:
            with torch.no_grad():
                net.eval()
                
                # Plot up-scaled testing image:
                test_sr_img = net(test_lr_img)
                plot_hr_lr(test_sr_img, test_lr_img)
                
                # Compute PSNR of validation image:
                valid_sr_img = net(valid_lr_img)
                psnr_valid = psnr(valid_hr_img, valid_sr_img)
                
                # Print PSNR:
                print('\n' + '-' * 50)
                print(f"PSNR of Validation = {psnr_valid}")
                print('-' * 50 + '\n')
                
    torch.save(optimizer.state_dict(), f'optimizer_{MODEL_TYPE}_epoch={NUM_EPOCHS}.pth')
    torch.save(scheduler.state_dict(), f'scheduler_{MODEL_TYPE}_epoch={NUM_EPOCHS}.pth')
def main_worker(gpu, ngpus, args):
    global best_loss

    print(f'Starting process on GPU: {gpu}')
    dist.init_process_group(backend='nccl',
                            init_method=f'tcp://localhost:{args.port}',
                            world_size=ngpus,
                            rank=gpu)
    total_batch_size = args.batch_size
    args.batch_size = args.batch_size // ngpus

    train_dataset, val_dataset, n_classes = get_datasets(
        args.dataset, args.task)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               num_workers=16,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             num_workers=16,
                                             pin_memory=True,
                                             drop_last=True,
                                             sampler=val_sampler)

    if args.task == 'context_encoder':
        model = ContextEncoder(args.dataset, n_classes)
    elif args.task == 'rotation':
        model = RotationPrediction(args.dataset, n_classes)
    elif args.task == 'cpc':
        model = CPC(args.dataset, n_classes)
    elif args.task == 'simclr':
        model = SimCLR(args.dataset, n_classes, dist)
    else:
        raise Exception('Invalid task:', args.task)
    args.metrics = model.metrics
    args.metrics_fmt = model.metrics_fmt

    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(gpu)
    model.cuda(gpu)

    args.gpu = gpu

    linear_classifier = model.construct_classifier().cuda(gpu)
    linear_classifier = torch.nn.parallel.DistributedDataParallel(
        linear_classifier, device_ids=[gpu], find_unused_parameters=True)
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[gpu], find_unused_parameters=True)

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=True)
        optimizer_linear = torch.optim.SGD(linear_classifier.parameters(),
                                           lr=args.lr,
                                           momentum=args.momentum,
                                           nesterov=True)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     betas=(args.momentum, 0.999),
                                     weight_decay=args.weight_decay)
        optimizer_linear = torch.optim.Adam(linear_classifier.parameters(),
                                            lr=args.lr,
                                            betas=(args.momentum, 0.999))
    elif args.optimizer == 'lars':
        optimizer = LARS(model.parameters(),
                         lr=args.lr,
                         momentum=args.momentum,
                         weight_decay=args.weight_decay)
        optimizer_linear = LARS(linear_classifier.parameters(),
                                lr=args.lr,
                                momentum=args.momentum)
    else:
        raise Exception('Unsupported optimizer', args.optimizer)

    # Minimize SSL task loss, maximize linear classification accuracy
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, 0, -1)
    scheduler_linear = lr_scheduler.CosineAnnealingLR(optimizer_linear,
                                                      args.epochs, 0, -1)
    if args.warmup_epochs > 0:
        scheduler = GradualWarmupScheduler(optimizer,
                                           multiplier=total_batch_size / 256.,
                                           total_epoch=args.warmup_epochs,
                                           after_scheduler=scheduler)
        scheduler_linear = GradualWarmupScheduler(
            optimizer,
            multiplier=total_batch_size / 256.,
            total_epoch=args.warmup_epochs,
            after_scheduler=scheduler_linear)

    for epoch in range(args.epochs):
        train_sampler.set_epoch(epoch)

        train(train_loader, model, linear_classifier, optimizer,
              optimizer_linear, epoch, args)

        val_loss, val_acc = validate(val_loader, model, linear_classifier,
                                     args, dist)

        scheduler.step()
        scheduler_linear.step()

        if dist.get_rank() == 0:
            is_best = val_loss < best_loss
            best_loss = min(val_loss, best_loss)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'state_dict_linear': linear_classifier.state_dict(),
                    'optimizer_linear': optimizer_linear.state_dict(),
                    'schedular_linear': scheduler_linear.state_dict(),
                    'best_loss': best_loss,
                    'best_acc': val_acc
                }, is_best, args)