Esempio n. 1
0
def train_net(configs):
    model = build_model(configs.backbone,
                        num_classes=configs.Num_Classes,
                        pretrained=configs.Pretrained)
    #print(model)
    optimizer = build_optimizer(model.parameters(), configs)
    criterion = nn.CrossEntropyLoss()
    if configs.cuda:
        device = torch.device("cuda")
        model.to(device)
        criterion.to(device)
    if configs.img_aug:
        imgaug = transforms.Compose([
            transforms.RandomHorizontalFlip(0.5),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=configs.mean, std=configs.std),
        ])
        train_set = datasets.ImageFolder(configs.train_root, transform=imgaug)
        train_loader = data.DataLoader(train_set,
                                       batch_size=configs.Train.batch_size,
                                       shuffle=configs.shuffle,
                                       num_workers=configs.num_workers,
                                       pin_memory=True)
    else:
        train_set = datasets.ImageFolder(configs.train_root, transform=None)
        train_loader = data.Dataloader(train_set,
                                       batch_size=configs.Train.batch_size,
                                       shuffle=configs.shuffle,
                                       num_workers=configs.num_workers,
                                       pin_memory=True)
    for epoch in range(configs.Train.nepochs):
        if epoch > 0 and epoch // 2 == 0:
            adjust_lr(optimizer, configs)
        for idx, (img, target) in enumerate(train_loader):
            if configs.cuda:
                device = torch.device("cuda")
                img = img.to(device)
                target = target.to(device)
            out = model(img)
            loss = criterion(out, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print("|Epoch|: {}, {}/{}, loss{}".format(
                epoch, idx,
                len(train_set) // configs.Train.batch_size, loss.item()))
        pth_path = "./weights/{}_{}.pth".format(configs.backbone, epoch)
        with open(pth_path, 'wb') as f:
            torch.save(model.state_dict(), f)
            print("Save weights to ---->{}<-----".format(pth_path))

    with open("./weights/final.pth", 'wb') as f:
        torch.save(model.state_dict(), f)
        print("Final model saved!!!")
Esempio n. 2
0
def train(args):
    """Total training procedure.
    """
    print("Use GPU: {} for training".format(args.local_rank))
    if args.local_rank == 0:
        writer = SummaryWriter(log_dir=args.tensorboardx_logdir)
        args.writer = writer
        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(backend='nccl', init_method='env://')
    args.rank = dist.get_rank()
    #print('args.rank: ', dist.get_rank())
    #print('args.get_world_size: ', dist.get_world_size())
    #print('is_nccl_available: ', dist.is_nccl_available())
    args.world_size = dist.get_world_size()
    trainset = ImageDataset(args.data_root, args.train_file)
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        trainset, shuffle=True)
    train_loader = DataLoader(dataset=trainset,
                              batch_size=args.batch_size,
                              sampler=train_sampler,
                              num_workers=0,
                              pin_memory=True,
                              drop_last=False)

    backbone_factory = BackboneFactory(args.backbone_type,
                                       args.backbone_conf_file)
    head_factory = HeadFactory(args.head_type, args.head_conf_file)
    model = FaceModel(backbone_factory, head_factory)
    model = model.to(args.local_rank)
    model.train()
    for ps in model.parameters():
        dist.broadcast(ps, 0)
    optimizer = build_optimizer(model, args.lr)
    lr_schedule = build_scheduler(optimizer, len(train_loader), args.epoches,
                                  args.warm_up_epoches)
    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
    # DDP
    model = torch.nn.parallel.DistributedDataParallel(
        module=model, broadcast_buffers=False, device_ids=[args.local_rank])
    criterion = torch.nn.CrossEntropyLoss().to(args.local_rank)
    loss_meter = AverageMeter()
    model.train()
    ori_epoch = 0
    for epoch in range(ori_epoch, args.epoches):
        train_one_epoch(train_loader, model, optimizer, lr_schedule, criterion,
                        epoch, loss_meter, args)
    dist.destroy_process_group()
Esempio n. 3
0
def main(config):
    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(
        config)
    model = build_model(config)
    optimizer = build_optimizer(config, model)

    if config.TRAIN.MODE == 'epoch':
        trainer = build_epoch_trainer(config)
        lr_scheduler = build_epoch_scheduler(config, optimizer,
                                             len(data_loader_train))
    elif config.TRAIN.MODE == 'step':
        trainer = build_finetune_trainer(config)
        lr_scheduler = build_finetune_scheduler(config, optimizer)

    mixup = True
    if config.AUG.MIXUP > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif config.MODEL.LABEL_SMOOTHING > 0.:
        # close mixup
        mixup = False
        criterion = LabelSmoothingCrossEntropy(
            smoothing=config.MODEL.LABEL_SMOOTHING)
    else:
        # close mixup
        mixup = False
        criterion = torch.nn.CrossEntropyLoss()

    lightning_train_engine = lightning_train_wrapper(model, criterion,
                                                     optimizer, lr_scheduler,
                                                     mixup_fn, mixup)
    lightning_model = lightning_train_engine(config)
    trainer.fit(
        model=lightning_model,
        train_dataloader=data_loader_train,
        val_dataloaders=data_loader_val,
    )
Esempio n. 4
0
def main(config):
    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(
        config)

    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
    model = build_model(config)
    model.cuda()
    logger.info(str(model))

    optimizer = build_optimizer(config, model)
    if config.AMP_OPT_LEVEL != "O0":
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=config.AMP_OPT_LEVEL)
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
    model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    logger.info(f"number of params: {n_parameters}")
    if hasattr(model_without_ddp, 'flops'):
        flops = model_without_ddp.flops()
        logger.info(f"number of GFLOPs: {flops / 1e9}")

    lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))

    if config.AUG.MIXUP > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif config.MODEL.LABEL_SMOOTHING > 0.:
        criterion = LabelSmoothingCrossEntropy(
            smoothing=config.MODEL.LABEL_SMOOTHING)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    max_accuracy = 0.0

    if config.TRAIN.AUTO_RESUME:
        resume_file = auto_resume_helper(config.OUTPUT)
        if resume_file:
            if config.MODEL.RESUME:
                logger.warning(
                    f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}"
                )
            config.defrost()
            config.MODEL.RESUME = resume_file
            config.freeze()
            logger.info(f'auto resuming from {resume_file}')
        else:
            logger.info(
                f'no checkpoint found in {config.OUTPUT}, ignoring auto resume'
            )

    if config.MODEL.RESUME:
        max_accuracy = load_checkpoint(config, model_without_ddp, optimizer,
                                       lr_scheduler, logger)
        acc1, acc5, loss = validate(config, data_loader_val, model)
        logger.info(
            f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%"
        )
        if config.EVAL_MODE:
            return

    if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
        load_pretrained(config, model_without_ddp, logger)
        acc1, acc5, loss = validate(config, data_loader_val, model)
        logger.info(
            f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%"
        )

    if config.THROUGHPUT_MODE:
        throughput(data_loader_val, model, logger)
        return

    logger.info("Start training")
    start_time = time.time()
    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
        data_loader_train.sampler.set_epoch(epoch)

        train_one_epoch(config, model, criterion, data_loader_train, optimizer,
                        epoch, mixup_fn, lr_scheduler)
        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0
                                     or epoch == (config.TRAIN.EPOCHS - 1)):
            save_checkpoint(config, epoch, model_without_ddp, max_accuracy,
                            optimizer, lr_scheduler, logger)

        acc1, acc5, loss = validate(config, data_loader_val, model)
        logger.info(
            f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%"
        )
        max_accuracy = max(max_accuracy, acc1)
        logger.info(f'Max accuracy: {max_accuracy:.2f}%')

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
def main(config):
    dataset_train, _, data_loader_train, _, _ = build_loader(config)

    config.defrost()
    config.DATA.TRAINING_IMAGES = len(dataset_train)
    config.freeze()

    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
    model = build_model(config)
    model.cuda()
    logger.info(str(model))

    optimizer = build_optimizer(config, model)
    if config.AMP_OPT_LEVEL != "O0":
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=config.AMP_OPT_LEVEL)
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
    model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    logger.info(f"number of params: {n_parameters}")
    if hasattr(model_without_ddp, 'flops'):
        flops = model_without_ddp.flops()
        logger.info(f"number of GFLOPs: {flops / 1e9}")

    lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))

    if config.TRAIN.AUTO_RESUME:
        resume_file = auto_resume_helper(config.OUTPUT)
        if resume_file:
            if config.MODEL.RESUME:
                logger.warning(
                    f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}"
                )
            config.defrost()
            config.MODEL.RESUME = resume_file
            config.freeze()
            logger.info(f'auto resuming from {resume_file}')
        else:
            logger.info(
                f'no checkpoint found in {config.OUTPUT}, ignoring auto resume'
            )

    if config.MODEL.RESUME:
        _ = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler,
                            logger)

    logger.info("Start self-supervised pre-training")
    start_time = time.time()
    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
        data_loader_train.sampler.set_epoch(epoch)

        train_one_epoch(config, model, data_loader_train, optimizer, epoch,
                        lr_scheduler)
        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0
                                     or epoch == (config.TRAIN.EPOCHS - 1)):
            save_checkpoint(config, epoch, model_without_ddp, 0.0, optimizer,
                            lr_scheduler, logger)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
