def get_model() -> common.LernomaticModel: if GLOBAL_OPTS['model'] == 'resnet': model = resnets.WideResnet(depth=GLOBAL_OPTS['resnet_depth'], num_classes=10, input_channels=3) elif GLOBAL_OPTS['model'] == 'cifar': model = cifar.CIFAR10Net() else: raise ValueError('Unknown model type [%s]' % str(GLOBAL_OPTS['model'])) return model
def test_history_extend(self) -> None: test_checkpoint = 'checkpoint/test_history_extend.pkl' test_history = 'checkpoint/test_history_extend_history.pkl' model = cifar.CIFAR10Net() # TODO : adjust this so we don't need 10 epochs of training # Get trainer object test_num_epochs = 10 trainer = cifar_trainer.CIFAR10Trainer( model, save_every = 0, print_every = 50, device_id = util.get_device_id(), # loader options, num_epochs = test_num_epochs, learning_rate = 3e-4, batch_size = 64, num_workers = self.test_num_workers, ) print('Training original model') trainer.train() trainer.save_checkpoint(test_checkpoint) trainer.save_history(test_history) # Load a new trainer, train for another 10 epochs (20 total) extend_trainer = cifar_trainer.CIFAR10Trainer( model, save_every = 0, print_every = 50, device_id = util.get_device_id(), # loader options, num_epochs = 10, learning_rate = 3e-4, batch_size = 64, num_workers = self.test_num_workers, ) print('Loading checkpoint [%s] into extend trainer...' % str(test_checkpoint)) extend_trainer.load_checkpoint(test_checkpoint) print('Loading history [%s] into extend trainer...' % str(test_history)) extend_trainer.load_history(test_history) # Check history before extending assert trainer.device_id == extend_trainer.device_id print('extend_trainer device : %s' % str(extend_trainer.device)) assert test_num_epochs == extend_trainer.num_epochs assert 10 == extend_trainer.cur_epoch assert extend_trainer.loss_history is not None assert (10 * len(extend_trainer.train_loader)) == len(extend_trainer.loss_history) extend_trainer.set_num_epochs(20) assert(20 * len(extend_trainer.train_loader) == len(extend_trainer.loss_history)) for i in range(10 * len(extend_trainer.train_loader)): print('Checking loss iter [%d / %d]' % (i, 20 * len(extend_trainer.train_loader)), end='\r') assert trainer.loss_history[i] == extend_trainer.loss_history[i] extend_trainer.train()
def get_model(model_type:str, num_classes:int=10, input_channels:int=3, depth:int=58) -> common.LernomaticModel: if model_type == 'resnet': model = resnets.WideResnet( depth=depth, num_classes=num_classes, input_channels = input_channels ) elif model_type == 'cifar': model = cifar.CIFAR10Net() else: raise ValueError('Unknown model type [%s]' % str(GLOBAL_OPTS['model'])) return model
def get_trainer() -> cifar_trainer.CIFAR10Trainer: # get a model to test on and its corresponding trainer model = cifar.CIFAR10Net() trainer = cifar_trainer.CIFAR10Trainer( model, # turn off checkpointing save_every = 0, print_every = GLOBAL_TEST_PARAMS['test_print_every'], # data options batch_size = GLOBAL_TEST_PARAMS['test_batch_size'], # training options learning_rate = GLOBAL_TEST_PARAMS['test_learning_rate'], num_epochs = GLOBAL_TEST_PARAMS['train_num_epochs'], device_id = util.get_device_id(), ) return trainer
def get_trainer(learning_rate: float = None) -> cifar_trainer.CIFAR10Trainer: if learning_rate is not None: test_learning_rate = learning_rate else: test_learning_rate = GLOBAL_OPTS['learning_rate'] model = cifar.CIFAR10Net() trainer = cifar_trainer.CIFAR10Trainer( model, # data options batch_size=GLOBAL_OPTS['batch_size'], num_workers=GLOBAL_OPTS['num_workers'], num_epochs=GLOBAL_OPTS['num_epochs'], # set initial learning rate learning_rate=test_learning_rate, # other options save_every=0, print_every=200, device_id=util.get_device_id(), verbose=GLOBAL_OPTS['verbose']) return trainer
def test_train(self) -> None: test_checkpoint = 'checkpoint/trainer_train_test.pkl' model = cifar.CIFAR10Net() # Get trainer object trainer = cifar_trainer.CIFAR10Trainer( model, save_every = 0, print_every = 50, device_id = util.get_device_id(), # loader options, num_epochs = self.test_num_epochs, learning_rate = 3e-4, batch_size = 128, num_workers = self.test_num_workers, ) if self.verbose: print('Created trainer object') print(trainer) # train for one epoch trainer.train() trainer.save_checkpoint(test_checkpoint) fig, ax = get_figure_subplots() vis_loss_history.plot_train_history_2subplots( ax, trainer.loss_history, acc_history = trainer.acc_history, cur_epoch = trainer.cur_epoch, iter_per_epoch = trainer.iter_per_epoch, loss_title = 'CIFAR-10 Trainer Test loss', acc_title = 'CIFAR-10 Trainer Test accuracy' ) fig.savefig('figures/trainer_train_test_history.png', bbox_inches='tight')
def get_model() -> common.LernomaticModel: model = cifar.CIFAR10Net() return model
def main() -> None: # Get a model model = cifar.CIFAR10Net() # Get a trainer trainer = cifar_trainer.CIFAR10Trainer( model, # training parameters batch_size=GLOBAL_OPTS['batch_size'], num_epochs=GLOBAL_OPTS['num_epochs'], learning_rate=GLOBAL_OPTS['learning_rate'], momentum=GLOBAL_OPTS['momentum'], weight_decay=GLOBAL_OPTS['weight_decay'], # device device_id=GLOBAL_OPTS['device_id'], # checkpoint checkpoint_dir=GLOBAL_OPTS['checkpoint_dir'], checkpoint_name=GLOBAL_OPTS['checkpoint_name'], # display, print_every=GLOBAL_OPTS['print_every'], save_every=GLOBAL_OPTS['save_every'], verbose=GLOBAL_OPTS['verbose']) if GLOBAL_OPTS['tensorboard_dir'] is not None: writer = tensorboard.SummaryWriter( log_dir=GLOBAL_OPTS['tensorboard_dir']) trainer.set_tb_writer(writer) # Optionally do a search pass here and add a scheduler if GLOBAL_OPTS['find_lr']: lr_finder = expr_util.get_lr_finder(trainer) lr_find_start_time = time.time() lr_finder.find() lr_find_min, lr_find_max = lr_finder.get_lr_range() lr_find_end_time = time.time() lr_find_total_time = lr_find_end_time - lr_find_start_time print('Found learning rate range %.4f -> %.4f' % (lr_find_min, lr_find_max)) print('Total find time [%s] ' %\ str(timedelta(seconds = lr_find_total_time)) ) # Now get a scheduler stepsize = trainer.get_num_epochs() * len(trainer.train_loader) // 2 # get scheduler lr_scheduler = expr_util.get_scheduler( lr_find_min, lr_find_max, stepsize, sched_type='TriangularScheduler') trainer.set_lr_scheduler(lr_scheduler) # train the model train_start_time = time.time() trainer.train() train_end_time = time.time() train_total_time = train_end_time - train_start_time print('Total training time [%s] (%d epochs) %s' %\ (repr(trainer), trainer.cur_epoch, str(timedelta(seconds = train_total_time))) ) # Visualise the output train_fig, train_ax = vis_loss_history.get_figure_subplots() vis_loss_history.plot_train_history_2subplots( train_ax, trainer.get_loss_history(), acc_history=trainer.get_acc_history(), cur_epoch=trainer.cur_epoch, iter_per_epoch=trainer.iter_per_epoch, loss_title='CIFAR-10 Training Loss', acc_title='CIFAR-10 Training Accuracy ') train_fig.savefig(GLOBAL_OPTS['fig_name'], bbox_inches='tight')
def test_save_load_checkpoint(self) -> None: test_checkpoint = 'checkpoint/trainer_test_save_load.pkl' # get a model model = cifar.CIFAR10Net() # get a trainer src_tr = cifar_trainer.CIFAR10Trainer( model, num_epochs = self.test_num_epochs, save_every = 0, device_id = util.get_device_id(), batch_size = self.test_batch_size, num_workers = self.test_num_workers ) if self.verbose: print('Created trainer object') print(src_tr) # train for one epoch src_tr.train() src_tr.save_checkpoint(test_checkpoint) # Make a new trainer and load all parameters into that # I guess we need to put some kind of loader and model here... dst_tr = cifar_trainer.CIFAR10Trainer( model, device_id = util.get_device_id() ) dst_tr.load_checkpoint(test_checkpoint) # Test object parameters assert src_tr.num_epochs == dst_tr.num_epochs assert src_tr.learning_rate == dst_tr.learning_rate assert src_tr.weight_decay == dst_tr.weight_decay assert src_tr.print_every == dst_tr.print_every assert src_tr.save_every == dst_tr.save_every assert src_tr.device_id == dst_tr.device_id print('\t Comparing model parameters ') src_model_params = src_tr.get_model_params() dst_model_params = dst_tr.get_model_params() assert len(src_model_params.items()) == len(dst_model_params.items()) # p1, p2 are k,v tuple pairs of each model parameters # k = str # v = torch.Tensor for n, (p1, p2) in enumerate(zip(src_model_params.items(), dst_model_params.items())): assert p1[0] == p2[0] print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n+1, len(src_model_params.items())), end='\r') assert torch.equal(p1[1], p2[1]) == True print('\n ...done') # Test loss history val_loss_history = 'test_save_load_history.pkl' src_tr.save_history(val_loss_history) dst_tr.load_history(val_loss_history) print('\t Comparing loss history....') assert src_tr.loss_iter == dst_tr.loss_iter for n in range(src_tr.loss_iter): print('Checking loss element [%d/%d]' % (n, src_tr.loss_iter), end='\r') assert src_tr.loss_history[n] == dst_tr.loss_history[n] # Try to train for another epoch dst_tr.set_num_epochs(src_tr.num_epochs+1) assert dst_tr.num_epochs == src_tr.num_epochs+1 dst_tr.train() assert src_tr.num_epochs+1 == dst_tr.cur_epoch print('\n ...done') os.remove(test_checkpoint)
def test_save_load_device_map(self) -> None: test_checkpoint = 'checkpoint/trainer_save_load_device_map.pkl' test_history = 'checkpoint/trainer_save_load_device_map_history.pkl' model = cifar.CIFAR10Net() # Get trainer object src_tr = cifar_trainer.CIFAR10Trainer( model, save_every = 0, print_every = 50, device_id = util.get_device_id(), # loader options, num_epochs = self.test_num_epochs, batch_size = self.test_batch_size, num_workers = self.test_num_workers, ) if self.verbose: print('Created trainer object') print(src_tr) # train for one epoch src_tr.train() src_tr.save_checkpoint(test_checkpoint) assert src_tr.acc_history is not None src_tr.save_history(test_history) # Now try to load a checkpoint and ensure that there is an # acc history attribute that is not None dst_tr = cifar_trainer.CIFAR10Trainer( model, device_id = -1, verbose = self.verbose ) dst_tr.load_checkpoint(test_checkpoint) # Test object parameters assert src_tr.num_epochs == dst_tr.num_epochs assert src_tr.learning_rate == dst_tr.learning_rate assert src_tr.weight_decay == dst_tr.weight_decay assert src_tr.print_every == dst_tr.print_every assert src_tr.save_every == dst_tr.save_every print('\t Comparing model parameters ') src_model_params = src_tr.get_model_params() dst_model_params = dst_tr.get_model_params() assert len(src_model_params.items()) == len(dst_model_params.items()) print('\t Comparing model parameters ') src_model_params = src_tr.get_model_params() dst_model_params = dst_tr.get_model_params() assert len(src_model_params.items()) == len(dst_model_params.items()) # p1, p2 are k,v tuple pairs of each model parameters # k = str # v = torch.Tensor print('Checking that tensors from checkpoint have been tranferred to new device') for n, (p1, p2) in enumerate(zip(src_model_params.items(), dst_model_params.items())): assert p1[0] == p2[0] print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n, len(src_model_params.items())), end='\r') assert 'cpu' == p2[1].device.type print('\n ...done') # Test loss history dst_tr.load_history(test_history) assert dst_tr.acc_history is not None print('\t Comparing loss history....') assert src_tr.loss_iter == dst_tr.loss_iter for n in range(src_tr.loss_iter): print('Checking loss element [%d/%d]' % (n, src_tr.loss_iter), end='\r') assert src_tr.loss_history[n] == dst_tr.loss_history[n]