Ejemplo n.º 1
0
def predict_on_eval_using_fold(fold, train_eval_data):
    _, eval_indices = indices_for_fold(fold, train_eval_data)
    eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices],
                                    transform=eval_transform)
    eval_data_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=config.batch_size,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)

    model = Model(config.model, NUM_CLASSES)
    model = model.to(DEVICE)
    model.load_state_dict(
        torch.load(
            os.path.join(args.experiment_path, 'model_{}.pth'.format(fold))))

    model.eval()
    with torch.no_grad():
        fold_labels = []
        fold_logits = []
        fold_exps = []
        fold_ids = []

        for images, feats, exps, labels, ids in tqdm(
                eval_data_loader, desc='fold {} evaluation'.format(fold)):
            images, feats, labels = images.to(DEVICE), feats.to(
                DEVICE), labels.to(DEVICE)
            logits = model(images, feats)

            fold_labels.append(labels)
            fold_logits.append(logits)
            fold_exps.extend(exps)
            fold_ids.extend(ids)

        fold_labels = torch.cat(fold_labels, 0)
        fold_logits = torch.cat(fold_logits, 0)

        tmp = train_eval_data.iloc[eval_indices].copy()
        temp, _, _ = find_temp_global(input=fold_logits,
                                      target=fold_labels,
                                      exps=fold_exps)
        classes = assign_classes(probs=(fold_logits *
                                        temp).softmax(1).data.cpu().numpy(),
                                 exps=fold_exps)
        print('{:.2f}'.format((tmp['sirna'] == classes).mean()))
        tmp['sirna'] = classes
        tmp.to_csv(os.path.join(args.experiment_path,
                                'eval_{}.csv'.format(fold)),
                   index=False)

        return fold_labels, fold_logits, fold_exps, fold_ids
Ejemplo n.º 2
0
def predict_on_eval_using_fold(fold, train_eval_data):
    _, eval_indices = indices_for_fold(fold, train_eval_data)
    eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices],
                                    transform=eval_transform)
    eval_data_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=config.batch_size // 4,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)

    model = Model(config.model, NUM_CLASSES)
    model = model.to(DEVICE)
    model.load_state_dict(
        torch.load(
            os.path.join(args.experiment_path, 'model_{}.pth'.format(fold))))

    model.eval()
    with torch.no_grad():
        fold_labels = []
        fold_probs = []
        fold_exps = []
        fold_ids = []

        for images, feats, exps, labels, ids in tqdm(
                eval_data_loader, desc='fold {} evaluation'.format(fold)):
            images, feats, labels = images.to(DEVICE), feats.to(
                DEVICE), labels.to(DEVICE)

            b, n, c, h, w = images.size()
            assert n == 2 * NUM_TTA
            images = images.view(b * n, c, h, w)
            feats = feats.view(b, 1, 2).repeat(1, n, 1).view(b * n, 2)
            logits = model(images, feats)
            logits = logits.view(b, n, NUM_CLASSES)
            probs = softmax(logits)

            fold_labels.append(labels)
            fold_probs.append(probs)
            fold_exps.extend(exps)
            fold_ids.extend(ids)

        fold_labels = torch.cat(fold_labels, 0)
        fold_probs = torch.cat(fold_probs, 0)
        fold_plates = train_eval_data.iloc[eval_indices]['plate'].values

        return fold_labels, fold_probs, fold_exps, fold_plates, fold_ids
Ejemplo n.º 3
0
def predict_on_eval_using_fold(fold, train_eval_data):
    _, eval_indices = indices_for_fold(fold, train_eval_data)
    eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices], transform=eval_transform)
    eval_data_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=config.batch_size,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)

    model = Model(config.model, NUM_CLASSES)
    model = model.to(DEVICE)
    model.load_state_dict(torch.load(os.path.join(args.experiment_path, 'model_{}.pth'.format(fold))))

    model.eval()
    with torch.no_grad():
        fold_labels = []
        fold_logits = []
        fold_exps = []
        fold_ids = []

        for images, feats, exps, labels, ids in tqdm(eval_data_loader, desc='fold {} evaluation'.format(fold)):
            images, feats, labels = images.to(DEVICE), feats.to(DEVICE), labels.to(DEVICE)

            b, n, c, h, w = images.size()
            images = images.view(b * n, c, h, w)
            feats = feats.view(b, 1, 2).repeat(1, n, 1).view(b * n, 2)
            logits = model(images, feats)
            logits = logits.view(b, n, NUM_CLASSES)

            fold_labels.append(labels)
            fold_logits.append(logits)
            fold_exps.extend(exps)
            fold_ids.extend(ids)

        fold_labels = torch.cat(fold_labels, 0)
        fold_logits = torch.cat(fold_logits, 0)

        fold_plates = train_eval_data.iloc[eval_indices]['plate'].values
        temp, _, _ = find_temp_global(input=fold_logits, target=fold_labels, exps=fold_exps)
        classes = assign_classes(probs=to_prob(fold_logits, temp).data.cpu().numpy(), exps=fold_exps)
        fold_logits = refine_scores(
            fold_logits, classes, exps=fold_exps, plates=fold_plates, value=float('-inf'))

        return fold_labels, fold_logits, fold_exps, fold_ids