Esempio n. 6
0
def main(config):
    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(
        config)

    if config.DISTILL.DO_DISTILL:
        logger.info(
            f"Loading teacher model:{config.MODEL.TYPE}/{config.DISTILL.TEACHER}"
        )
        model_checkpoint_name = os.path.basename(config.DISTILL.TEACHER)
        if 'regnety_160' in model_checkpoint_name:
            model_teacher = create_model(
                'regnety_160',
                pretrained=False,
                num_classes=config.MODEL.NUM_CLASSES,
                global_pool='avg',
            )
            if config.DISTILL.TEACHER.startswith('https'):
                checkpoint = torch.hub.load_state_dict_from_url(
                    config.DISTILL.TEACHER,
                    map_location='cpu',
                    check_hash=True)
            else:
                checkpoint = torch.load(config.DISTILL.TEACHER,
                                        map_location='cpu')
            model_teacher.load_state_dict(checkpoint['model'])
            model_teacher.cuda()
            model_teacher.eval()
            del checkpoint
            torch.cuda.empty_cache()
        else:
            if 'base' in model_checkpoint_name:
                teacher_type = 'base'
            elif 'large' in model_checkpoint_name:
                teacher_type = 'large'
            else:
                teacher_type = None
            model_teacher = load_teacher_model(type=teacher_type)
            model_teacher.cuda()
            model_teacher = torch.nn.parallel.DistributedDataParallel(
                model_teacher,
                device_ids=[config.LOCAL_RANK],
                broadcast_buffers=False)
            checkpoint = torch.load(config.DISTILL.TEACHER, map_location='cpu')
            msg = model_teacher.module.load_state_dict(checkpoint['model'],
                                                       strict=False)
            logger.info(msg)
            del checkpoint
            torch.cuda.empty_cache()

    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
    model = build_model(config)
    model.cuda()
    logger.info(str(model))

    optimizer = build_optimizer(config, model)
    if config.AMP_OPT_LEVEL != "O0":
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=config.AMP_OPT_LEVEL)
    model = torch.nn.parallel.DistributedDataParallel(
        model,
        device_ids=[config.LOCAL_RANK],
        broadcast_buffers=False,
        find_unused_parameters=True)

    model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    logger.info(f"number of params: {n_parameters}")
    if hasattr(model_without_ddp, 'flops'):
        flops = model_without_ddp.flops()
        logger.info(f"number of GFLOPs: {flops / 1e9}")

    lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
    criterion_soft = soft_cross_entropy
    criterion_attn = cal_relation_loss
    criterion_hidden = cal_hidden_relation_loss if config.DISTILL.HIDDEN_RELATION else cal_hidden_loss

    if config.AUG.MIXUP > 0.:
        # smoothing is handled with mixup label transform
        criterion_truth = SoftTargetCrossEntropy()
    elif config.MODEL.LABEL_SMOOTHING > 0.:
        criterion_truth = LabelSmoothingCrossEntropy(
            smoothing=config.MODEL.LABEL_SMOOTHING)
    else:
        criterion_truth = torch.nn.CrossEntropyLoss()

    max_accuracy = 0.0

    if config.TRAIN.AUTO_RESUME:
        resume_file = auto_resume_helper(config.OUTPUT)
        if resume_file:
            if config.MODEL.RESUME:
                logger.warning(
                    f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}"
                )
            config.defrost()
            config.MODEL.RESUME = resume_file
            config.DISTILL.RESUME_WEIGHT_ONLY = False
            config.freeze()
            logger.info(f'auto resuming from {resume_file}')
        else:
            logger.info(
                f'no checkpoint found in {config.OUTPUT}, ignoring auto resume'
            )

    if config.MODEL.RESUME:
        max_accuracy = load_checkpoint(config, model_without_ddp, optimizer,
                                       lr_scheduler, logger)
        acc1, acc5, loss = validate(config, data_loader_val, model, logger)
        logger.info(
            f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%"
        )
        if config.EVAL_MODE:
            return

    if config.THROUGHPUT_MODE:
        throughput(data_loader_val, model, logger)
        return

    logger.info("Start training")
    start_time = time.time()
    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
        data_loader_train.sampler.set_epoch(epoch)

        if config.DISTILL.DO_DISTILL:
            train_one_epoch_distill(config,
                                    model,
                                    model_teacher,
                                    data_loader_train,
                                    optimizer,
                                    epoch,
                                    mixup_fn,
                                    lr_scheduler,
                                    criterion_soft=criterion_soft,
                                    criterion_truth=criterion_truth,
                                    criterion_attn=criterion_attn,
                                    criterion_hidden=criterion_hidden)
        else:
            train_one_epoch(config, model, criterion_truth, data_loader_train,
                            optimizer, epoch, mixup_fn, lr_scheduler)

        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0
                                     or epoch == (config.TRAIN.EPOCHS - 1)):
            save_checkpoint(config, epoch, model_without_ddp, max_accuracy,
                            optimizer, lr_scheduler, logger)

        if epoch % config.EVAL_FREQ == 0 or epoch == config.TRAIN.EPOCHS - 1:
            acc1, acc5, loss = validate(config, data_loader_val, model, logger)
            logger.info(
                f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%"
            )
            max_accuracy = max(max_accuracy, acc1)
            logger.info(f'Max accuracy: {max_accuracy:.2f}%')

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
Esempio n. 7
0
def main(args, config):
    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(
        config)

    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
    model = build_model(config)
    model.cuda()

    if args.use_sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    logger.info(str(model))

    optimizer = build_optimizer(config, model)
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
    loss_scaler = NativeScalerWithGradNormCount()
    model_without_ddp = model.module

    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    logger.info(f"number of params: {n_parameters}")
    if hasattr(model_without_ddp, 'flops'):
        flops = model_without_ddp.flops()
        logger.info(f"number of GFLOPs: {flops / 1e9}")

    lr_scheduler = build_scheduler(
        config, optimizer,
        len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS)

    if config.DISTILL.ENABLED:
        # we disable MIXUP and CUTMIX when knowledge distillation
        assert len(config.DISTILL.TEACHER_LOGITS_PATH
                   ) > 0, "Please fill in DISTILL.TEACHER_LOGITS_PATH"
        criterion = torch.nn.CrossEntropyLoss(reduction='mean')
    else:
        if config.AUG.MIXUP > 0.:
            # smoothing is handled with mixup label transform
            criterion = SoftTargetCrossEntropy()
        elif config.MODEL.LABEL_SMOOTHING > 0.:
            criterion = LabelSmoothingCrossEntropy(
                smoothing=config.MODEL.LABEL_SMOOTHING)
        else:
            criterion = torch.nn.CrossEntropyLoss()

    max_accuracy = 0.0

    if config.TRAIN.AUTO_RESUME:
        resume_file = auto_resume_helper(config.OUTPUT)
        if resume_file:
            if config.MODEL.RESUME:
                logger.warning(
                    f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}"
                )
            config.defrost()
            config.MODEL.RESUME = resume_file
            config.freeze()
            logger.info(f'auto resuming from {resume_file}')
        else:
            logger.info(
                f'no checkpoint found in {config.OUTPUT}, ignoring auto resume'
            )

    if config.MODEL.RESUME:
        max_accuracy = load_checkpoint(config, model_without_ddp, optimizer,
                                       lr_scheduler, loss_scaler, logger)
        acc1, acc5, loss = validate(args, config, data_loader_val, model)
        logger.info(
            f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%"
        )
        if config.EVAL_MODE:
            return

    if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
        load_pretrained(config, model_without_ddp, logger)
        acc1, acc5, loss = validate(args, config, data_loader_val, model)
        logger.info(
            f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%"
        )

    if config.THROUGHPUT_MODE:
        throughput(data_loader_val, model, logger)
        return

    logger.info("Start training")
    start_time = time.time()
    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
        # set_epoch for dataset_train when distillation
        if hasattr(dataset_train, 'set_epoch'):
            dataset_train.set_epoch(epoch)
        data_loader_train.sampler.set_epoch(epoch)

        if config.DISTILL.ENABLED:
            train_one_epoch_distill_using_saved_logits(
                args, config, model, criterion, data_loader_train, optimizer,
                epoch, mixup_fn, lr_scheduler, loss_scaler)
        else:
            train_one_epoch(args, config, model, criterion, data_loader_train,
                            optimizer, epoch, mixup_fn, lr_scheduler,
                            loss_scaler)
        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0
                                     or epoch == (config.TRAIN.EPOCHS - 1)):
            save_checkpoint(config, epoch, model_without_ddp, max_accuracy,
                            optimizer, lr_scheduler, loss_scaler, logger)

        acc1, acc5, loss = validate(args, config, data_loader_val, model)
        logger.info(
            f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%"
        )
        max_accuracy = max(max_accuracy, acc1)
        logger.info(f'Max accuracy: {max_accuracy:.2f}%')

        if is_main_process() and args.use_wandb:
            wandb.log({
                f"val/acc@1": acc1,
                f"val/acc@5": acc5,
                f"val/loss": loss,
                "epoch": epoch,
            })
            wandb.run.summary['epoch'] = epoch
            wandb.run.summary['best_acc@1'] = max_accuracy

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
Esempio n. 8
0
def main():

    args = arg_parser()

    seed_everything(args.seed)

    if cuda.is_available():
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")

    train_df = pd.read_csv(args.train_df_path)
    valid_df = pd.read_csv(args.valid_df_path)
    valid_df_sub = valid_df.sample(
        frac=1.0, random_state=42).reset_index(drop=True)[:40000]
    valid_df_sub1 = valid_df.sample(
        frac=1.0, random_state=52).reset_index(drop=True)[:40000]
    valid_df_sub2 = valid_df.sample(
        frac=1.0, random_state=62).reset_index(drop=True)[:40000]
    del valid_df
    gc.collect()

    if args.DEBUG:
        train_df = train_df[:1000]
        valid_df_sub = valid_df_sub[:1000]
        valid_df_sub1 = valid_df_sub1[:1000]
        valid_df_sub2 = valid_df_sub2[:1000]

    train_loader = build_dataset(args, train_df, is_train=True)
    batch_num = len(train_loader)
    valid_loader = build_dataset(args, valid_df_sub, is_train=False)
    valid_loader1 = build_dataset(args, valid_df_sub1, is_train=False)
    valid_loader2 = build_dataset(args, valid_df_sub2, is_train=False)

    model = build_model(args, device)

    if args.model == 'resnet50':
        save_path = os.path.join(args.PATH, 'weights', f'resnet50_best.pt')
    if args.model == 'resnext':
        save_path = os.path.join(args.PATH, 'weights', f'resnext_best.pt')
    elif args.model == 'xception':
        save_path = os.path.join(args.PATH, 'weights', f'xception_best.pt')
    else:
        NotImplementedError

    optimizer = build_optimizer(args, model)
    scheduler = build_scheduler(args, optimizer, batch_num)

    train_cfg = {
        'train_loader': train_loader,
        'valid_loader': valid_loader,
        'valid_loader1': valid_loader1,
        'valid_loader2': valid_loader2,
        'model': model,
        'criterion': nn.BCEWithLogitsLoss(),
        'optimizer': optimizer,
        'scheduler': scheduler,
        'save_path': save_path,
        'device': device
    }

    train_model(args, train_cfg)
