Example #1
0
def train_fold(fold, train_eval_data, unsup_data):
    train_indices, eval_indices = indices_for_fold(fold, train_eval_data)

    train_dataset = TrainEvalDataset(train_eval_data.iloc[train_indices],
                                     transform=train_transform)
    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        drop_last=True,
        shuffle=True,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)
    eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices],
                                    transform=eval_transform)
    eval_data_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=config.batch_size,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)
    unsup_data = TestDataset(unsup_data, transform=unsup_transform)
    unsup_data_loader = torch.utils.data.DataLoader(
        unsup_data,
        batch_size=config.batch_size // 2,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)

    model = Model(config.model, NUM_CLASSES)
    model = model.to(DEVICE)
    if args.restore_path is not None:
        model.load_state_dict(
            torch.load(
                os.path.join(args.restore_path, 'model_{}.pth'.format(fold))))

    optimizer = build_optimizer(config.opt, model.parameters())

    if config.sched.type == 'onecycle':
        scheduler = lr_scheduler_wrapper.StepWrapper(
            OneCycleScheduler(optimizer,
                              lr=(config.opt.lr / 20, config.opt.lr),
                              beta_range=config.sched.onecycle.beta,
                              max_steps=len(train_data_loader) * config.epochs,
                              annealing=config.sched.onecycle.anneal,
                              peak_pos=config.sched.onecycle.peak_pos,
                              end_pos=config.sched.onecycle.end_pos))
    elif config.sched.type == 'step':
        scheduler = lr_scheduler_wrapper.EpochWrapper(
            torch.optim.lr_scheduler.StepLR(
                optimizer,
                step_size=config.sched.step.step_size,
                gamma=config.sched.step.decay))
    elif config.sched.type == 'cyclic':
        step_size_up = len(
            train_data_loader) * config.sched.cyclic.step_size_up
        step_size_down = len(
            train_data_loader) * config.sched.cyclic.step_size_down

        scheduler = lr_scheduler_wrapper.StepWrapper(
            torch.optim.lr_scheduler.CyclicLR(
                optimizer,
                0.,
                config.opt.lr,
                step_size_up=step_size_up,
                step_size_down=step_size_down,
                mode='triangular2',
                gamma=config.sched.cyclic.decay**(
                    1 / (step_size_up + step_size_down)),
                cycle_momentum=True,
                base_momentum=0.85,
                max_momentum=0.95))
    elif config.sched.type == 'cawr':
        scheduler = lr_scheduler_wrapper.StepWrapper(
            torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=len(train_data_loader), T_mult=2))
    elif config.sched.type == 'plateau':
        scheduler = lr_scheduler_wrapper.ScoreWrapper(
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='max',
                factor=config.sched.plateau.decay,
                patience=config.sched.plateau.patience,
                verbose=True))
    else:
        raise AssertionError('invalid sched {}'.format(config.sched.type))

    best_score = 0
    for epoch in range(1, config.epochs + 1):
        optimizer.train()
        train_epoch(model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    data_loader=train_data_loader,
                    unsup_data_loader=unsup_data_loader,
                    fold=fold,
                    epoch=epoch)
        gc.collect()
        optimizer.eval()
        metric = eval_epoch(model=model,
                            data_loader=eval_data_loader,
                            fold=fold,
                            epoch=epoch)
        gc.collect()

        score = metric['accuracy@1']

        scheduler.step_epoch()
        scheduler.step_score(score)

        if score > best_score:
            best_score = score
            torch.save(
                model.state_dict(),
                os.path.join(args.experiment_path,
                             'model_{}.pth'.format(fold)))
