Exemple #1
0
def test_savefun_and_writer_exclusive():
    # savefun and writer arguments cannot be specified together.
    def savefun(*args, **kwargs):
        assert False
    writer = writing.SimpleWriter()
    with pytest.raises(TypeError):
        extensions.snapshot(savefun=savefun, writer=writer)

    trainer = mock.MagicMock()
    with pytest.raises(TypeError):
        extensions.snapshot_object(trainer, savefun=savefun, writer=writer)
Exemple #2
0
def test_save_file(remover):
    trainer = get_trainer(out_dir='.')
    trainer._done = True
    w = writing.SimpleWriter()
    snapshot = extensions.snapshot_object(trainer, 'myfile.dat', writer=w)
    snapshot(trainer)

    assert os.path.exists('myfile.dat')
Exemple #3
0
def test_multi_target_autoload_not_found(remover):
    trainer = get_trainer(out_dir='.')
    other = _StateDictObj(state_dict={'original': 'state'})

    target = {'trainer': trainer, 'other': other}
    snapshot = extensions.snapshot_object(target, 'myfile.dat', autoload=True)

    assert snapshot.initialize(trainer) is None
    assert other.state_dict() == {'original': 'state'}
Exemple #4
0
def test_clean_up_tempdir(remover):
    trainer = get_trainer_with_mock_updater(out_dir='.')
    trainer._done = True
    snapshot = extensions.snapshot_object(trainer, 'myfile.dat')
    snapshot(trainer)

    left_tmps = [fn for fn in os.listdir('.')
                 if fn.startswith('tmpmyfile.dat')]
    assert len(left_tmps) == 0
 def __call__(self, manager: ExtensionsManager) -> None:
     lr_after = self.optimizer.param_groups[self.param_group]['lr']
     if self.lr_before < lr_after:
         filename = self.filename.format(cycle_count=self.cycle_count)
         save_func = snapshot_object(self.target,
                                     filename,
                                     saver_rank=self.saver_rank)
         save_func(manager)
         self.cycle_count += 1
     self.lr_before = lr_after
Exemple #6
0
def test_multi_target_autoload(remover):
    trainer = get_trainer(out_dir='.')
    trainer._done = True
    other_state_dict = {'test': True}
    other = _StateDictObj(state_dict=other_state_dict)
    w = ppe.writing.SimpleWriter()
    target = {'trainer': trainer, 'other': other}
    snapshot = extensions.snapshot_object(target, 'myfile.dat', writer=w)
    snapshot(trainer)

    assert os.path.exists('myfile.dat')
    new_trainer = get_trainer(out_dir='.')
    new_other = _StateDictObj(state_dict={})

    target = {'trainer': new_trainer, 'other': new_other}
    snapshot2 = extensions.snapshot_object(target, 'myfile.dat', autoload=True)
    # Load the snapshot and verify it
    assert snapshot2.initialize(new_trainer) == 'myfile.dat'
    assert new_trainer.state_dict() == trainer.state_dict()
    assert new_other.state_dict() == other_state_dict
Exemple #7
0
def test_on_error():
    # Will fail when accesing the dummy optimizer
    optimizers = {'main': object()}
    trainer = training.ExtensionsManager(
        {}, optimizers, 1,
        iters_per_epoch=1,
        out_dir='.')
    filename = 'myfile-deadbeef.dat'

    snapshot = extensions.snapshot_object(trainer, filename,
                                          snapshot_on_error=True)
    trainer.extend(snapshot)
    assert not os.path.exists(filename)
    with pytest.raises(AttributeError):
        with trainer.run_iteration():
            pass
    assert not os.path.exists(filename)
Exemple #8
0
def test_multi_target(remover):
    trainer = get_trainer(out_dir='.')
    trainer._done = True
    other_state_dict = {'test': True}
    other = _StateDictObj(state_dict=other_state_dict)
    w = ppe.writing.SimpleWriter()
    target = {'trainer': trainer, 'other': other}
    snapshot = extensions.snapshot_object(target, 'myfile.dat', writer=w)
    snapshot(trainer)

    assert os.path.exists('myfile.dat')
    # Load the snapshot and verify it
    state = torch.load('myfile.dat')
    new_trainer = get_trainer(out_dir='.')
    new_other = _StateDictObj(state_dict={})
    new_trainer.load_state_dict(state['trainer'])
    new_other.load_state_dict(state['other'])
    assert new_trainer.state_dict() == trainer.state_dict()
    assert new_other.state_dict() == other_state_dict
    models = {"main": model}
    optimizers = {"main": optimizer}
    manager = IgniteExtensionsManager(
        trainer,
        models,
        optimizers,
        epoch,
        extensions=extensions,
        out_dir=str(out_dir),
    )
    # Run evaluation for valid dataset in each epoch.
    manager.extend(valid_evaluator, trigger=(flags.validation_freq, "iteration"))
    if local_rank == 0:
        # Save predictor.pt every epoch
        manager.extend(E.snapshot_object(predictor, "predictor.pt"),
                       trigger=(flags.snapshot_freq, "iteration"))
        # Check & Save best validation predictor.pt every epoch
        # manager.extend(E.snapshot_object(predictor, "best_predictor.pt"),
        #                trigger=MinValueTrigger("validation/mainmodule/nll",
        #                trigger=(flags.snapshot_freq, "iteration")))

    # --- lr scheduler ---
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #     optimizer, mode='min', factor=0.7, patience=5, min_lr=1e-10)
    if flags.scheduler_type != "":
        scheduler_type = flags.scheduler_type

        # For backward compatibility
        if scheduler_type == "exponential":
            scheduler_type = "ExponentialLR"
