def lr_search(train_eval_data): train_eval_dataset = TrainEvalDataset(train_eval_data, transform=train_transform) train_eval_data_loader = torch.utils.data.DataLoader( train_eval_dataset, batch_size=config.batch_size, drop_last=True, shuffle=True, num_workers=args.workers, worker_init_fn=worker_init_fn) min_lr = 1e-7 max_lr = 10. gamma = (max_lr / min_lr)**(1 / len(train_eval_data_loader)) lrs = [] losses = [] lim = None model = Model(config.model, NUM_CLASSES) model = model.to(DEVICE) optimizer = build_optimizer(config.opt, model.parameters()) for param_group in optimizer.param_groups: param_group['lr'] = min_lr scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma) optimizer.train() update_transforms(1.) model.train() optimizer.zero_grad() for i, (images, _, labels, _) in enumerate(tqdm(train_eval_data_loader, desc='lr search'), 1): images, labels = images.to(DEVICE), labels.to(DEVICE) labels = utils.one_hot(labels, NUM_CLASSES) images, labels = mixup(images, labels) logits = model(images, None, True) loss = compute_loss(input=logits, target=labels) labels = labels.argmax(1) lrs.append(np.squeeze(scheduler.get_lr())) losses.append(loss.data.cpu().numpy().mean()) if lim is None: lim = losses[0] * 1.1 if lim < losses[-1]: break (loss.mean() / config.opt.acc_steps).backward() if i % config.opt.acc_steps == 0: optimizer.step() optimizer.zero_grad() scheduler.step() writer = SummaryWriter(os.path.join(args.experiment_path, 'lr_search')) with torch.no_grad(): losses = np.clip(losses, 0, lim) minima_loss = losses[np.argmin(utils.smooth(losses))] minima_lr = lrs[np.argmin(utils.smooth(losses))] step = 0 for loss, loss_sm in zip(losses, utils.smooth(losses)): writer.add_scalar('search_loss', loss, global_step=step) writer.add_scalar('search_loss_sm', loss_sm, global_step=step) step += config.batch_size fig = plt.figure() plt.plot(lrs, losses) plt.plot(lrs, utils.smooth(losses)) plt.axvline(minima_lr) plt.xscale('log') plt.title('loss: {:.8f}, lr: {:.8f}'.format(minima_loss, minima_lr)) writer.add_figure('search', fig, global_step=0) return minima_lr
def train_fold(fold, train_eval_data, unsup_data): train_indices, eval_indices = indices_for_fold(fold, train_eval_data) train_dataset = TrainEvalDataset(train_eval_data.iloc[train_indices], transform=train_transform) train_data_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.batch_size, drop_last=True, shuffle=True, num_workers=args.workers, worker_init_fn=worker_init_fn) eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices], transform=eval_transform) eval_data_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=config.batch_size, num_workers=args.workers, worker_init_fn=worker_init_fn) unsup_data = TestDataset(unsup_data, transform=unsup_transform) unsup_data_loader = torch.utils.data.DataLoader( unsup_data, batch_size=config.batch_size // 2, num_workers=args.workers, worker_init_fn=worker_init_fn) model = Model(config.model, NUM_CLASSES) model = model.to(DEVICE) if args.restore_path is not None: model.load_state_dict( torch.load( os.path.join(args.restore_path, 'model_{}.pth'.format(fold)))) optimizer = build_optimizer(config.opt, model.parameters()) if config.sched.type == 'onecycle': scheduler = lr_scheduler_wrapper.StepWrapper( OneCycleScheduler(optimizer, lr=(config.opt.lr / 20, config.opt.lr), beta_range=config.sched.onecycle.beta, max_steps=len(train_data_loader) * config.epochs, annealing=config.sched.onecycle.anneal, peak_pos=config.sched.onecycle.peak_pos, end_pos=config.sched.onecycle.end_pos)) elif config.sched.type == 'step': scheduler = lr_scheduler_wrapper.EpochWrapper( torch.optim.lr_scheduler.StepLR( optimizer, step_size=config.sched.step.step_size, gamma=config.sched.step.decay)) elif config.sched.type == 'cyclic': step_size_up = len( train_data_loader) * config.sched.cyclic.step_size_up step_size_down = len( train_data_loader) * config.sched.cyclic.step_size_down scheduler = lr_scheduler_wrapper.StepWrapper( torch.optim.lr_scheduler.CyclicLR( optimizer, 0., config.opt.lr, step_size_up=step_size_up, step_size_down=step_size_down, mode='triangular2', gamma=config.sched.cyclic.decay**( 1 / (step_size_up + step_size_down)), cycle_momentum=True, base_momentum=0.85, max_momentum=0.95)) elif config.sched.type == 'cawr': scheduler = lr_scheduler_wrapper.StepWrapper( torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=len(train_data_loader), T_mult=2)) elif config.sched.type == 'plateau': scheduler = lr_scheduler_wrapper.ScoreWrapper( torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', factor=config.sched.plateau.decay, patience=config.sched.plateau.patience, verbose=True)) else: raise AssertionError('invalid sched {}'.format(config.sched.type)) best_score = 0 for epoch in range(1, config.epochs + 1): optimizer.train() train_epoch(model=model, optimizer=optimizer, scheduler=scheduler, data_loader=train_data_loader, unsup_data_loader=unsup_data_loader, fold=fold, epoch=epoch) gc.collect() optimizer.eval() metric = eval_epoch(model=model, data_loader=eval_data_loader, fold=fold, epoch=epoch) gc.collect() score = metric['accuracy@1'] scheduler.step_epoch() scheduler.step_score(score) if score > best_score: best_score = score torch.save( model.state_dict(), os.path.join(args.experiment_path, 'model_{}.pth'.format(fold)))
def train_fold(fold, train_eval_data): model = Model(config.model, NUM_CLASSES) model = model.to(DEVICE) if args.restore_path is not None: model.load_state_dict( torch.load( os.path.join(args.restore_path, 'model_{}.pth'.format(fold)))) optimizer = build_optimizer(config.opt, model.parameters()) if config.sched.type == 'onecycle': scheduler = lr_scheduler_wrapper.EpochWrapper( OneCycleScheduler(optimizer, lr=(config.opt.lr / 20, config.opt.lr), beta_range=config.sched.onecycle.beta, max_steps=config.epochs, annealing=config.sched.onecycle.anneal, peak_pos=config.sched.onecycle.peak_pos, end_pos=config.sched.onecycle.end_pos)) elif config.sched.type == 'cyclic': step_size_up = len( train_data_loader) * config.sched.cyclic.step_size_up step_size_down = len( train_data_loader) * config.sched.cyclic.step_size_down scheduler = lr_scheduler_wrapper.StepWrapper( torch.optim.lr_scheduler.CyclicLR( optimizer, 0., config.opt.lr, step_size_up=step_size_up, step_size_down=step_size_down, mode='exp_range', gamma=config.sched.cyclic.decay**( 1 / (step_size_up + step_size_down)), cycle_momentum=True, base_momentum=0.75, max_momentum=0.95)) elif config.sched.type == 'cawr': scheduler = lr_scheduler_wrapper.StepWrapper( torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=len(train_data_loader), T_mult=2)) elif config.sched.type == 'plateau': scheduler = lr_scheduler_wrapper.ScoreWrapper( torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', factor=config.sched.plateau.decay, patience=config.sched.plateau.patience, verbose=True)) else: raise AssertionError('invalid sched {}'.format(config.sched.type)) best_score = 0 for epoch in range(1, config.epochs + 1): train_indices, eval_indices = indices_for_fold(fold, train_eval_data) eval_pl = pd.read_csv( './tf_log/cells/tmp-512-progres-crop-norm-la/eval_{}.csv'.format( fold)) eval_pl['root'] = os.path.join(args.dataset_path, 'train') test_pl = pd.read_csv( './tf_log/cells/tmp-512-progres-crop-norm-la/test.csv') test_pl['root'] = os.path.join(args.dataset_path, 'test') pl = pd.concat([eval_pl, test_pl]) pl_size = len(pl) pl = pl.sample(frac=np.linspace(1., 0., config.epochs)[epoch - 1].item()) print('frac: {:.4f}, lr: {:.8f}'.format( len(pl) / pl_size, scheduler.get_lr())) train_dataset = TrainEvalDataset(pd.concat( [train_eval_data.iloc[train_indices], pl]), transform=train_transform) train_data_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.batch_size, drop_last=True, shuffle=True, num_workers=args.workers, worker_init_fn=worker_init_fn) eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices], transform=eval_transform) eval_data_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=config.batch_size, num_workers=args.workers, worker_init_fn=worker_init_fn) train_epoch(model=model, optimizer=optimizer, scheduler=scheduler, data_loader=train_data_loader, fold=fold, epoch=epoch) gc.collect() metric = eval_epoch(model=model, data_loader=eval_data_loader, fold=fold, epoch=epoch) gc.collect() score = metric['accuracy@1'] scheduler.step_epoch() scheduler.step_score(score) if score > best_score: best_score = score torch.save( model.state_dict(), os.path.join(args.experiment_path, 'model_{}.pth'.format(fold)))