def test_savefun_and_writer_exclusive(): # savefun and writer arguments cannot be specified together. def savefun(*args, **kwargs): assert False writer = writing.SimpleWriter() with pytest.raises(TypeError): extensions.snapshot(savefun=savefun, writer=writer) trainer = mock.MagicMock() with pytest.raises(TypeError): extensions.snapshot_object(trainer, savefun=savefun, writer=writer)
def test_save_file(remover): trainer = get_trainer(out_dir='.') trainer._done = True w = writing.SimpleWriter() snapshot = extensions.snapshot_object(trainer, 'myfile.dat', writer=w) snapshot(trainer) assert os.path.exists('myfile.dat')
def test_multi_target_autoload_not_found(remover): trainer = get_trainer(out_dir='.') other = _StateDictObj(state_dict={'original': 'state'}) target = {'trainer': trainer, 'other': other} snapshot = extensions.snapshot_object(target, 'myfile.dat', autoload=True) assert snapshot.initialize(trainer) is None assert other.state_dict() == {'original': 'state'}
def test_clean_up_tempdir(remover): trainer = get_trainer_with_mock_updater(out_dir='.') trainer._done = True snapshot = extensions.snapshot_object(trainer, 'myfile.dat') snapshot(trainer) left_tmps = [fn for fn in os.listdir('.') if fn.startswith('tmpmyfile.dat')] assert len(left_tmps) == 0
def __call__(self, manager: ExtensionsManager) -> None: lr_after = self.optimizer.param_groups[self.param_group]['lr'] if self.lr_before < lr_after: filename = self.filename.format(cycle_count=self.cycle_count) save_func = snapshot_object(self.target, filename, saver_rank=self.saver_rank) save_func(manager) self.cycle_count += 1 self.lr_before = lr_after
def test_multi_target_autoload(remover): trainer = get_trainer(out_dir='.') trainer._done = True other_state_dict = {'test': True} other = _StateDictObj(state_dict=other_state_dict) w = ppe.writing.SimpleWriter() target = {'trainer': trainer, 'other': other} snapshot = extensions.snapshot_object(target, 'myfile.dat', writer=w) snapshot(trainer) assert os.path.exists('myfile.dat') new_trainer = get_trainer(out_dir='.') new_other = _StateDictObj(state_dict={}) target = {'trainer': new_trainer, 'other': new_other} snapshot2 = extensions.snapshot_object(target, 'myfile.dat', autoload=True) # Load the snapshot and verify it assert snapshot2.initialize(new_trainer) == 'myfile.dat' assert new_trainer.state_dict() == trainer.state_dict() assert new_other.state_dict() == other_state_dict
def test_on_error(): # Will fail when accesing the dummy optimizer optimizers = {'main': object()} trainer = training.ExtensionsManager( {}, optimizers, 1, iters_per_epoch=1, out_dir='.') filename = 'myfile-deadbeef.dat' snapshot = extensions.snapshot_object(trainer, filename, snapshot_on_error=True) trainer.extend(snapshot) assert not os.path.exists(filename) with pytest.raises(AttributeError): with trainer.run_iteration(): pass assert not os.path.exists(filename)
def test_multi_target(remover): trainer = get_trainer(out_dir='.') trainer._done = True other_state_dict = {'test': True} other = _StateDictObj(state_dict=other_state_dict) w = ppe.writing.SimpleWriter() target = {'trainer': trainer, 'other': other} snapshot = extensions.snapshot_object(target, 'myfile.dat', writer=w) snapshot(trainer) assert os.path.exists('myfile.dat') # Load the snapshot and verify it state = torch.load('myfile.dat') new_trainer = get_trainer(out_dir='.') new_other = _StateDictObj(state_dict={}) new_trainer.load_state_dict(state['trainer']) new_other.load_state_dict(state['other']) assert new_trainer.state_dict() == trainer.state_dict() assert new_other.state_dict() == other_state_dict
models = {"main": model} optimizers = {"main": optimizer} manager = IgniteExtensionsManager( trainer, models, optimizers, epoch, extensions=extensions, out_dir=str(out_dir), ) # Run evaluation for valid dataset in each epoch. manager.extend(valid_evaluator, trigger=(flags.validation_freq, "iteration")) if local_rank == 0: # Save predictor.pt every epoch manager.extend(E.snapshot_object(predictor, "predictor.pt"), trigger=(flags.snapshot_freq, "iteration")) # Check & Save best validation predictor.pt every epoch # manager.extend(E.snapshot_object(predictor, "best_predictor.pt"), # trigger=MinValueTrigger("validation/mainmodule/nll", # trigger=(flags.snapshot_freq, "iteration"))) # --- lr scheduler --- # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( # optimizer, mode='min', factor=0.7, patience=5, min_lr=1e-10) if flags.scheduler_type != "": scheduler_type = flags.scheduler_type # For backward compatibility if scheduler_type == "exponential": scheduler_type = "ExponentialLR"
def main(): parser = argparse.ArgumentParser() parser.add_argument('--snapshot', type=str) parser.add_argument('--snapmodel', type=str) args = parser.parse_args() # get config config = Config() # set seed random.seed(config.seed) torch.manual_seed(config.seed) torch.cuda.manual_seed(config.seed) torch.backends.cudnn.deterministic = True # create model model = utils.create_model(config) device = 'cuda' model.cuda() # define transforms train_trans = transforms.train_transform(resize=(config.input_size_h, config.input_size_w), normalize=config.normalize) val_trans = transforms.eval_transform(resize=(config.input_size_h, config.input_size_w), normalize=config.normalize) # copy config and src if not os.path.exists(os.path.join(config.result, 'src')): os.makedirs(os.path.join(config.result, 'src'), exist_ok=True) for src_file in glob.glob('/work/*.py') + glob.glob('/work/*/*.py'): shutil.copy( src_file, os.path.join(config.result, 'src', os.path.basename(src_file))) # create dataset train_dataset = dataset.Alaska2Dataset( root=config.data, transforms=train_trans, train=True, batchsize=config.batchsize, uniform=config.batch_uniform, ) val_dataset = dataset.Alaska2Dataset(root=config.data, transforms=val_trans, train=False, uniform=False) # create data loader train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batchsize, num_workers=config.num_workers, shuffle=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.batchsize, num_workers=config.num_workers, shuffle=False) # set optimizer # optimizer = torch.optim.AdamW([{'params': model.parameters()}, {'params':metrics_fc.parameters()}], # lr=config.lr # ) #else: optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr) #optimizer = torch.optim.SGD(model.parameters(), # lr=config.lr, # momentum=0.9) # Initialize Amp. Amp accepts either values or strings for the optional override arguments, # for convenient interoperation with argparse. if config.fp16: opt_level = 'O1' model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level #keep_batchnorm_fp32=True ) # set scheduler scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=0.5, patience=2, threshold_mode='abs', min_lr=1e-8, eps=1e-08) # set criterion #criterion = torch.nn.CrossEntropyLoss() criterion = LabelSmoothing().cuda() num_epochs = config.num_epochs # set manager iters_per_epoch = len(train_loader) manager = ppe.training.ExtensionsManager(model, optimizer, num_epochs, iters_per_epoch=iters_per_epoch, out_dir=config.result, stop_trigger=None) log_interval = (100, 'iteration') #eval_interval = (500, 'iteration') eval_interval = (1, 'epoch') manager.extend(extensions.snapshot(filename='best_snapshot'), trigger=MaxValueTrigger('validation/auc', trigger=eval_interval)) if config.fp16: manager.extend(extensions.snapshot_object(amp, filename='amp.ckpt'), trigger=MaxValueTrigger('validation/auc', trigger=eval_interval)) manager.extend(extensions.LogReport(trigger=log_interval)) manager.extend(extensions.PlotReport(['train/loss', 'validation/loss'], 'epoch', filename='loss.png'), trigger=(1, 'epoch')) manager.extend(extensions.PrintReport([ 'epoch', 'iteration', 'train/loss', 'validation/loss', 'validation/auc', 'lr', 'elapsed_time' ]), trigger=log_interval) manager.extend(extensions.ProgressBar(update_interval=100)) manager.extend(extensions.observe_lr(optimizer=optimizer), trigger=log_interval) #manager.extend(extensions.ParnnameterStatistics(model, prefix='model')) #manager.extend(extensions.VariableStatisticsPlot(model)) manager.extend(ALASKAEvaluator(val_loader, model, eval_hook=None, eval_func=None, loss_criterion=criterion, auc_criterion=auc_eval_func, device=device, scheduler=scheduler, metric_learning=config.metric_learning), trigger=eval_interval) # Lets load the snapshot if args.snapshot is not None: state = torch.load(args.snapshot) manager.load_state_dict(state) #amp = torch.load('amp.ckpt') elif args.snapmodel is not None: print('load snapshot model {}'.format(args.snapmodel)) state = torch.load(args.snapmodel) manager._models['main'].load_state_dict(state['models']['main']) train_func(manager, model, criterion, optimizer, train_loader, device, metric_learning=config.metric_learning, fp16=config.fp16)