def train_net(configs): model = build_model(configs.backbone, num_classes=configs.Num_Classes, pretrained=configs.Pretrained) #print(model) optimizer = build_optimizer(model.parameters(), configs) criterion = nn.CrossEntropyLoss() if configs.cuda: device = torch.device("cuda") model.to(device) criterion.to(device) if configs.img_aug: imgaug = transforms.Compose([ transforms.RandomHorizontalFlip(0.5), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=configs.mean, std=configs.std), ]) train_set = datasets.ImageFolder(configs.train_root, transform=imgaug) train_loader = data.DataLoader(train_set, batch_size=configs.Train.batch_size, shuffle=configs.shuffle, num_workers=configs.num_workers, pin_memory=True) else: train_set = datasets.ImageFolder(configs.train_root, transform=None) train_loader = data.Dataloader(train_set, batch_size=configs.Train.batch_size, shuffle=configs.shuffle, num_workers=configs.num_workers, pin_memory=True) for epoch in range(configs.Train.nepochs): if epoch > 0 and epoch // 2 == 0: adjust_lr(optimizer, configs) for idx, (img, target) in enumerate(train_loader): if configs.cuda: device = torch.device("cuda") img = img.to(device) target = target.to(device) out = model(img) loss = criterion(out, target) optimizer.zero_grad() loss.backward() optimizer.step() print("|Epoch|: {}, {}/{}, loss{}".format( epoch, idx, len(train_set) // configs.Train.batch_size, loss.item())) pth_path = "./weights/{}_{}.pth".format(configs.backbone, epoch) with open(pth_path, 'wb') as f: torch.save(model.state_dict(), f) print("Save weights to ---->{}<-----".format(pth_path)) with open("./weights/final.pth", 'wb') as f: torch.save(model.state_dict(), f) print("Final model saved!!!")
def train(args): """Total training procedure. """ print("Use GPU: {} for training".format(args.local_rank)) if args.local_rank == 0: writer = SummaryWriter(log_dir=args.tensorboardx_logdir) args.writer = writer if not os.path.exists(args.out_dir): os.makedirs(args.out_dir) torch.cuda.set_device(args.local_rank) dist.init_process_group(backend='nccl', init_method='env://') args.rank = dist.get_rank() #print('args.rank: ', dist.get_rank()) #print('args.get_world_size: ', dist.get_world_size()) #print('is_nccl_available: ', dist.is_nccl_available()) args.world_size = dist.get_world_size() trainset = ImageDataset(args.data_root, args.train_file) train_sampler = torch.utils.data.distributed.DistributedSampler( trainset, shuffle=True) train_loader = DataLoader(dataset=trainset, batch_size=args.batch_size, sampler=train_sampler, num_workers=0, pin_memory=True, drop_last=False) backbone_factory = BackboneFactory(args.backbone_type, args.backbone_conf_file) head_factory = HeadFactory(args.head_type, args.head_conf_file) model = FaceModel(backbone_factory, head_factory) model = model.to(args.local_rank) model.train() for ps in model.parameters(): dist.broadcast(ps, 0) optimizer = build_optimizer(model, args.lr) lr_schedule = build_scheduler(optimizer, len(train_loader), args.epoches, args.warm_up_epoches) model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # DDP model = torch.nn.parallel.DistributedDataParallel( module=model, broadcast_buffers=False, device_ids=[args.local_rank]) criterion = torch.nn.CrossEntropyLoss().to(args.local_rank) loss_meter = AverageMeter() model.train() ori_epoch = 0 for epoch in range(ori_epoch, args.epoches): train_one_epoch(train_loader, model, optimizer, lr_schedule, criterion, epoch, loss_meter, args) dist.destroy_process_group()
def main(config): dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader( config) model = build_model(config) optimizer = build_optimizer(config, model) if config.TRAIN.MODE == 'epoch': trainer = build_epoch_trainer(config) lr_scheduler = build_epoch_scheduler(config, optimizer, len(data_loader_train)) elif config.TRAIN.MODE == 'step': trainer = build_finetune_trainer(config) lr_scheduler = build_finetune_scheduler(config, optimizer) mixup = True if config.AUG.MIXUP > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif config.MODEL.LABEL_SMOOTHING > 0.: # close mixup mixup = False criterion = LabelSmoothingCrossEntropy( smoothing=config.MODEL.LABEL_SMOOTHING) else: # close mixup mixup = False criterion = torch.nn.CrossEntropyLoss() lightning_train_engine = lightning_train_wrapper(model, criterion, optimizer, lr_scheduler, mixup_fn, mixup) lightning_model = lightning_train_engine(config) trainer.fit( model=lightning_model, train_dataloader=data_loader_train, val_dataloaders=data_loader_val, )
def main(config): dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader( config) logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") model = build_model(config) model.cuda() logger.info(str(model)) optimizer = build_optimizer(config, model) if config.AMP_OPT_LEVEL != "O0": model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"number of params: {n_parameters}") if hasattr(model_without_ddp, 'flops'): flops = model_without_ddp.flops() logger.info(f"number of GFLOPs: {flops / 1e9}") lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) if config.AUG.MIXUP > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif config.MODEL.LABEL_SMOOTHING > 0.: criterion = LabelSmoothingCrossEntropy( smoothing=config.MODEL.LABEL_SMOOTHING) else: criterion = torch.nn.CrossEntropyLoss() max_accuracy = 0.0 if config.TRAIN.AUTO_RESUME: resume_file = auto_resume_helper(config.OUTPUT) if resume_file: if config.MODEL.RESUME: logger.warning( f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}" ) config.defrost() config.MODEL.RESUME = resume_file config.freeze() logger.info(f'auto resuming from {resume_file}') else: logger.info( f'no checkpoint found in {config.OUTPUT}, ignoring auto resume' ) if config.MODEL.RESUME: max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) acc1, acc5, loss = validate(config, data_loader_val, model) logger.info( f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" ) if config.EVAL_MODE: return if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): load_pretrained(config, model_without_ddp, logger) acc1, acc5, loss = validate(config, data_loader_val, model) logger.info( f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" ) if config.THROUGHPUT_MODE: throughput(data_loader_val, model, logger) return logger.info("Start training") start_time = time.time() for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): data_loader_train.sampler.set_epoch(epoch) train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler) if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger) acc1, acc5, loss = validate(config, data_loader_val, model) logger.info( f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" ) max_accuracy = max(max_accuracy, acc1) logger.info(f'Max accuracy: {max_accuracy:.2f}%') total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str))
def main(config): dataset_train, _, data_loader_train, _, _ = build_loader(config) config.defrost() config.DATA.TRAINING_IMAGES = len(dataset_train) config.freeze() logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") model = build_model(config) model.cuda() logger.info(str(model)) optimizer = build_optimizer(config, model) if config.AMP_OPT_LEVEL != "O0": model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"number of params: {n_parameters}") if hasattr(model_without_ddp, 'flops'): flops = model_without_ddp.flops() logger.info(f"number of GFLOPs: {flops / 1e9}") lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) if config.TRAIN.AUTO_RESUME: resume_file = auto_resume_helper(config.OUTPUT) if resume_file: if config.MODEL.RESUME: logger.warning( f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}" ) config.defrost() config.MODEL.RESUME = resume_file config.freeze() logger.info(f'auto resuming from {resume_file}') else: logger.info( f'no checkpoint found in {config.OUTPUT}, ignoring auto resume' ) if config.MODEL.RESUME: _ = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) logger.info("Start self-supervised pre-training") start_time = time.time() for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): data_loader_train.sampler.set_epoch(epoch) train_one_epoch(config, model, data_loader_train, optimizer, epoch, lr_scheduler) if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): save_checkpoint(config, epoch, model_without_ddp, 0.0, optimizer, lr_scheduler, logger) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str))
def main(config): dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader( config) if config.DISTILL.DO_DISTILL: logger.info( f"Loading teacher model:{config.MODEL.TYPE}/{config.DISTILL.TEACHER}" ) model_checkpoint_name = os.path.basename(config.DISTILL.TEACHER) if 'regnety_160' in model_checkpoint_name: model_teacher = create_model( 'regnety_160', pretrained=False, num_classes=config.MODEL.NUM_CLASSES, global_pool='avg', ) if config.DISTILL.TEACHER.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( config.DISTILL.TEACHER, map_location='cpu', check_hash=True) else: checkpoint = torch.load(config.DISTILL.TEACHER, map_location='cpu') model_teacher.load_state_dict(checkpoint['model']) model_teacher.cuda() model_teacher.eval() del checkpoint torch.cuda.empty_cache() else: if 'base' in model_checkpoint_name: teacher_type = 'base' elif 'large' in model_checkpoint_name: teacher_type = 'large' else: teacher_type = None model_teacher = load_teacher_model(type=teacher_type) model_teacher.cuda() model_teacher = torch.nn.parallel.DistributedDataParallel( model_teacher, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) checkpoint = torch.load(config.DISTILL.TEACHER, map_location='cpu') msg = model_teacher.module.load_state_dict(checkpoint['model'], strict=False) logger.info(msg) del checkpoint torch.cuda.empty_cache() logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") model = build_model(config) model.cuda() logger.info(str(model)) optimizer = build_optimizer(config, model) if config.AMP_OPT_LEVEL != "O0": model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False, find_unused_parameters=True) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"number of params: {n_parameters}") if hasattr(model_without_ddp, 'flops'): flops = model_without_ddp.flops() logger.info(f"number of GFLOPs: {flops / 1e9}") lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) criterion_soft = soft_cross_entropy criterion_attn = cal_relation_loss criterion_hidden = cal_hidden_relation_loss if config.DISTILL.HIDDEN_RELATION else cal_hidden_loss if config.AUG.MIXUP > 0.: # smoothing is handled with mixup label transform criterion_truth = SoftTargetCrossEntropy() elif config.MODEL.LABEL_SMOOTHING > 0.: criterion_truth = LabelSmoothingCrossEntropy( smoothing=config.MODEL.LABEL_SMOOTHING) else: criterion_truth = torch.nn.CrossEntropyLoss() max_accuracy = 0.0 if config.TRAIN.AUTO_RESUME: resume_file = auto_resume_helper(config.OUTPUT) if resume_file: if config.MODEL.RESUME: logger.warning( f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}" ) config.defrost() config.MODEL.RESUME = resume_file config.DISTILL.RESUME_WEIGHT_ONLY = False config.freeze() logger.info(f'auto resuming from {resume_file}') else: logger.info( f'no checkpoint found in {config.OUTPUT}, ignoring auto resume' ) if config.MODEL.RESUME: max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) acc1, acc5, loss = validate(config, data_loader_val, model, logger) logger.info( f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" ) if config.EVAL_MODE: return if config.THROUGHPUT_MODE: throughput(data_loader_val, model, logger) return logger.info("Start training") start_time = time.time() for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): data_loader_train.sampler.set_epoch(epoch) if config.DISTILL.DO_DISTILL: train_one_epoch_distill(config, model, model_teacher, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, criterion_soft=criterion_soft, criterion_truth=criterion_truth, criterion_attn=criterion_attn, criterion_hidden=criterion_hidden) else: train_one_epoch(config, model, criterion_truth, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler) if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger) if epoch % config.EVAL_FREQ == 0 or epoch == config.TRAIN.EPOCHS - 1: acc1, acc5, loss = validate(config, data_loader_val, model, logger) logger.info( f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" ) max_accuracy = max(max_accuracy, acc1) logger.info(f'Max accuracy: {max_accuracy:.2f}%') total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str))
def main(args, config): dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader( config) logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") model = build_model(config) model.cuda() if args.use_sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) logger.info(str(model)) optimizer = build_optimizer(config, model) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) loss_scaler = NativeScalerWithGradNormCount() model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"number of params: {n_parameters}") if hasattr(model_without_ddp, 'flops'): flops = model_without_ddp.flops() logger.info(f"number of GFLOPs: {flops / 1e9}") lr_scheduler = build_scheduler( config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS) if config.DISTILL.ENABLED: # we disable MIXUP and CUTMIX when knowledge distillation assert len(config.DISTILL.TEACHER_LOGITS_PATH ) > 0, "Please fill in DISTILL.TEACHER_LOGITS_PATH" criterion = torch.nn.CrossEntropyLoss(reduction='mean') else: if config.AUG.MIXUP > 0.: # smoothing is handled with mixup label transform criterion = SoftTargetCrossEntropy() elif config.MODEL.LABEL_SMOOTHING > 0.: criterion = LabelSmoothingCrossEntropy( smoothing=config.MODEL.LABEL_SMOOTHING) else: criterion = torch.nn.CrossEntropyLoss() max_accuracy = 0.0 if config.TRAIN.AUTO_RESUME: resume_file = auto_resume_helper(config.OUTPUT) if resume_file: if config.MODEL.RESUME: logger.warning( f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}" ) config.defrost() config.MODEL.RESUME = resume_file config.freeze() logger.info(f'auto resuming from {resume_file}') else: logger.info( f'no checkpoint found in {config.OUTPUT}, ignoring auto resume' ) if config.MODEL.RESUME: max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger) acc1, acc5, loss = validate(args, config, data_loader_val, model) logger.info( f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" ) if config.EVAL_MODE: return if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): load_pretrained(config, model_without_ddp, logger) acc1, acc5, loss = validate(args, config, data_loader_val, model) logger.info( f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" ) if config.THROUGHPUT_MODE: throughput(data_loader_val, model, logger) return logger.info("Start training") start_time = time.time() for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): # set_epoch for dataset_train when distillation if hasattr(dataset_train, 'set_epoch'): dataset_train.set_epoch(epoch) data_loader_train.sampler.set_epoch(epoch) if config.DISTILL.ENABLED: train_one_epoch_distill_using_saved_logits( args, config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler) else: train_one_epoch(args, config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler) if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger) acc1, acc5, loss = validate(args, config, data_loader_val, model) logger.info( f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%" ) max_accuracy = max(max_accuracy, acc1) logger.info(f'Max accuracy: {max_accuracy:.2f}%') if is_main_process() and args.use_wandb: wandb.log({ f"val/acc@1": acc1, f"val/acc@5": acc5, f"val/loss": loss, "epoch": epoch, }) wandb.run.summary['epoch'] = epoch wandb.run.summary['best_acc@1'] = max_accuracy total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) logger.info('Training time {}'.format(total_time_str))
def main(): args = arg_parser() seed_everything(args.seed) if cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") train_df = pd.read_csv(args.train_df_path) valid_df = pd.read_csv(args.valid_df_path) valid_df_sub = valid_df.sample( frac=1.0, random_state=42).reset_index(drop=True)[:40000] valid_df_sub1 = valid_df.sample( frac=1.0, random_state=52).reset_index(drop=True)[:40000] valid_df_sub2 = valid_df.sample( frac=1.0, random_state=62).reset_index(drop=True)[:40000] del valid_df gc.collect() if args.DEBUG: train_df = train_df[:1000] valid_df_sub = valid_df_sub[:1000] valid_df_sub1 = valid_df_sub1[:1000] valid_df_sub2 = valid_df_sub2[:1000] train_loader = build_dataset(args, train_df, is_train=True) batch_num = len(train_loader) valid_loader = build_dataset(args, valid_df_sub, is_train=False) valid_loader1 = build_dataset(args, valid_df_sub1, is_train=False) valid_loader2 = build_dataset(args, valid_df_sub2, is_train=False) model = build_model(args, device) if args.model == 'resnet50': save_path = os.path.join(args.PATH, 'weights', f'resnet50_best.pt') if args.model == 'resnext': save_path = os.path.join(args.PATH, 'weights', f'resnext_best.pt') elif args.model == 'xception': save_path = os.path.join(args.PATH, 'weights', f'xception_best.pt') else: NotImplementedError optimizer = build_optimizer(args, model) scheduler = build_scheduler(args, optimizer, batch_num) train_cfg = { 'train_loader': train_loader, 'valid_loader': valid_loader, 'valid_loader1': valid_loader1, 'valid_loader2': valid_loader2, 'model': model, 'criterion': nn.BCEWithLogitsLoss(), 'optimizer': optimizer, 'scheduler': scheduler, 'save_path': save_path, 'device': device } train_model(args, train_cfg)
def main(): args = arg_parser() seed_everything(args.seed) if cuda.is_available() and not args.cpu: device = torch.device("cuda:0") else: device = torch.device("cpu") print(device) if args.model_type == 'cnn': if args.preprocess: train_df = pd.read_csv('../input/preprocessed_train_df.csv') valid_df = pd.read_csv('../input/preprocessed_valid_df.csv') else: train_df = pd.read_csv('../input/train_df.csv') valid_df = pd.read_csv('../input/valid_df.csv') valid_sample_num = 40000 elif args.model_type == 'lrcn': if args.preprocess: train_df = pd.read_pickle( '../input/preprocessed_lrcn_train_df.pkl') valid_df = pd.read_pickle( '../input/preprocessed_lrcn_train_df.pkl') else: train_df = pd.read_pickle('../input/lrcn_train_df.pkl') valid_df = pd.read_pickle('../input/lrcn_valid_df.pkl') valid_sample_num = 15000 print("number of train data {}".format(len(train_df))) print("number of valid data {}\n".format(len(valid_df))) train_df = train_df.sample(frac=args.train_sample_num, random_state=args.seed).reset_index(drop=True) valid_df_sub = valid_df.sample( frac=1.0, random_state=42).reset_index(drop=True)[:valid_sample_num] valid_df_sub1 = valid_df.sample( frac=1.0, random_state=52).reset_index(drop=True)[:valid_sample_num] valid_df_sub2 = valid_df.sample( frac=1.0, random_state=62).reset_index(drop=True)[:valid_sample_num] del valid_df gc.collect() if args.DEBUG: train_df = train_df[:1000] valid_df_sub = valid_df_sub[:1000] valid_df_sub1 = valid_df_sub1[:1000] valid_df_sub2 = valid_df_sub2[:1000] if args.model_type == 'cnn': train_transforms = albumentations.Compose([ HorizontalFlip(p=0.3), # ShiftScaleRotate(p=0.3, scale_limit=0.25, border_mode=1, rotate_limit=25), # RandomBrightnessContrast(p=0.2, brightness_limit=0.25, contrast_limit=0.5), # MotionBlur(p=0.2), GaussNoise(p=0.3), JpegCompression(p=0.3, quality_lower=50), # Normalize() ]) valid_transforms = albumentations.Compose([ HorizontalFlip(p=0.2), albumentations.OneOf([ JpegCompression(quality_lower=8, quality_upper=30, p=1.0), GaussNoise(p=1.0), ], p=0.22), # Normalize() ]) elif args.model_type == 'lrcn': train_transforms = None valid_transforms = None train_loader = build_dataset(args, train_df, transforms=train_transforms, is_train=True) batch_num = len(train_loader) valid_loader = build_dataset(args, valid_df_sub, transforms=valid_transforms, is_train=False) valid_loader1 = build_dataset(args, valid_df_sub1, transforms=valid_transforms, is_train=False) valid_loader2 = build_dataset(args, valid_df_sub2, transforms=valid_transforms, is_train=False) model = build_model(args, device) if args.model == 'mobilenet_v2': save_path = os.path.join(args.PATH, 'weights', f'mobilenet_v2_best.pt') elif args.model == 'resnet18': save_path = os.path.join(args.PATH, 'weights', f'resnet18_best.pt') elif args.model == 'resnet50': save_path = os.path.join(args.PATH, 'weights', f'resnet50_best.pt') elif args.model == 'resnext': save_path = os.path.join(args.PATH, 'weights', f'resnext_best.pt') elif args.model == 'xception': save_path = os.path.join(args.PATH, 'weights', f'xception_best.pt') else: NotImplementedError if args.model_type == 'lrcn': save_path = os.path.join(args.PATH, 'weights', f'lrcn_best.pt') optimizer = build_optimizer(args, model) scheduler = build_scheduler(args, optimizer, batch_num) train_cfg = { 'train_loader': train_loader, 'valid_loader': valid_loader, 'valid_loader1': valid_loader1, 'valid_loader2': valid_loader2, 'model': model, 'criterion': nn.BCEWithLogitsLoss(), 'optimizer': optimizer, 'scheduler': scheduler, 'save_path': save_path, 'device': device } train_model(args, train_cfg)