def train_fold(fold, train_eval_data, unsup_data): train_indices, eval_indices = indices_for_fold(fold, train_eval_data) train_dataset = TrainEvalDataset(train_eval_data.iloc[train_indices], transform=train_transform) train_data_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.batch_size, drop_last=True, shuffle=True, num_workers=args.workers, worker_init_fn=worker_init_fn) eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices], transform=eval_transform) eval_data_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=config.batch_size, num_workers=args.workers, worker_init_fn=worker_init_fn) unsup_data = TestDataset(unsup_data, transform=unsup_transform) unsup_data_loader = torch.utils.data.DataLoader( unsup_data, batch_size=config.batch_size // 2, num_workers=args.workers, worker_init_fn=worker_init_fn) model = Model(config.model, NUM_CLASSES) model = model.to(DEVICE) if args.restore_path is not None: model.load_state_dict( torch.load( os.path.join(args.restore_path, 'model_{}.pth'.format(fold)))) optimizer = build_optimizer(config.opt, model.parameters()) if config.sched.type == 'onecycle': scheduler = lr_scheduler_wrapper.StepWrapper( OneCycleScheduler(optimizer, lr=(config.opt.lr / 20, config.opt.lr), beta_range=config.sched.onecycle.beta, max_steps=len(train_data_loader) * config.epochs, annealing=config.sched.onecycle.anneal, peak_pos=config.sched.onecycle.peak_pos, end_pos=config.sched.onecycle.end_pos)) elif config.sched.type == 'step': scheduler = lr_scheduler_wrapper.EpochWrapper( torch.optim.lr_scheduler.StepLR( optimizer, step_size=config.sched.step.step_size, gamma=config.sched.step.decay)) elif config.sched.type == 'cyclic': step_size_up = len( train_data_loader) * config.sched.cyclic.step_size_up step_size_down = len( train_data_loader) * config.sched.cyclic.step_size_down scheduler = lr_scheduler_wrapper.StepWrapper( torch.optim.lr_scheduler.CyclicLR( optimizer, 0., config.opt.lr, step_size_up=step_size_up, step_size_down=step_size_down, mode='triangular2', gamma=config.sched.cyclic.decay**( 1 / (step_size_up + step_size_down)), cycle_momentum=True, base_momentum=0.85, max_momentum=0.95)) elif config.sched.type == 'cawr': scheduler = lr_scheduler_wrapper.StepWrapper( torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=len(train_data_loader), T_mult=2)) elif config.sched.type == 'plateau': scheduler = lr_scheduler_wrapper.ScoreWrapper( torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', factor=config.sched.plateau.decay, patience=config.sched.plateau.patience, verbose=True)) else: raise AssertionError('invalid sched {}'.format(config.sched.type)) best_score = 0 for epoch in range(1, config.epochs + 1): optimizer.train() train_epoch(model=model, optimizer=optimizer, scheduler=scheduler, data_loader=train_data_loader, unsup_data_loader=unsup_data_loader, fold=fold, epoch=epoch) gc.collect() optimizer.eval() metric = eval_epoch(model=model, data_loader=eval_data_loader, fold=fold, epoch=epoch) gc.collect() score = metric['accuracy@1'] scheduler.step_epoch() scheduler.step_score(score) if score > best_score: best_score = score torch.save( model.state_dict(), os.path.join(args.experiment_path, 'model_{}.pth'.format(fold)))
def train_fold(fold, train_eval_data): model = Model(config.model, NUM_CLASSES) model = model.to(DEVICE) if args.restore_path is not None: model.load_state_dict( torch.load( os.path.join(args.restore_path, 'model_{}.pth'.format(fold)))) optimizer = build_optimizer(config.opt, model.parameters()) if config.sched.type == 'onecycle': scheduler = lr_scheduler_wrapper.EpochWrapper( OneCycleScheduler(optimizer, lr=(config.opt.lr / 20, config.opt.lr), beta_range=config.sched.onecycle.beta, max_steps=config.epochs, annealing=config.sched.onecycle.anneal, peak_pos=config.sched.onecycle.peak_pos, end_pos=config.sched.onecycle.end_pos)) elif config.sched.type == 'cyclic': step_size_up = len( train_data_loader) * config.sched.cyclic.step_size_up step_size_down = len( train_data_loader) * config.sched.cyclic.step_size_down scheduler = lr_scheduler_wrapper.StepWrapper( torch.optim.lr_scheduler.CyclicLR( optimizer, 0., config.opt.lr, step_size_up=step_size_up, step_size_down=step_size_down, mode='exp_range', gamma=config.sched.cyclic.decay**( 1 / (step_size_up + step_size_down)), cycle_momentum=True, base_momentum=0.75, max_momentum=0.95)) elif config.sched.type == 'cawr': scheduler = lr_scheduler_wrapper.StepWrapper( torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=len(train_data_loader), T_mult=2)) elif config.sched.type == 'plateau': scheduler = lr_scheduler_wrapper.ScoreWrapper( torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', factor=config.sched.plateau.decay, patience=config.sched.plateau.patience, verbose=True)) else: raise AssertionError('invalid sched {}'.format(config.sched.type)) best_score = 0 for epoch in range(1, config.epochs + 1): train_indices, eval_indices = indices_for_fold(fold, train_eval_data) eval_pl = pd.read_csv( './tf_log/cells/tmp-512-progres-crop-norm-la/eval_{}.csv'.format( fold)) eval_pl['root'] = os.path.join(args.dataset_path, 'train') test_pl = pd.read_csv( './tf_log/cells/tmp-512-progres-crop-norm-la/test.csv') test_pl['root'] = os.path.join(args.dataset_path, 'test') pl = pd.concat([eval_pl, test_pl]) pl_size = len(pl) pl = pl.sample(frac=np.linspace(1., 0., config.epochs)[epoch - 1].item()) print('frac: {:.4f}, lr: {:.8f}'.format( len(pl) / pl_size, scheduler.get_lr())) train_dataset = TrainEvalDataset(pd.concat( [train_eval_data.iloc[train_indices], pl]), transform=train_transform) train_data_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.batch_size, drop_last=True, shuffle=True, num_workers=args.workers, worker_init_fn=worker_init_fn) eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices], transform=eval_transform) eval_data_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=config.batch_size, num_workers=args.workers, worker_init_fn=worker_init_fn) train_epoch(model=model, optimizer=optimizer, scheduler=scheduler, data_loader=train_data_loader, fold=fold, epoch=epoch) gc.collect() metric = eval_epoch(model=model, data_loader=eval_data_loader, fold=fold, epoch=epoch) gc.collect() score = metric['accuracy@1'] scheduler.step_epoch() scheduler.step_score(score) if score > best_score: best_score = score torch.save( model.state_dict(), os.path.join(args.experiment_path, 'model_{}.pth'.format(fold)))