def prepare(args): resume_from_checkpoint = args.resume_from_checkpoint prepare_start_time = time.time() logger.info('global', 'Start preparing.') check_config_dir() logger.info('setting', config_info(), time_report=False) model = Baseline(num_classes=Config.nr_class).cuda() logger.info('setting', model_summary(model), time_report=False) logger.info('setting', str(model), time_report=False) bclassifier = BinaryClassifier(Config.in_planes).cuda() train_transforms = transforms.Compose([ transforms.Resize(Config.input_shape), transforms.RandomApply([ transforms.ColorJitter( brightness=0.3, contrast=0.3, saturation=0.3, hue=0) ], p=0.5), transforms.RandomHorizontalFlip(), transforms.Pad(10), transforms.RandomCrop(Config.input_shape), transforms.ToTensor(), transforms.RandomErasing(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) test_transforms = transforms.Compose([ transforms.Resize(Config.input_shape), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) trainset = Veri776_train(transforms=train_transforms) testset = Veri776_test(transforms=test_transforms) pksampler = PKSampler(trainset, p=Config.P, k=Config.K) train_loader = torch.utils.data.DataLoader(trainset, batch_size=Config.batch_size, sampler=pksampler, num_workers=Config.nr_worker, pin_memory=True) test_loader = torch.utils.data.DataLoader( testset, batch_size=Config.batch_size, sampler=torch.utils.data.SequentialSampler(testset), num_workers=Config.nr_worker, pin_memory=True) weight_decay_setting = parm_list_with_Wdecay(model) weight_decay_setting_bc = parm_list_with_Wdecay(bclassifier) weight_decay_setting += weight_decay_setting_bc optimizer = torch.optim.Adam(weight_decay_setting, lr=Config.lr) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_multi_func) losses = {} losses['cross_entropy_loss'] = torch.nn.CrossEntropyLoss() losses['pair_loss'] = torch.nn.CrossEntropyLoss() for k in losses.keys(): losses[k] = losses[k].cuda() start_epoch = 0 if resume_from_checkpoint and os.path.exists(Config.checkpoint_path): checkpoint = load_checkpoint() start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) # continue training for next the epoch of the checkpoint, or simply start from 1 start_epoch += 1 ret = { 'start_epoch': start_epoch, 'model': model, 'bclassifier': bclassifier, 'train_loader': train_loader, 'test_loader': test_loader, 'optimizer': optimizer, 'scheduler': scheduler, 'losses': losses } prepare_end_time = time.time() time_spent = sec2min_sec(prepare_start_time, prepare_end_time) logger.info( 'global', 'Finish preparing, time spend: {}mins {}s.'.format( time_spent[0], time_spent[1])) return ret