예제 #1
0
def predict_on_eval_using_fold(fold, train_eval_data):
    _, eval_indices = indices_for_fold(fold, train_eval_data)
    eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices],
                                    transform=eval_transform)
    eval_data_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=config.batch_size,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)

    model = Model(config.model, NUM_CLASSES)
    model = model.to(DEVICE)
    model.load_state_dict(
        torch.load(
            os.path.join(args.experiment_path, 'model_{}.pth'.format(fold))))

    model.eval()
    with torch.no_grad():
        fold_labels = []
        fold_logits = []
        fold_exps = []
        fold_ids = []

        for images, feats, exps, labels, ids in tqdm(
                eval_data_loader, desc='fold {} evaluation'.format(fold)):
            images, feats, labels = images.to(DEVICE), feats.to(
                DEVICE), labels.to(DEVICE)

            b, n, c, h, w = images.size()
            images = images.view(b * n, c, h, w)
            feats = feats.view(b, 1, 2).repeat(1, n, 1).view(b * n, 2)
            logits = model(images, feats)
            logits = logits.view(b, n, NUM_CLASSES)

            fold_labels.append(labels)
            fold_logits.append(logits)
            fold_exps.extend(exps)
            fold_ids.extend(ids)

        fold_labels = torch.cat(fold_labels, 0)
        fold_logits = torch.cat(fold_logits, 0)

        return fold_labels, fold_logits, fold_exps, fold_ids
예제 #2
0
def predict_on_test_using_fold(fold, test_data):
    test_dataset = TestDataset(test_data, transform=test_transform)
    test_data_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.batch_size // 2,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)

    model = Model(config.model, NUM_CLASSES)
    model = model.to(DEVICE)
    model.load_state_dict(
        torch.load(
            os.path.join(args.experiment_path, 'model_{}.pth'.format(fold))))

    model.eval()
    with torch.no_grad():
        fold_logits = []
        fold_exps = []
        fold_plates = []
        fold_ids = []

        for images, feats, exps, plates, ids in tqdm(
                test_data_loader, desc='fold {} inference'.format(fold)):
            images, feats = images.to(DEVICE), feats.to(DEVICE)

            b, n, c, h, w = images.size()
            images = images.view(b * n, c, h, w)
            feats = feats.view(b, 1, 2).repeat(1, n, 1).view(b * n, 2)
            logits = model(images, feats)
            logits = logits.view(b, n, NUM_CLASSES)

            fold_logits.append(logits)
            fold_exps.extend(exps)
            fold_plates.extend(plates)
            fold_ids.extend(ids)

        fold_logits = torch.cat(fold_logits, 0)

    torch.save((fold_logits, fold_exps, fold_ids),
               './test_{}.pth'.format(fold))

    return fold_logits, fold_exps, fold_plates, fold_ids
예제 #3
0
def compute_features_using_fold(fold, data):
    dataset = TestDataset(data, transform=test_transform)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=config.batch_size //
                                              2,
                                              num_workers=args.workers,
                                              worker_init_fn=worker_init_fn)

    model = Model(config.model, NUM_CLASSES, return_features=True)
    model = model.to(DEVICE)
    model.load_state_dict(
        torch.load(
            os.path.join(args.experiment_path, 'model_{}.pth'.format(fold))))

    model.eval()
    with torch.no_grad():
        fold_embs = []
        fold_exps = []
        fold_ids = []

        for images, feats, exps, ids in tqdm(
                data_loader, desc='fold {} inference'.format(fold)):
            images, feats = images.to(DEVICE), feats.to(DEVICE)

            b, n, c, h, w = images.size()
            images = images.view(b * n, c, h, w)
            feats = feats.view(b, 1, 2).repeat(1, n, 1).view(b * n, 2)
            _, embds = model(images, feats)
            embds = embds.view(b, n, embds.size(1))

            fold_embs.append(embds)
            fold_exps.extend(exps)
            fold_ids.extend(ids)

        fold_embs = torch.cat(fold_embs, 0)

    return fold_embs, fold_exps, fold_ids
