def predict_on_eval_using_fold(fold, train_eval_data): _, eval_indices = indices_for_fold(fold, train_eval_data) eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices], transform=eval_transform) eval_data_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=config.batch_size, num_workers=args.workers, worker_init_fn=worker_init_fn) model = Model(config.model, NUM_CLASSES) model = model.to(DEVICE) model.load_state_dict( torch.load( os.path.join(args.experiment_path, 'model_{}.pth'.format(fold)))) model.eval() with torch.no_grad(): fold_labels = [] fold_logits = [] fold_exps = [] fold_ids = [] for images, feats, exps, labels, ids in tqdm( eval_data_loader, desc='fold {} evaluation'.format(fold)): images, feats, labels = images.to(DEVICE), feats.to( DEVICE), labels.to(DEVICE) logits = model(images, feats) fold_labels.append(labels) fold_logits.append(logits) fold_exps.extend(exps) fold_ids.extend(ids) fold_labels = torch.cat(fold_labels, 0) fold_logits = torch.cat(fold_logits, 0) tmp = train_eval_data.iloc[eval_indices].copy() temp, _, _ = find_temp_global(input=fold_logits, target=fold_labels, exps=fold_exps) classes = assign_classes(probs=(fold_logits * temp).softmax(1).data.cpu().numpy(), exps=fold_exps) print('{:.2f}'.format((tmp['sirna'] == classes).mean())) tmp['sirna'] = classes tmp.to_csv(os.path.join(args.experiment_path, 'eval_{}.csv'.format(fold)), index=False) return fold_labels, fold_logits, fold_exps, fold_ids
def predict_on_eval_using_fold(fold, train_eval_data): _, eval_indices = indices_for_fold(fold, train_eval_data) eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices], transform=eval_transform) eval_data_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=config.batch_size // 4, num_workers=args.workers, worker_init_fn=worker_init_fn) model = Model(config.model, NUM_CLASSES) model = model.to(DEVICE) model.load_state_dict( torch.load( os.path.join(args.experiment_path, 'model_{}.pth'.format(fold)))) model.eval() with torch.no_grad(): fold_labels = [] fold_probs = [] fold_exps = [] fold_ids = [] for images, feats, exps, labels, ids in tqdm( eval_data_loader, desc='fold {} evaluation'.format(fold)): images, feats, labels = images.to(DEVICE), feats.to( DEVICE), labels.to(DEVICE) b, n, c, h, w = images.size() assert n == 2 * NUM_TTA images = images.view(b * n, c, h, w) feats = feats.view(b, 1, 2).repeat(1, n, 1).view(b * n, 2) logits = model(images, feats) logits = logits.view(b, n, NUM_CLASSES) probs = softmax(logits) fold_labels.append(labels) fold_probs.append(probs) fold_exps.extend(exps) fold_ids.extend(ids) fold_labels = torch.cat(fold_labels, 0) fold_probs = torch.cat(fold_probs, 0) fold_plates = train_eval_data.iloc[eval_indices]['plate'].values return fold_labels, fold_probs, fold_exps, fold_plates, fold_ids
def predict_on_eval_using_fold(fold, train_eval_data): _, eval_indices = indices_for_fold(fold, train_eval_data) eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices], transform=eval_transform) eval_data_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=config.batch_size, num_workers=args.workers, worker_init_fn=worker_init_fn) model = Model(config.model, NUM_CLASSES) model = model.to(DEVICE) model.load_state_dict(torch.load(os.path.join(args.experiment_path, 'model_{}.pth'.format(fold)))) model.eval() with torch.no_grad(): fold_labels = [] fold_logits = [] fold_exps = [] fold_ids = [] for images, feats, exps, labels, ids in tqdm(eval_data_loader, desc='fold {} evaluation'.format(fold)): images, feats, labels = images.to(DEVICE), feats.to(DEVICE), labels.to(DEVICE) b, n, c, h, w = images.size() images = images.view(b * n, c, h, w) feats = feats.view(b, 1, 2).repeat(1, n, 1).view(b * n, 2) logits = model(images, feats) logits = logits.view(b, n, NUM_CLASSES) fold_labels.append(labels) fold_logits.append(logits) fold_exps.extend(exps) fold_ids.extend(ids) fold_labels = torch.cat(fold_labels, 0) fold_logits = torch.cat(fold_logits, 0) fold_plates = train_eval_data.iloc[eval_indices]['plate'].values temp, _, _ = find_temp_global(input=fold_logits, target=fold_labels, exps=fold_exps) classes = assign_classes(probs=to_prob(fold_logits, temp).data.cpu().numpy(), exps=fold_exps) fold_logits = refine_scores( fold_logits, classes, exps=fold_exps, plates=fold_plates, value=float('-inf')) return fold_labels, fold_logits, fold_exps, fold_ids
def predict_on_test_using_fold(fold, test_data): test_dataset = TestDataset(test_data, transform=test_transform) test_data_loader = torch.utils.data.DataLoader( test_dataset, batch_size=config.batch_size // 2, num_workers=args.workers, worker_init_fn=worker_init_fn) model = Model(config.model, NUM_CLASSES) model = model.to(DEVICE) model.load_state_dict( torch.load( os.path.join(args.experiment_path, 'model_{}.pth'.format(fold)))) model.eval() with torch.no_grad(): fold_logits = [] fold_exps = [] fold_plates = [] fold_ids = [] for images, feats, exps, plates, ids in tqdm( test_data_loader, desc='fold {} inference'.format(fold)): images, feats = images.to(DEVICE), feats.to(DEVICE) b, n, c, h, w = images.size() images = images.view(b * n, c, h, w) feats = feats.view(b, 1, 2).repeat(1, n, 1).view(b * n, 2) logits = model(images, feats) logits = logits.view(b, n, NUM_CLASSES) fold_logits.append(logits) fold_exps.extend(exps) fold_plates.extend(plates) fold_ids.extend(ids) fold_logits = torch.cat(fold_logits, 0) torch.save((fold_logits, fold_exps, fold_ids), './test_{}.pth'.format(fold)) return fold_logits, fold_exps, fold_plates, fold_ids
def compute_features_using_fold(fold, data): dataset = TestDataset(data, transform=test_transform) data_loader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size // 2, num_workers=args.workers, worker_init_fn=worker_init_fn) model = Model(config.model, NUM_CLASSES, return_features=True) model = model.to(DEVICE) model.load_state_dict( torch.load( os.path.join(args.experiment_path, 'model_{}.pth'.format(fold)))) model.eval() with torch.no_grad(): fold_embs = [] fold_exps = [] fold_ids = [] for images, feats, exps, ids in tqdm( data_loader, desc='fold {} inference'.format(fold)): images, feats = images.to(DEVICE), feats.to(DEVICE) b, n, c, h, w = images.size() images = images.view(b * n, c, h, w) feats = feats.view(b, 1, 2).repeat(1, n, 1).view(b * n, 2) _, embds = model(images, feats) embds = embds.view(b, n, embds.size(1)) fold_embs.append(embds) fold_exps.extend(exps) fold_ids.extend(ids) fold_embs = torch.cat(fold_embs, 0) return fold_embs, fold_exps, fold_ids
def train_fold(fold, train_eval_data, unsup_data): train_indices, eval_indices = indices_for_fold(fold, train_eval_data) train_dataset = TrainEvalDataset(train_eval_data.iloc[train_indices], transform=train_transform) train_data_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.batch_size, drop_last=True, shuffle=True, num_workers=args.workers, worker_init_fn=worker_init_fn) eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices], transform=eval_transform) eval_data_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=config.batch_size, num_workers=args.workers, worker_init_fn=worker_init_fn) unsup_data = TestDataset(unsup_data, transform=unsup_transform) unsup_data_loader = torch.utils.data.DataLoader( unsup_data, batch_size=config.batch_size // 2, num_workers=args.workers, worker_init_fn=worker_init_fn) model = Model(config.model, NUM_CLASSES) model = model.to(DEVICE) if args.restore_path is not None: model.load_state_dict( torch.load( os.path.join(args.restore_path, 'model_{}.pth'.format(fold)))) optimizer = build_optimizer(config.opt, model.parameters()) if config.sched.type == 'onecycle': scheduler = lr_scheduler_wrapper.StepWrapper( OneCycleScheduler(optimizer, lr=(config.opt.lr / 20, config.opt.lr), beta_range=config.sched.onecycle.beta, max_steps=len(train_data_loader) * config.epochs, annealing=config.sched.onecycle.anneal, peak_pos=config.sched.onecycle.peak_pos, end_pos=config.sched.onecycle.end_pos)) elif config.sched.type == 'step': scheduler = lr_scheduler_wrapper.EpochWrapper( torch.optim.lr_scheduler.StepLR( optimizer, step_size=config.sched.step.step_size, gamma=config.sched.step.decay)) elif config.sched.type == 'cyclic': step_size_up = len( train_data_loader) * config.sched.cyclic.step_size_up step_size_down = len( train_data_loader) * config.sched.cyclic.step_size_down scheduler = lr_scheduler_wrapper.StepWrapper( torch.optim.lr_scheduler.CyclicLR( optimizer, 0., config.opt.lr, step_size_up=step_size_up, step_size_down=step_size_down, mode='triangular2', gamma=config.sched.cyclic.decay**( 1 / (step_size_up + step_size_down)), cycle_momentum=True, base_momentum=0.85, max_momentum=0.95)) elif config.sched.type == 'cawr': scheduler = lr_scheduler_wrapper.StepWrapper( torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=len(train_data_loader), T_mult=2)) elif config.sched.type == 'plateau': scheduler = lr_scheduler_wrapper.ScoreWrapper( torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', factor=config.sched.plateau.decay, patience=config.sched.plateau.patience, verbose=True)) else: raise AssertionError('invalid sched {}'.format(config.sched.type)) best_score = 0 for epoch in range(1, config.epochs + 1): optimizer.train() train_epoch(model=model, optimizer=optimizer, scheduler=scheduler, data_loader=train_data_loader, unsup_data_loader=unsup_data_loader, fold=fold, epoch=epoch) gc.collect() optimizer.eval() metric = eval_epoch(model=model, data_loader=eval_data_loader, fold=fold, epoch=epoch) gc.collect() score = metric['accuracy@1'] scheduler.step_epoch() scheduler.step_score(score) if score > best_score: best_score = score torch.save( model.state_dict(), os.path.join(args.experiment_path, 'model_{}.pth'.format(fold)))
def lr_search(train_eval_data): train_eval_dataset = TrainEvalDataset(train_eval_data, transform=train_transform) train_eval_data_loader = torch.utils.data.DataLoader( train_eval_dataset, batch_size=config.batch_size, drop_last=True, shuffle=True, num_workers=args.workers, worker_init_fn=worker_init_fn) min_lr = 1e-7 max_lr = 10. gamma = (max_lr / min_lr)**(1 / len(train_eval_data_loader)) lrs = [] losses = [] lim = None model = Model(config.model, NUM_CLASSES) model = model.to(DEVICE) optimizer = build_optimizer(config.opt, model.parameters()) for param_group in optimizer.param_groups: param_group['lr'] = min_lr scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma) optimizer.train() update_transforms(1.) model.train() optimizer.zero_grad() for i, (images, _, labels, _) in enumerate(tqdm(train_eval_data_loader, desc='lr search'), 1): images, labels = images.to(DEVICE), labels.to(DEVICE) labels = utils.one_hot(labels, NUM_CLASSES) images, labels = mixup(images, labels) logits = model(images, None, True) loss = compute_loss(input=logits, target=labels) labels = labels.argmax(1) lrs.append(np.squeeze(scheduler.get_lr())) losses.append(loss.data.cpu().numpy().mean()) if lim is None: lim = losses[0] * 1.1 if lim < losses[-1]: break (loss.mean() / config.opt.acc_steps).backward() if i % config.opt.acc_steps == 0: optimizer.step() optimizer.zero_grad() scheduler.step() writer = SummaryWriter(os.path.join(args.experiment_path, 'lr_search')) with torch.no_grad(): losses = np.clip(losses, 0, lim) minima_loss = losses[np.argmin(utils.smooth(losses))] minima_lr = lrs[np.argmin(utils.smooth(losses))] step = 0 for loss, loss_sm in zip(losses, utils.smooth(losses)): writer.add_scalar('search_loss', loss, global_step=step) writer.add_scalar('search_loss_sm', loss_sm, global_step=step) step += config.batch_size fig = plt.figure() plt.plot(lrs, losses) plt.plot(lrs, utils.smooth(losses)) plt.axvline(minima_lr) plt.xscale('log') plt.title('loss: {:.8f}, lr: {:.8f}'.format(minima_loss, minima_lr)) writer.add_figure('search', fig, global_step=0) return minima_lr
def train_fold(fold, train_eval_data): model = Model(config.model, NUM_CLASSES) model = model.to(DEVICE) if args.restore_path is not None: model.load_state_dict( torch.load( os.path.join(args.restore_path, 'model_{}.pth'.format(fold)))) optimizer = build_optimizer(config.opt, model.parameters()) if config.sched.type == 'onecycle': scheduler = lr_scheduler_wrapper.EpochWrapper( OneCycleScheduler(optimizer, lr=(config.opt.lr / 20, config.opt.lr), beta_range=config.sched.onecycle.beta, max_steps=config.epochs, annealing=config.sched.onecycle.anneal, peak_pos=config.sched.onecycle.peak_pos, end_pos=config.sched.onecycle.end_pos)) elif config.sched.type == 'cyclic': step_size_up = len( train_data_loader) * config.sched.cyclic.step_size_up step_size_down = len( train_data_loader) * config.sched.cyclic.step_size_down scheduler = lr_scheduler_wrapper.StepWrapper( torch.optim.lr_scheduler.CyclicLR( optimizer, 0., config.opt.lr, step_size_up=step_size_up, step_size_down=step_size_down, mode='exp_range', gamma=config.sched.cyclic.decay**( 1 / (step_size_up + step_size_down)), cycle_momentum=True, base_momentum=0.75, max_momentum=0.95)) elif config.sched.type == 'cawr': scheduler = lr_scheduler_wrapper.StepWrapper( torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=len(train_data_loader), T_mult=2)) elif config.sched.type == 'plateau': scheduler = lr_scheduler_wrapper.ScoreWrapper( torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', factor=config.sched.plateau.decay, patience=config.sched.plateau.patience, verbose=True)) else: raise AssertionError('invalid sched {}'.format(config.sched.type)) best_score = 0 for epoch in range(1, config.epochs + 1): train_indices, eval_indices = indices_for_fold(fold, train_eval_data) eval_pl = pd.read_csv( './tf_log/cells/tmp-512-progres-crop-norm-la/eval_{}.csv'.format( fold)) eval_pl['root'] = os.path.join(args.dataset_path, 'train') test_pl = pd.read_csv( './tf_log/cells/tmp-512-progres-crop-norm-la/test.csv') test_pl['root'] = os.path.join(args.dataset_path, 'test') pl = pd.concat([eval_pl, test_pl]) pl_size = len(pl) pl = pl.sample(frac=np.linspace(1., 0., config.epochs)[epoch - 1].item()) print('frac: {:.4f}, lr: {:.8f}'.format( len(pl) / pl_size, scheduler.get_lr())) train_dataset = TrainEvalDataset(pd.concat( [train_eval_data.iloc[train_indices], pl]), transform=train_transform) train_data_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.batch_size, drop_last=True, shuffle=True, num_workers=args.workers, worker_init_fn=worker_init_fn) eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices], transform=eval_transform) eval_data_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=config.batch_size, num_workers=args.workers, worker_init_fn=worker_init_fn) train_epoch(model=model, optimizer=optimizer, scheduler=scheduler, data_loader=train_data_loader, fold=fold, epoch=epoch) gc.collect() metric = eval_epoch(model=model, data_loader=eval_data_loader, fold=fold, epoch=epoch) gc.collect() score = metric['accuracy@1'] scheduler.step_epoch() scheduler.step_score(score) if score > best_score: best_score = score torch.save( model.state_dict(), os.path.join(args.experiment_path, 'model_{}.pth'.format(fold)))