def main():

    args = arg_parser()

    seed_everything(args.seed)

    if cuda.is_available() and not args.cpu:
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")

    print(device)

    if args.model_type == 'cnn':
        if args.preprocess:
            train_df = pd.read_csv('../input/preprocessed_train_df.csv')
            valid_df = pd.read_csv('../input/preprocessed_valid_df.csv')
        else:
            train_df = pd.read_csv('../input/train_df.csv')
            valid_df = pd.read_csv('../input/valid_df.csv')
        valid_sample_num = 40000

    elif args.model_type == 'lrcn':
        if args.preprocess:
            train_df = pd.read_pickle(
                '../input/preprocessed_lrcn_train_df.pkl')
            valid_df = pd.read_pickle(
                '../input/preprocessed_lrcn_train_df.pkl')
        else:
            train_df = pd.read_pickle('../input/lrcn_train_df.pkl')
            valid_df = pd.read_pickle('../input/lrcn_valid_df.pkl')
        valid_sample_num = 15000

    print("number of train data {}".format(len(train_df)))
    print("number of valid data {}\n".format(len(valid_df)))

    train_df = train_df.sample(frac=args.train_sample_num,
                               random_state=args.seed).reset_index(drop=True)
    valid_df_sub = valid_df.sample(
        frac=1.0, random_state=42).reset_index(drop=True)[:valid_sample_num]
    valid_df_sub1 = valid_df.sample(
        frac=1.0, random_state=52).reset_index(drop=True)[:valid_sample_num]
    valid_df_sub2 = valid_df.sample(
        frac=1.0, random_state=62).reset_index(drop=True)[:valid_sample_num]
    del valid_df
    gc.collect()

    if args.DEBUG:
        train_df = train_df[:1000]
        valid_df_sub = valid_df_sub[:1000]
        valid_df_sub1 = valid_df_sub1[:1000]
        valid_df_sub2 = valid_df_sub2[:1000]

    if args.model_type == 'cnn':
        train_transforms = albumentations.Compose([
            HorizontalFlip(p=0.3),
            #   ShiftScaleRotate(p=0.3, scale_limit=0.25, border_mode=1, rotate_limit=25),
            #   RandomBrightnessContrast(p=0.2, brightness_limit=0.25, contrast_limit=0.5),
            #   MotionBlur(p=0.2),
            GaussNoise(p=0.3),
            JpegCompression(p=0.3, quality_lower=50),
            #   Normalize()
        ])
        valid_transforms = albumentations.Compose([
            HorizontalFlip(p=0.2),
            albumentations.OneOf([
                JpegCompression(quality_lower=8, quality_upper=30, p=1.0),
                GaussNoise(p=1.0),
            ],
                                 p=0.22),
            #   Normalize()
        ])
    elif args.model_type == 'lrcn':
        train_transforms = None
        valid_transforms = None

    train_loader = build_dataset(args,
                                 train_df,
                                 transforms=train_transforms,
                                 is_train=True)
    batch_num = len(train_loader)
    valid_loader = build_dataset(args,
                                 valid_df_sub,
                                 transforms=valid_transforms,
                                 is_train=False)
    valid_loader1 = build_dataset(args,
                                  valid_df_sub1,
                                  transforms=valid_transforms,
                                  is_train=False)
    valid_loader2 = build_dataset(args,
                                  valid_df_sub2,
                                  transforms=valid_transforms,
                                  is_train=False)

    model = build_model(args, device)

    if args.model == 'mobilenet_v2':
        save_path = os.path.join(args.PATH, 'weights', f'mobilenet_v2_best.pt')
    elif args.model == 'resnet18':
        save_path = os.path.join(args.PATH, 'weights', f'resnet18_best.pt')
    elif args.model == 'resnet50':
        save_path = os.path.join(args.PATH, 'weights', f'resnet50_best.pt')
    elif args.model == 'resnext':
        save_path = os.path.join(args.PATH, 'weights', f'resnext_best.pt')
    elif args.model == 'xception':
        save_path = os.path.join(args.PATH, 'weights', f'xception_best.pt')
    else:
        NotImplementedError

    if args.model_type == 'lrcn':
        save_path = os.path.join(args.PATH, 'weights', f'lrcn_best.pt')

    optimizer = build_optimizer(args, model)
    scheduler = build_scheduler(args, optimizer, batch_num)

    train_cfg = {
        'train_loader': train_loader,
        'valid_loader': valid_loader,
        'valid_loader1': valid_loader1,
        'valid_loader2': valid_loader2,
        'model': model,
        'criterion': nn.BCEWithLogitsLoss(),
        'optimizer': optimizer,
        'scheduler': scheduler,
        'save_path': save_path,
        'device': device
    }

    train_model(args, train_cfg)