예제 #4
0
def train_fold(fold, train_eval_data, unsup_data):
    train_indices, eval_indices = indices_for_fold(fold, train_eval_data)

    train_dataset = TrainEvalDataset(train_eval_data.iloc[train_indices],
                                     transform=train_transform)
    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        drop_last=True,
        shuffle=True,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)
    eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices],
                                    transform=eval_transform)
    eval_data_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=config.batch_size,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)
    unsup_data = TestDataset(unsup_data, transform=unsup_transform)
    unsup_data_loader = torch.utils.data.DataLoader(
        unsup_data,
        batch_size=config.batch_size // 2,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)

    model = Model(config.model, NUM_CLASSES)
    model = model.to(DEVICE)
    if args.restore_path is not None:
        model.load_state_dict(
            torch.load(
                os.path.join(args.restore_path, 'model_{}.pth'.format(fold))))

    optimizer = build_optimizer(config.opt, model.parameters())

    if config.sched.type == 'onecycle':
        scheduler = lr_scheduler_wrapper.StepWrapper(
            OneCycleScheduler(optimizer,
                              lr=(config.opt.lr / 20, config.opt.lr),
                              beta_range=config.sched.onecycle.beta,
                              max_steps=len(train_data_loader) * config.epochs,
                              annealing=config.sched.onecycle.anneal,
                              peak_pos=config.sched.onecycle.peak_pos,
                              end_pos=config.sched.onecycle.end_pos))
    elif config.sched.type == 'step':
        scheduler = lr_scheduler_wrapper.EpochWrapper(
            torch.optim.lr_scheduler.StepLR(
                optimizer,
                step_size=config.sched.step.step_size,
                gamma=config.sched.step.decay))
    elif config.sched.type == 'cyclic':
        step_size_up = len(
            train_data_loader) * config.sched.cyclic.step_size_up
        step_size_down = len(
            train_data_loader) * config.sched.cyclic.step_size_down

        scheduler = lr_scheduler_wrapper.StepWrapper(
            torch.optim.lr_scheduler.CyclicLR(
                optimizer,
                0.,
                config.opt.lr,
                step_size_up=step_size_up,
                step_size_down=step_size_down,
                mode='triangular2',
                gamma=config.sched.cyclic.decay**(
                    1 / (step_size_up + step_size_down)),
                cycle_momentum=True,
                base_momentum=0.85,
                max_momentum=0.95))
    elif config.sched.type == 'cawr':
        scheduler = lr_scheduler_wrapper.StepWrapper(
            torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=len(train_data_loader), T_mult=2))
    elif config.sched.type == 'plateau':
        scheduler = lr_scheduler_wrapper.ScoreWrapper(
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='max',
                factor=config.sched.plateau.decay,
                patience=config.sched.plateau.patience,
                verbose=True))
    else:
        raise AssertionError('invalid sched {}'.format(config.sched.type))

    best_score = 0
    for epoch in range(1, config.epochs + 1):
        optimizer.train()
        train_epoch(model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    data_loader=train_data_loader,
                    unsup_data_loader=unsup_data_loader,
                    fold=fold,
                    epoch=epoch)
        gc.collect()
        optimizer.eval()
        metric = eval_epoch(model=model,
                            data_loader=eval_data_loader,
                            fold=fold,
                            epoch=epoch)
        gc.collect()

        score = metric['accuracy@1']

        scheduler.step_epoch()
        scheduler.step_score(score)

        if score > best_score:
            best_score = score
            torch.save(
                model.state_dict(),
                os.path.join(args.experiment_path,
                             'model_{}.pth'.format(fold)))
예제 #5
0
def train_fold(fold, train_eval_data):
    model = Model(config.model, NUM_CLASSES)
    model = model.to(DEVICE)
    if args.restore_path is not None:
        model.load_state_dict(
            torch.load(
                os.path.join(args.restore_path, 'model_{}.pth'.format(fold))))

    optimizer = build_optimizer(config.opt, model.parameters())

    if config.sched.type == 'onecycle':
        scheduler = lr_scheduler_wrapper.EpochWrapper(
            OneCycleScheduler(optimizer,
                              lr=(config.opt.lr / 20, config.opt.lr),
                              beta_range=config.sched.onecycle.beta,
                              max_steps=config.epochs,
                              annealing=config.sched.onecycle.anneal,
                              peak_pos=config.sched.onecycle.peak_pos,
                              end_pos=config.sched.onecycle.end_pos))
    elif config.sched.type == 'cyclic':
        step_size_up = len(
            train_data_loader) * config.sched.cyclic.step_size_up
        step_size_down = len(
            train_data_loader) * config.sched.cyclic.step_size_down

        scheduler = lr_scheduler_wrapper.StepWrapper(
            torch.optim.lr_scheduler.CyclicLR(
                optimizer,
                0.,
                config.opt.lr,
                step_size_up=step_size_up,
                step_size_down=step_size_down,
                mode='exp_range',
                gamma=config.sched.cyclic.decay**(
                    1 / (step_size_up + step_size_down)),
                cycle_momentum=True,
                base_momentum=0.75,
                max_momentum=0.95))
    elif config.sched.type == 'cawr':
        scheduler = lr_scheduler_wrapper.StepWrapper(
            torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=len(train_data_loader), T_mult=2))
    elif config.sched.type == 'plateau':
        scheduler = lr_scheduler_wrapper.ScoreWrapper(
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='max',
                factor=config.sched.plateau.decay,
                patience=config.sched.plateau.patience,
                verbose=True))
    else:
        raise AssertionError('invalid sched {}'.format(config.sched.type))

    best_score = 0
    for epoch in range(1, config.epochs + 1):
        train_indices, eval_indices = indices_for_fold(fold, train_eval_data)

        eval_pl = pd.read_csv(
            './tf_log/cells/tmp-512-progres-crop-norm-la/eval_{}.csv'.format(
                fold))
        eval_pl['root'] = os.path.join(args.dataset_path, 'train')
        test_pl = pd.read_csv(
            './tf_log/cells/tmp-512-progres-crop-norm-la/test.csv')
        test_pl['root'] = os.path.join(args.dataset_path, 'test')
        pl = pd.concat([eval_pl, test_pl])
        pl_size = len(pl)
        pl = pl.sample(frac=np.linspace(1., 0., config.epochs)[epoch -
                                                               1].item())
        print('frac: {:.4f}, lr: {:.8f}'.format(
            len(pl) / pl_size, scheduler.get_lr()))

        train_dataset = TrainEvalDataset(pd.concat(
            [train_eval_data.iloc[train_indices], pl]),
                                         transform=train_transform)
        train_data_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=config.batch_size,
            drop_last=True,
            shuffle=True,
            num_workers=args.workers,
            worker_init_fn=worker_init_fn)
        eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices],
                                        transform=eval_transform)
        eval_data_loader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=config.batch_size,
            num_workers=args.workers,
            worker_init_fn=worker_init_fn)

        train_epoch(model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    data_loader=train_data_loader,
                    fold=fold,
                    epoch=epoch)
        gc.collect()
        metric = eval_epoch(model=model,
                            data_loader=eval_data_loader,
                            fold=fold,
                            epoch=epoch)
        gc.collect()

        score = metric['accuracy@1']

        scheduler.step_epoch()
        scheduler.step_score(score)

        if score > best_score:
            best_score = score
            torch.save(
                model.state_dict(),
                os.path.join(args.experiment_path,
                             'model_{}.pth'.format(fold)))