Example #2
0
def train():
    train_dataset = torchvision.datasets.ImageNet(args.dataset_path,
                                                  split='train',
                                                  transform=train_transform)
    train_dataset = utils.RandomSubset(train_dataset, len(train_dataset) // 4)
    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        drop_last=True,
        shuffle=True,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)

    eval_dataset = torchvision.datasets.ImageNet(args.dataset_path,
                                                 split='val',
                                                 transform=eval_transform)
    eval_data_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=config.batch_size * 2,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)

    model = MobileNetV3(3, NUM_CLASSES)
    model = model.to(DEVICE)
    if args.restore_path is not None:
        model.load_state_dict(torch.load(args.restore_path))

    optimizer = build_optimizer(config.opt, model.parameters())

    if config.sched.type == 'step':
        scheduler = lr_scheduler_wrapper.EpochWrapper(
            torch.optim.lr_scheduler.StepLR(
                optimizer,
                step_size=config.sched.step.step_size,
                gamma=config.sched.step.decay))
    elif config.sched.type == 'plateau':
        scheduler = lr_scheduler_wrapper.ScoreWrapper(
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='max',
                factor=config.sched.plateau.decay,
                patience=config.sched.plateau.patience,
                verbose=True))
    elif config.sched.type == 'cyclic':
        scheduler = lr_scheduler_wrapper.StepWrapper(
            torch.optim.lr_scheduler.CyclicLR(
                optimizer,
                0.,
                config.opt.lr,
                step_size_up=len(train_data_loader) *
                config.sched.cyclic.step_size_up,
                step_size_down=len(train_data_loader) *
                config.sched.cyclic.step_size_down,
                mode='triangular2',
                cycle_momentum=True,
                base_momentum=0.75,
                max_momentum=0.95))
    elif config.sched.type == 'cawr':
        scheduler = lr_scheduler_wrapper.StepWrapper(
            torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=len(train_data_loader), T_mult=2))
    else:
        raise AssertionError('invalid sched {}'.format(config.sched.type))

    for epoch in range(config.epochs):
        train_epoch(model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    data_loader=train_data_loader,
                    epoch=epoch)
        gc.collect()
        metric = eval_epoch(model=model,
                            data_loader=eval_data_loader,
                            epoch=epoch)
        gc.collect()

        scheduler.step_epoch()
        scheduler.step_score(metric['accuracy@1'])

        torch.save(model.state_dict(),
                   os.path.join(args.experiment_path, 'model.pth'))
Example #3
0
def train():
    train_dataset = ADE20K(args.dataset_path,
                           train=True,
                           transform=train_transform)
    train_dataset = torch.utils.data.Subset(
        train_dataset,
        np.random.permutation(len(train_dataset))[:len(train_dataset) // 1])
    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        drop_last=True,
        shuffle=True,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)

    eval_dataset = ADE20K(args.dataset_path,
                          train=False,
                          transform=eval_transform)
    eval_dataset = torch.utils.data.Subset(
        eval_dataset,
        np.random.permutation(len(eval_dataset))[:len(eval_dataset) // 8])
    eval_data_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=config.batch_size,
        num_workers=args.workers,
        worker_init_fn=worker_init_fn)

    model = UNet(NUM_CLASSES)
    model = model.to(DEVICE)
    if args.restore_path is not None:
        model.load_state_dict(torch.load(args.restore_path))

    optimizer = build_optimizer(config.opt, model.parameters())

    if config.sched.type == 'step':
        scheduler = lr_scheduler_wrapper.EpochWrapper(
            torch.optim.lr_scheduler.StepLR(
                optimizer,
                step_size=config.sched.step.step_size,
                gamma=config.sched.step.decay))
    elif config.sched.type == 'plateau':
        scheduler = lr_scheduler_wrapper.ScoreWrapper(
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='max',
                factor=config.sched.plateau.decay,
                patience=config.sched.plateau.patience,
                verbose=True))
    elif config.sched.type == 'cawr':
        scheduler = lr_scheduler_wrapper.StepWrapper(
            torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=len(train_data_loader), T_mult=2))
    else:
        raise AssertionError('invalid sched {}'.format(config.sched.type))

    for epoch in range(config.epochs):
        train_epoch(model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    data_loader=train_data_loader,
                    epoch=epoch)
        gc.collect()
        metric = eval_epoch(model=model,
                            data_loader=eval_data_loader,
                            epoch=epoch)
        gc.collect()

        scheduler.step_epoch()
        scheduler.step_score(metric['iou'])

        torch.save(model.state_dict(),
                   os.path.join(args.experiment_path, 'model.pth'))
