def load_trainer(N, D_in, H, D_out, num_epochs, data_loader, data_loader_steps): device, model, loss_func, optimizer, scheduler, metrics = get_trainer_base( D_in, H, D_out) # NOTICE, load_trainer IS A STATIC METHOD IN Trainer CLASS loaded_trainer = Trainer.load_trainer(dir_path=save_to_dir, file_name=trainer_file_name + '_manual_save', model=model, device=device, loss_func=loss_func, optimizer=optimizer, scheduler=scheduler, train_data_loader=data_loader, val_data_loader=data_loader, train_steps=data_loader_steps, val_steps=data_loader_steps) return loaded_trainer
def test_save_and_load(self): gu.seed_all(42) save_to_dir = os.path.dirname(__file__) + '/trainer_checkpoint/' trainer_file_name = 'trainer' device = tu.get_gpu_device_if_available() model = eu.get_basic_model(10, 10, 10).to(device) loss_func = nn.CrossEntropyLoss().to(device) optimizer = optim.Adam(model.parameters(), lr=1e-4) scheduler = KerasDecay(optimizer, 0.0001, last_step=-1) metrics = CategoricalAccuracyWithLogits(name='acc') callbacks = [ LossOptimizerHandler(), ModelCheckPoint(checkpoint_dir=save_to_dir, checkpoint_file_name=trainer_file_name, callback_monitor=CallbackMonitor(monitor_type=MonitorType.LOSS, stats_type=StatsType.VAL, monitor_mode=MonitorMode.MIN), save_best_only=False, save_full_trainer=True, verbose=0), SchedulerStep(apply_on_phase=Phase.BATCH_END, apply_on_states=State.TRAIN), StatsPrint() ] data_loader = eu.examples_data_generator(10, 10, 10, category_out=True) data_loader_steps = 100 num_epochs = 5 trainer = Trainer(model=model, device=device, loss_func=loss_func, optimizer=optimizer, scheduler=scheduler, metrics=metrics, train_data_loader=data_loader, val_data_loader=data_loader, train_steps=data_loader_steps, val_steps=data_loader_steps, callbacks=callbacks, name='Trainer-Test') trainer.train(num_epochs, verbose=0) loaded_trainer = Trainer.load_trainer(dir_path=save_to_dir, file_name=trainer_file_name + f'_epoch_{num_epochs}', model=model, device=device, loss_func=loss_func, optimizer=optimizer, scheduler=scheduler, train_data_loader=data_loader, val_data_loader=data_loader, train_steps=data_loader_steps, val_steps=data_loader_steps) self.assertEqual(loaded_trainer.epoch, trainer.epoch) self.assertListEqual(tu.get_lrs_from_optimizer(loaded_trainer.optimizer), tu.get_lrs_from_optimizer(trainer.optimizer)) self.assertEqual(loaded_trainer.callbacks[1].monitor._get_best(), trainer.callbacks[1].monitor._get_best())