Ejemplo n.º 4
0
def predict_on_test_using_fold(fold, test_data):
    test_dataset = TestDataset(test_data, transform=test_transform)
    test_data_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.batch_size // 2,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)

    model = Model(config.model, NUM_CLASSES)
    model = model.to(DEVICE)
    model.load_state_dict(
        torch.load(
            os.path.join(args.experiment_path, 'model_{}.pth'.format(fold))))

    model.eval()
    with torch.no_grad():
        fold_logits = []
        fold_exps = []
        fold_plates = []
        fold_ids = []

        for images, feats, exps, plates, ids in tqdm(
                test_data_loader, desc='fold {} inference'.format(fold)):
            images, feats = images.to(DEVICE), feats.to(DEVICE)

            b, n, c, h, w = images.size()
            images = images.view(b * n, c, h, w)
            feats = feats.view(b, 1, 2).repeat(1, n, 1).view(b * n, 2)
            logits = model(images, feats)
            logits = logits.view(b, n, NUM_CLASSES)

            fold_logits.append(logits)
            fold_exps.extend(exps)
            fold_plates.extend(plates)
            fold_ids.extend(ids)

        fold_logits = torch.cat(fold_logits, 0)

    torch.save((fold_logits, fold_exps, fold_ids),
               './test_{}.pth'.format(fold))

    return fold_logits, fold_exps, fold_plates, fold_ids
Ejemplo n.º 5
0
def compute_features_using_fold(fold, data):
    dataset = TestDataset(data, transform=test_transform)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=config.batch_size //
                                              2,
                                              num_workers=args.workers,
                                              worker_init_fn=worker_init_fn)

    model = Model(config.model, NUM_CLASSES, return_features=True)
    model = model.to(DEVICE)
    model.load_state_dict(
        torch.load(
            os.path.join(args.experiment_path, 'model_{}.pth'.format(fold))))

    model.eval()
    with torch.no_grad():
        fold_embs = []
        fold_exps = []
        fold_ids = []

        for images, feats, exps, ids in tqdm(
                data_loader, desc='fold {} inference'.format(fold)):
            images, feats = images.to(DEVICE), feats.to(DEVICE)

            b, n, c, h, w = images.size()
            images = images.view(b * n, c, h, w)
            feats = feats.view(b, 1, 2).repeat(1, n, 1).view(b * n, 2)
            _, embds = model(images, feats)
            embds = embds.view(b, n, embds.size(1))

            fold_embs.append(embds)
            fold_exps.extend(exps)
            fold_ids.extend(ids)

        fold_embs = torch.cat(fold_embs, 0)

    return fold_embs, fold_exps, fold_ids