Exemple #10
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--snapshot', type=str)
    parser.add_argument('--snapmodel', type=str)
    args = parser.parse_args()

    # get config
    config = Config()

    # set seed
    random.seed(config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed(config.seed)
    torch.backends.cudnn.deterministic = True

    # create model
    model = utils.create_model(config)
    device = 'cuda'
    model.cuda()

    # define transforms
    train_trans = transforms.train_transform(resize=(config.input_size_h,
                                                     config.input_size_w),
                                             normalize=config.normalize)
    val_trans = transforms.eval_transform(resize=(config.input_size_h,
                                                  config.input_size_w),
                                          normalize=config.normalize)

    # copy config and src
    if not os.path.exists(os.path.join(config.result, 'src')):
        os.makedirs(os.path.join(config.result, 'src'), exist_ok=True)
    for src_file in glob.glob('/work/*.py') + glob.glob('/work/*/*.py'):
        shutil.copy(
            src_file,
            os.path.join(config.result, 'src', os.path.basename(src_file)))

    # create dataset
    train_dataset = dataset.Alaska2Dataset(
        root=config.data,
        transforms=train_trans,
        train=True,
        batchsize=config.batchsize,
        uniform=config.batch_uniform,
    )
    val_dataset = dataset.Alaska2Dataset(root=config.data,
                                         transforms=val_trans,
                                         train=False,
                                         uniform=False)

    # create data loader
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=config.batchsize,
                                               num_workers=config.num_workers,
                                               shuffle=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=config.batchsize,
                                             num_workers=config.num_workers,
                                             shuffle=False)

    # set optimizer
    #    optimizer = torch.optim.AdamW([{'params': model.parameters()}, {'params':metrics_fc.parameters()}],
    #                                 lr=config.lr
    #                                 )
    #else:
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    #optimizer = torch.optim.SGD(model.parameters(),
    #                              lr=config.lr,
    #                            momentum=0.9)

    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
    # for convenient interoperation with argparse.
    if config.fp16:
        opt_level = 'O1'
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=opt_level
                                          #keep_batchnorm_fp32=True
                                          )

    # set scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=0.5,
        patience=2,
        threshold_mode='abs',
        min_lr=1e-8,
        eps=1e-08)

    # set criterion
    #criterion = torch.nn.CrossEntropyLoss()
    criterion = LabelSmoothing().cuda()
    num_epochs = config.num_epochs

    # set manager
    iters_per_epoch = len(train_loader)

    manager = ppe.training.ExtensionsManager(model,
                                             optimizer,
                                             num_epochs,
                                             iters_per_epoch=iters_per_epoch,
                                             out_dir=config.result,
                                             stop_trigger=None)

    log_interval = (100, 'iteration')
    #eval_interval = (500, 'iteration')
    eval_interval = (1, 'epoch')

    manager.extend(extensions.snapshot(filename='best_snapshot'),
                   trigger=MaxValueTrigger('validation/auc',
                                           trigger=eval_interval))
    if config.fp16:
        manager.extend(extensions.snapshot_object(amp, filename='amp.ckpt'),
                       trigger=MaxValueTrigger('validation/auc',
                                               trigger=eval_interval))

    manager.extend(extensions.LogReport(trigger=log_interval))

    manager.extend(extensions.PlotReport(['train/loss', 'validation/loss'],
                                         'epoch',
                                         filename='loss.png'),
                   trigger=(1, 'epoch'))

    manager.extend(extensions.PrintReport([
        'epoch', 'iteration', 'train/loss', 'validation/loss',
        'validation/auc', 'lr', 'elapsed_time'
    ]),
                   trigger=log_interval)

    manager.extend(extensions.ProgressBar(update_interval=100))
    manager.extend(extensions.observe_lr(optimizer=optimizer),
                   trigger=log_interval)
    #manager.extend(extensions.ParnnameterStatistics(model, prefix='model'))
    #manager.extend(extensions.VariableStatisticsPlot(model))

    manager.extend(ALASKAEvaluator(val_loader,
                                   model,
                                   eval_hook=None,
                                   eval_func=None,
                                   loss_criterion=criterion,
                                   auc_criterion=auc_eval_func,
                                   device=device,
                                   scheduler=scheduler,
                                   metric_learning=config.metric_learning),
                   trigger=eval_interval)

    # Lets load the snapshot
    if args.snapshot is not None:
        state = torch.load(args.snapshot)
        manager.load_state_dict(state)
        #amp = torch.load('amp.ckpt')
    elif args.snapmodel is not None:
        print('load snapshot model {}'.format(args.snapmodel))
        state = torch.load(args.snapmodel)
        manager._models['main'].load_state_dict(state['models']['main'])

    train_func(manager,
               model,
               criterion,
               optimizer,
               train_loader,
               device,
               metric_learning=config.metric_learning,
               fp16=config.fp16)