Example #4
0
def train_fold(fold, train_eval_data):
    model = Model(config.model, NUM_CLASSES)
    model = model.to(DEVICE)
    if args.restore_path is not None:
        model.load_state_dict(
            torch.load(
                os.path.join(args.restore_path, 'model_{}.pth'.format(fold))))

    optimizer = build_optimizer(config.opt, model.parameters())

    if config.sched.type == 'onecycle':
        scheduler = lr_scheduler_wrapper.EpochWrapper(
            OneCycleScheduler(optimizer,
                              lr=(config.opt.lr / 20, config.opt.lr),
                              beta_range=config.sched.onecycle.beta,
                              max_steps=config.epochs,
                              annealing=config.sched.onecycle.anneal,
                              peak_pos=config.sched.onecycle.peak_pos,
                              end_pos=config.sched.onecycle.end_pos))
    elif config.sched.type == 'cyclic':
        step_size_up = len(
            train_data_loader) * config.sched.cyclic.step_size_up
        step_size_down = len(
            train_data_loader) * config.sched.cyclic.step_size_down

        scheduler = lr_scheduler_wrapper.StepWrapper(
            torch.optim.lr_scheduler.CyclicLR(
                optimizer,
                0.,
                config.opt.lr,
                step_size_up=step_size_up,
                step_size_down=step_size_down,
                mode='exp_range',
                gamma=config.sched.cyclic.decay**(
                    1 / (step_size_up + step_size_down)),
                cycle_momentum=True,
                base_momentum=0.75,
                max_momentum=0.95))
    elif config.sched.type == 'cawr':
        scheduler = lr_scheduler_wrapper.StepWrapper(
            torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer, T_0=len(train_data_loader), T_mult=2))
    elif config.sched.type == 'plateau':
        scheduler = lr_scheduler_wrapper.ScoreWrapper(
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='max',
                factor=config.sched.plateau.decay,
                patience=config.sched.plateau.patience,
                verbose=True))
    else:
        raise AssertionError('invalid sched {}'.format(config.sched.type))

    best_score = 0
    for epoch in range(1, config.epochs + 1):
        train_indices, eval_indices = indices_for_fold(fold, train_eval_data)

        eval_pl = pd.read_csv(
            './tf_log/cells/tmp-512-progres-crop-norm-la/eval_{}.csv'.format(
                fold))
        eval_pl['root'] = os.path.join(args.dataset_path, 'train')
        test_pl = pd.read_csv(
            './tf_log/cells/tmp-512-progres-crop-norm-la/test.csv')
        test_pl['root'] = os.path.join(args.dataset_path, 'test')
        pl = pd.concat([eval_pl, test_pl])
        pl_size = len(pl)
        pl = pl.sample(frac=np.linspace(1., 0., config.epochs)[epoch -
                                                               1].item())
        print('frac: {:.4f}, lr: {:.8f}'.format(
            len(pl) / pl_size, scheduler.get_lr()))

        train_dataset = TrainEvalDataset(pd.concat(
            [train_eval_data.iloc[train_indices], pl]),
                                         transform=train_transform)
        train_data_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=config.batch_size,
            drop_last=True,
            shuffle=True,
            num_workers=args.workers,
            worker_init_fn=worker_init_fn)
        eval_dataset = TrainEvalDataset(train_eval_data.iloc[eval_indices],
                                        transform=eval_transform)
        eval_data_loader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=config.batch_size,
            num_workers=args.workers,
            worker_init_fn=worker_init_fn)

        train_epoch(model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    data_loader=train_data_loader,
                    fold=fold,
                    epoch=epoch)
        gc.collect()
        metric = eval_epoch(model=model,
                            data_loader=eval_data_loader,
                            fold=fold,
                            epoch=epoch)
        gc.collect()

        score = metric['accuracy@1']

        scheduler.step_epoch()
        scheduler.step_score(score)

        if score > best_score:
            best_score = score
            torch.save(
                model.state_dict(),
                os.path.join(args.experiment_path,
                             'model_{}.pth'.format(fold)))