Ejemplo n.º 6
0
def train_fold(fold, train_eval_data, unsup_data):
    train_indices, eval_indices = indices_for_fold(fold, train_eval_data)

    train_dataset = TrainEvalDataset(train_eval_data.iloc[train_indices],
                                     transform=train_transform)
    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        drop_last=True,
        shuffle=True,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)
    eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices],
                                    transform=eval_transform)
    eval_data_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=config.batch_size,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)
    unsup_data = TestDataset(unsup_data, transform=unsup_transform)
    unsup_data_loader = torch.utils.data.DataLoader(
        unsup_data,
        batch_size=config.batch_size // 2,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)

    model = Model(config.model, NUM_CLASSES)
    model = model.to(DEVICE)
    if args.restore_path is not None:
        model.load_state_dict(
            torch.load(
                os.path.join(args.restore_path, 'model_{}.pth'.format(fold))))

    optimizer = build_optimizer(config.opt, model.parameters())

    if config.sched.type == 'onecycle':
        scheduler = lr_scheduler_wrapper.StepWrapper(
            OneCycleScheduler(optimizer,
                              lr=(config.opt.lr / 20, config.opt.lr),
                              beta_range=config.sched.onecycle.beta,
                              max_steps=len(train_data_loader) * config.epochs,
                              annealing=config.sched.onecycle.anneal,
                              peak_pos=config.sched.onecycle.peak_pos,
                              end_pos=config.sched.onecycle.end_pos))
    elif config.sched.type == 'step':
        scheduler = lr_scheduler_wrapper.EpochWrapper(
            torch.optim.lr_scheduler.StepLR(
                optimizer,
                step_size=config.sched.step.step_size,
                gamma=config.sched.step.decay))
    elif config.sched.type == 'cyclic':
        step_size_up = len(
            train_data_loader) * config.sched.cyclic.step_size_up
        step_size_down = len(
            train_data_loader) * config.sched.cyclic.step_size_down

        scheduler = lr_scheduler_wrapper.StepWrapper(
            torch.optim.lr_scheduler.CyclicLR(
                optimizer,
                0.,
                config.opt.lr,
                step_size_up=step_size_up,
                step_size_down=step_size_down,
                mode='triangular2',
                gamma=config.sched.cyclic.decay**(
                    1 / (step_size_up + step_size_down)),
                cycle_momentum=True,
                base_momentum=0.85,
                max_momentum=0.95))
    elif config.sched.type == 'cawr':
        scheduler = lr_scheduler_wrapper.StepWrapper(
            torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=len(train_data_loader), T_mult=2))
    elif config.sched.type == 'plateau':
        scheduler = lr_scheduler_wrapper.ScoreWrapper(
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='max',
                factor=config.sched.plateau.decay,
                patience=config.sched.plateau.patience,
                verbose=True))
    else:
        raise AssertionError('invalid sched {}'.format(config.sched.type))

    best_score = 0
    for epoch in range(1, config.epochs + 1):
        optimizer.train()
        train_epoch(model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    data_loader=train_data_loader,
                    unsup_data_loader=unsup_data_loader,
                    fold=fold,
                    epoch=epoch)
        gc.collect()
        optimizer.eval()
        metric = eval_epoch(model=model,
                            data_loader=eval_data_loader,
                            fold=fold,
                            epoch=epoch)
        gc.collect()

        score = metric['accuracy@1']

        scheduler.step_epoch()
        scheduler.step_score(score)

        if score > best_score:
            best_score = score
            torch.save(
                model.state_dict(),
                os.path.join(args.experiment_path,
                             'model_{}.pth'.format(fold)))
Ejemplo n.º 7
0
def lr_search(train_eval_data):
    train_eval_dataset = TrainEvalDataset(train_eval_data,
                                          transform=train_transform)
    train_eval_data_loader = torch.utils.data.DataLoader(
        train_eval_dataset,
        batch_size=config.batch_size,
        drop_last=True,
        shuffle=True,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)

    min_lr = 1e-7
    max_lr = 10.
    gamma = (max_lr / min_lr)**(1 / len(train_eval_data_loader))

    lrs = []
    losses = []
    lim = None

    model = Model(config.model, NUM_CLASSES)
    model = model.to(DEVICE)

    optimizer = build_optimizer(config.opt, model.parameters())
    for param_group in optimizer.param_groups:
        param_group['lr'] = min_lr
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)

    optimizer.train()
    update_transforms(1.)
    model.train()
    optimizer.zero_grad()
    for i, (images, _, labels,
            _) in enumerate(tqdm(train_eval_data_loader, desc='lr search'), 1):
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        labels = utils.one_hot(labels, NUM_CLASSES)
        images, labels = mixup(images, labels)
        logits = model(images, None, True)

        loss = compute_loss(input=logits, target=labels)
        labels = labels.argmax(1)

        lrs.append(np.squeeze(scheduler.get_lr()))
        losses.append(loss.data.cpu().numpy().mean())

        if lim is None:
            lim = losses[0] * 1.1

        if lim < losses[-1]:
            break

        (loss.mean() / config.opt.acc_steps).backward()

        if i % config.opt.acc_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        scheduler.step()

    writer = SummaryWriter(os.path.join(args.experiment_path, 'lr_search'))

    with torch.no_grad():
        losses = np.clip(losses, 0, lim)
        minima_loss = losses[np.argmin(utils.smooth(losses))]
        minima_lr = lrs[np.argmin(utils.smooth(losses))]

        step = 0
        for loss, loss_sm in zip(losses, utils.smooth(losses)):
            writer.add_scalar('search_loss', loss, global_step=step)
            writer.add_scalar('search_loss_sm', loss_sm, global_step=step)
            step += config.batch_size

        fig = plt.figure()
        plt.plot(lrs, losses)
        plt.plot(lrs, utils.smooth(losses))
        plt.axvline(minima_lr)
        plt.xscale('log')
        plt.title('loss: {:.8f}, lr: {:.8f}'.format(minima_loss, minima_lr))
        writer.add_figure('search', fig, global_step=0)

        return minima_lr
Ejemplo n.º 8
0
def train_fold(fold, train_eval_data):
    model = Model(config.model, NUM_CLASSES)
    model = model.to(DEVICE)
    if args.restore_path is not None:
        model.load_state_dict(
            torch.load(
                os.path.join(args.restore_path, 'model_{}.pth'.format(fold))))

    optimizer = build_optimizer(config.opt, model.parameters())

    if config.sched.type == 'onecycle':
        scheduler = lr_scheduler_wrapper.EpochWrapper(
            OneCycleScheduler(optimizer,
                              lr=(config.opt.lr / 20, config.opt.lr),
                              beta_range=config.sched.onecycle.beta,
                              max_steps=config.epochs,
                              annealing=config.sched.onecycle.anneal,
                              peak_pos=config.sched.onecycle.peak_pos,
                              end_pos=config.sched.onecycle.end_pos))
    elif config.sched.type == 'cyclic':
        step_size_up = len(
            train_data_loader) * config.sched.cyclic.step_size_up
        step_size_down = len(
            train_data_loader) * config.sched.cyclic.step_size_down

        scheduler = lr_scheduler_wrapper.StepWrapper(
            torch.optim.lr_scheduler.CyclicLR(
                optimizer,
                0.,
                config.opt.lr,
                step_size_up=step_size_up,
                step_size_down=step_size_down,
                mode='exp_range',
                gamma=config.sched.cyclic.decay**(
                    1 / (step_size_up + step_size_down)),
                cycle_momentum=True,
                base_momentum=0.75,
                max_momentum=0.95))
    elif config.sched.type == 'cawr':
        scheduler = lr_scheduler_wrapper.StepWrapper(
            torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=len(train_data_loader), T_mult=2))
    elif config.sched.type == 'plateau':
        scheduler = lr_scheduler_wrapper.ScoreWrapper(
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='max',
                factor=config.sched.plateau.decay,
                patience=config.sched.plateau.patience,
                verbose=True))
    else:
        raise AssertionError('invalid sched {}'.format(config.sched.type))

    best_score = 0
    for epoch in range(1, config.epochs + 1):
        train_indices, eval_indices = indices_for_fold(fold, train_eval_data)

        eval_pl = pd.read_csv(
            './tf_log/cells/tmp-512-progres-crop-norm-la/eval_{}.csv'.format(
                fold))
        eval_pl['root'] = os.path.join(args.dataset_path, 'train')
        test_pl = pd.read_csv(
            './tf_log/cells/tmp-512-progres-crop-norm-la/test.csv')
        test_pl['root'] = os.path.join(args.dataset_path, 'test')
        pl = pd.concat([eval_pl, test_pl])
        pl_size = len(pl)
        pl = pl.sample(frac=np.linspace(1., 0., config.epochs)[epoch -
                                                               1].item())
        print('frac: {:.4f}, lr: {:.8f}'.format(
            len(pl) / pl_size, scheduler.get_lr()))

        train_dataset = TrainEvalDataset(pd.concat(
            [train_eval_data.iloc[train_indices], pl]),
                                         transform=train_transform)
        train_data_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=config.batch_size,
            drop_last=True,
            shuffle=True,
            num_workers=args.workers,
            worker_init_fn=worker_init_fn)
        eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices],
                                        transform=eval_transform)
        eval_data_loader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=config.batch_size,
            num_workers=args.workers,
            worker_init_fn=worker_init_fn)

        train_epoch(model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    data_loader=train_data_loader,
                    fold=fold,
                    epoch=epoch)
        gc.collect()
        metric = eval_epoch(model=model,
                            data_loader=eval_data_loader,
                            fold=fold,
                            epoch=epoch)
        gc.collect()

        score = metric['accuracy@1']

        scheduler.step_epoch()
        scheduler.step_score(score)

        if score > best_score:
            best_score = score
            torch.save(
                model.state_dict(),
                os.path.join(args.experiment_path,
                             'model_{}.pth'.format(fold)))