def test_save_load_checkpoint_train(self) -> None: test_dataset_file = 'hdf5/trainer_unit_test.h5' test_checkpoint_name = 'checkpoint/save_load_test_checkpoint.pkl' # get a model, trainer model = mnist_net.MNISTNet() src_tr = mnist_trainer.MNISTTrainer( model, num_epochs=self.test_num_epochs, save_every=0, print_every=250, device_id=util.get_device_id(), # dataload options checkpoint_name='save_load_test', batch_size=16, num_workers=1 #num_workers = GLOBAL_OPTS['num_workers'] ) if self.verbose: print('Created trainer object') print(src_tr) # train for one epoch src_tr.train() src_tr.save_checkpoint(test_checkpoint_name) # Make a new trainer and load all parameters into that # I guess we need to put some kind of loader and model here... new_model = mnist_net.MNISTNet() dst_tr = mnist_trainer.MNISTTrainer(new_model, device_id=util.get_device_id()) dst_tr.load_checkpoint(test_checkpoint_name) # Test object parameters assert src_tr.num_epochs == dst_tr.num_epochs assert src_tr.learning_rate == dst_tr.learning_rate assert src_tr.weight_decay == dst_tr.weight_decay assert src_tr.print_every == dst_tr.print_every assert src_tr.save_every == dst_tr.save_every assert src_tr.device_id == dst_tr.device_id print('\t Comparing model parameters ') src_model_params = src_tr.get_model_params() dst_model_params = dst_tr.get_model_params() assert len(src_model_params.items()) == len(dst_model_params.items()) # p1, p2 are k,v tuple pairs of each model parameters # k = str # v = torch.Tensor for n, (p1, p2) in enumerate( zip(src_model_params.items(), dst_model_params.items())): assert p1[0] == p2[0] print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n + 1, len(src_model_params.items())), end='\r') assert torch.equal(p1[1], p2[1]) == True print('\n ...done')
def test_history_extend(self) -> None: test_checkpoint = 'checkpoint/test_history_extend.pkl' test_history = 'checkpoint/test_history_extend_history.pkl' model = cifar.CIFAR10Net() # TODO : adjust this so we don't need 10 epochs of training # Get trainer object test_num_epochs = 10 trainer = cifar_trainer.CIFAR10Trainer( model, save_every = 0, print_every = 50, device_id = util.get_device_id(), # loader options, num_epochs = test_num_epochs, learning_rate = 3e-4, batch_size = 64, num_workers = self.test_num_workers, ) print('Training original model') trainer.train() trainer.save_checkpoint(test_checkpoint) trainer.save_history(test_history) # Load a new trainer, train for another 10 epochs (20 total) extend_trainer = cifar_trainer.CIFAR10Trainer( model, save_every = 0, print_every = 50, device_id = util.get_device_id(), # loader options, num_epochs = 10, learning_rate = 3e-4, batch_size = 64, num_workers = self.test_num_workers, ) print('Loading checkpoint [%s] into extend trainer...' % str(test_checkpoint)) extend_trainer.load_checkpoint(test_checkpoint) print('Loading history [%s] into extend trainer...' % str(test_history)) extend_trainer.load_history(test_history) # Check history before extending assert trainer.device_id == extend_trainer.device_id print('extend_trainer device : %s' % str(extend_trainer.device)) assert test_num_epochs == extend_trainer.num_epochs assert 10 == extend_trainer.cur_epoch assert extend_trainer.loss_history is not None assert (10 * len(extend_trainer.train_loader)) == len(extend_trainer.loss_history) extend_trainer.set_num_epochs(20) assert(20 * len(extend_trainer.train_loader) == len(extend_trainer.loss_history)) for i in range(10 * len(extend_trainer.train_loader)): print('Checking loss iter [%d / %d]' % (i, 20 * len(extend_trainer.train_loader)), end='\r') assert trainer.loss_history[i] == extend_trainer.loss_history[i] extend_trainer.train()
def test_save_load(self) -> None: infer_test_checkpoint = 'checkpoint/infer_save_load_test.pkl' model = get_model() trainer = get_trainer(model, None, 64, 0) # train the model for a while trainer.train() # save a training checkpoint to disk and load it into an inferrer trainer.save_checkpoint(infer_test_checkpoint) infer = inferrer.Inferrer(device_id=util.get_device_id()) infer.load_model(infer_test_checkpoint) infer_model = infer.get_model() trainer_model = trainer.get_model() # check model parameters train_model_params = trainer.model.get_net_state_dict() infer_model_params = infer.model.get_net_state_dict() print('Comparing models') for n, (p1, p2) in enumerate( zip(train_model_params.items(), infer_model_params.items())): assert p1[0] == p2[0] print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n + 1, len(train_model_params.items()))) assert torch.equal(p1[1], p2[1]) == True print('\n ...done') # run the forward pass test_img, _ = next(iter(trainer.val_loader)) pred = infer.forward(test_img) print('Complete prediction vector (shape: %s)' % (str(pred.shape))) print(str(pred))
def test_train(self) -> None: test_checkpoint_name = self.checkpoint_dir + 'resnet_trainer_train_checkpoint.pkl' test_history_name = self.checkpoint_dir + 'resnet_trainer_train_history.pkl' train_num_epochs = 4 train_batch_size = 128 # get a model model = resnets.WideResnet( depth = self.resnet_depth, num_classes = 10, # using CIFAR-10 data input_channels=3, w_factor=1 ) # get a traner trainer = resnet_trainer.ResnetTrainer( model, # training parameters batch_size = train_batch_size, num_epochs = train_num_epochs, learning_rate = self.test_learning_rate, # device device_id = util.get_device_id(), # checkpoint checkpoint_dir = self.checkpoint_dir, checkpoint_name = 'resnet_trainer_test', # display, print_every = self.print_every, save_every = 5000, verbose = self.verbose ) if self.verbose: print('Created %s object' % repr(trainer)) print(trainer) print('Training model %s for %d epochs (batch size = %d)' %\ (repr(trainer), train_num_epochs, train_batch_size) ) trainer.train() # save the final checkpoint trainer.save_checkpoint(test_checkpoint_name) trainer.save_history(test_history_name) fig, ax = vis_loss_history.get_figure_subplots() vis_loss_history.plot_train_history_2subplots( ax, trainer.loss_history, acc_history = trainer.acc_history, cur_epoch = trainer.cur_epoch, iter_per_epoch = trainer.iter_per_epoch ) fig.savefig('figures/resnet_trainer_train_history.png', bbox_inches='tight') print('======== TestResnetTrainer.test_train <END>')
def test_save_load_history(self) -> None: test_history_name = 'checkpoint/test_history.pkl' # get a model, trainer model = mnist_net.MNISTNet() src_tr = mnist_trainer.MNISTTrainer( model, num_epochs=self.test_num_epochs, save_every=0, print_every=250, device_id=util.get_device_id(), # dataload options checkpoint_name='save_load_test', batch_size=16, num_workers=1) if self.verbose: print('Created trainer object') print(src_tr) # train for one epoch src_tr.train() src_tr.save_history(test_history_name) # Load history into new object dst_tr = mnist_trainer.MNISTTrainer( model, device_id=util.get_device_id(), ) dst_tr.load_history(test_history_name) # Test loss history print('\t Comparing loss history....') assert src_tr.loss_iter == dst_tr.loss_iter for n in range(src_tr.loss_iter): print('Checking loss element [%d/%d]' % (n, src_tr.loss_iter), end='\r') assert src_tr.loss_history[n] == dst_tr.loss_history[n] print('\n ...done')
def get_trainer(model: common.LernomaticModel, batch_size: int, checkpoint_name: str) -> cifar_trainer.CIFAR10Trainer: trainer = cifar_trainer.CIFAR10Trainer( model, num_epochs=4, checkpoint_name=checkpoint_name, # since we don't train for long a large learning rate helps learning_rate=1.5e-3, save_every=0, print_every=50, batch_size=batch_size, device_id=util.get_device_id(), verbose=True) return trainer
def get_trainer(model: common.LernomaticModel, checkpoint_name: str, batch_size: int, save_every: int) -> cifar_trainer.CIFAR10Trainer: trainer = cifar_trainer.CIFAR10Trainer(model, batch_size=batch_size, test_batch_size=1, device_id=util.get_device_id(), checkpoint_name=checkpoint_name, save_every=save_every, save_hist=False, print_every=50, num_epochs=4, learning_rate=9e-4) return trainer
def get_trainer() -> cifar_trainer.CIFAR10Trainer: # get a model to test on and its corresponding trainer model = cifar.CIFAR10Net() trainer = cifar_trainer.CIFAR10Trainer( model, # turn off checkpointing save_every = 0, print_every = GLOBAL_TEST_PARAMS['test_print_every'], # data options batch_size = GLOBAL_TEST_PARAMS['test_batch_size'], # training options learning_rate = GLOBAL_TEST_PARAMS['test_learning_rate'], num_epochs = GLOBAL_TEST_PARAMS['train_num_epochs'], device_id = util.get_device_id(), ) return trainer
def get_trainer(learning_rate: float = None) -> cifar_trainer.CIFAR10Trainer: if learning_rate is not None: test_learning_rate = learning_rate else: test_learning_rate = GLOBAL_OPTS['learning_rate'] model = cifar.CIFAR10Net() trainer = cifar_trainer.CIFAR10Trainer( model, # data options batch_size=GLOBAL_OPTS['batch_size'], num_workers=GLOBAL_OPTS['num_workers'], num_epochs=GLOBAL_OPTS['num_epochs'], # set initial learning rate learning_rate=test_learning_rate, # other options save_every=0, print_every=200, device_id=util.get_device_id(), verbose=GLOBAL_OPTS['verbose']) return trainer
def test_train(self) -> None: test_checkpoint = 'checkpoint/trainer_train_test.pkl' model = cifar.CIFAR10Net() # Get trainer object trainer = cifar_trainer.CIFAR10Trainer( model, save_every = 0, print_every = 50, device_id = util.get_device_id(), # loader options, num_epochs = self.test_num_epochs, learning_rate = 3e-4, batch_size = 128, num_workers = self.test_num_workers, ) if self.verbose: print('Created trainer object') print(trainer) # train for one epoch trainer.train() trainer.save_checkpoint(test_checkpoint) fig, ax = get_figure_subplots() vis_loss_history.plot_train_history_2subplots( ax, trainer.loss_history, acc_history = trainer.acc_history, cur_epoch = trainer.cur_epoch, iter_per_epoch = trainer.iter_per_epoch, loss_title = 'CIFAR-10 Trainer Test loss', acc_title = 'CIFAR-10 Trainer Test accuracy' ) fig.savefig('figures/trainer_train_test_history.png', bbox_inches='tight')
def test_save_load(self) -> None: # Get some data train_ab_paths = [path for path in os.listdir(self.train_data_root)] val_ab_paths = [path for path in os.listdir(self.val_data_root)] train_dataset = aligned_dataset.AlignedDatasetHDF5(self.test_dataset) # Get some models - we use resnet and PatchGAN here for now. At some # point the bugs in the UnetGenerator also need to be solved and this # test should be smaller than a 'real' training run. generator = resnet_gen.ResnetGenerator(3, 3, num_filters=64) discriminator = pixel_disc.PixelDiscriminator(3 + 3, num_filters=64) test_checkpoint_file = 'checkpoint/pix2pix_trainer_checkpoint_test.pkl' test_history_file = 'checkpoint/pix2pix_trainer_history_test.pkl' # Get a trainer src_trainer = pix2pix_trainer.Pix2PixTrainer( generator, discriminator, # dataset train_dataset=train_dataset, val_dataset=None, # trainer general options batch_size=self.batch_size, device_id=util.get_device_id(), num_epochs=self.test_num_epochs, # checkpoint save_every=0, print_every=self.print_every, ) src_trainer.train() print('Saving checkpoint to file [%s]' % str(test_checkpoint_file)) src_trainer.save_checkpoint(test_checkpoint_file) print('Saving history to file [%s]' % str(test_history_file)) src_trainer.save_history(test_history_file) # get a new trainer and load dst_trainer = pix2pix_trainer.Pix2PixTrainer( None, None, device_id=util.get_device_id()) dst_trainer.load_checkpoint(test_checkpoint_file) # Check that some models were loaded assert dst_trainer.g_net is not None assert dst_trainer.d_net is not None assert repr(src_trainer.g_net) == repr(dst_trainer.g_net) assert repr(src_trainer.d_net) == repr(dst_trainer.d_net) # check model params src_models = [src_trainer.g_net, src_trainer.d_net] dst_models = [dst_trainer.g_net, dst_trainer.d_net] for src_mod, dst_mod in zip(src_models, dst_models): print('Checking parameters for model [%s]' % repr(src_mod)) src_model_params = src_mod.get_net_state_dict() dst_model_params = dst_mod.get_net_state_dict() assert len(src_model_params.items()) == len( dst_model_params.items()) # p1, p2 are k,v tuple pairs of each model parameters # k = str # v = torch.Tensor for n, (p1, p2) in enumerate( zip(src_model_params.items(), dst_model_params.items())): assert p1[0] == p2[0] print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n + 1, len(src_model_params.items())), end='\r') assert torch.equal(p1[1], p2[1]) is True print('\n ...done') # check the various trainer stats assert src_trainer.beta1 == dst_trainer.beta1 assert src_trainer.l1_lambda == dst_trainer.l1_lambda assert src_trainer.gan_mode == dst_trainer.gan_mode assert src_trainer.learning_rate == dst_trainer.learning_rate assert src_trainer.batch_size == dst_trainer.batch_size assert src_trainer.print_every == dst_trainer.print_every assert src_trainer.save_every == dst_trainer.save_every assert src_trainer.cur_epoch == dst_trainer.cur_epoch assert src_trainer.num_epochs == dst_trainer.num_epochs # check history dst_trainer.load_history(test_history_file) assert dst_trainer.g_loss_history is not None assert dst_trainer.d_loss_history is not None assert len(src_trainer.g_loss_history) == len( dst_trainer.g_loss_history) assert len(src_trainer.d_loss_history) == len( dst_trainer.d_loss_history) for loss_elem in range(len(src_trainer.g_loss_history)): print('Checking g_loss_history [%d / %d]' % (loss_elem + 1, len(src_trainer.g_loss_history)), end='\r') assert src_trainer.g_loss_history[ loss_elem] == dst_trainer.g_loss_history[loss_elem] print('\n OK') for loss_elem in range(len(src_trainer.d_loss_history)): print('Checking d_loss_history [%d / %d]' % (loss_elem + 1, len(src_trainer.g_loss_history)), end='\r') assert src_trainer.d_loss_history[ loss_elem] == dst_trainer.d_loss_history[loss_elem] print('\n OK') # we need to set the train loaders dst_trainer.set_train_dataset(train_dataset) # Now try to extent the trainer history and train for another epoch dst_trainer.set_num_epochs(src_trainer.num_epochs + 1) assert dst_trainer.cur_epoch == src_trainer.cur_epoch print('Continuing training for dst_trainer from epoch %d' % dst_trainer.cur_epoch) dst_trainer.train() assert src_trainer.num_epochs + 1 == dst_trainer.cur_epoch
def test_save_load_checkpoint(self) -> None: test_checkpoint = 'checkpoint/trainer_test_save_load.pkl' # get a model model = cifar.CIFAR10Net() # get a trainer src_tr = cifar_trainer.CIFAR10Trainer( model, num_epochs = self.test_num_epochs, save_every = 0, device_id = util.get_device_id(), batch_size = self.test_batch_size, num_workers = self.test_num_workers ) if self.verbose: print('Created trainer object') print(src_tr) # train for one epoch src_tr.train() src_tr.save_checkpoint(test_checkpoint) # Make a new trainer and load all parameters into that # I guess we need to put some kind of loader and model here... dst_tr = cifar_trainer.CIFAR10Trainer( model, device_id = util.get_device_id() ) dst_tr.load_checkpoint(test_checkpoint) # Test object parameters assert src_tr.num_epochs == dst_tr.num_epochs assert src_tr.learning_rate == dst_tr.learning_rate assert src_tr.weight_decay == dst_tr.weight_decay assert src_tr.print_every == dst_tr.print_every assert src_tr.save_every == dst_tr.save_every assert src_tr.device_id == dst_tr.device_id print('\t Comparing model parameters ') src_model_params = src_tr.get_model_params() dst_model_params = dst_tr.get_model_params() assert len(src_model_params.items()) == len(dst_model_params.items()) # p1, p2 are k,v tuple pairs of each model parameters # k = str # v = torch.Tensor for n, (p1, p2) in enumerate(zip(src_model_params.items(), dst_model_params.items())): assert p1[0] == p2[0] print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n+1, len(src_model_params.items())), end='\r') assert torch.equal(p1[1], p2[1]) == True print('\n ...done') # Test loss history val_loss_history = 'test_save_load_history.pkl' src_tr.save_history(val_loss_history) dst_tr.load_history(val_loss_history) print('\t Comparing loss history....') assert src_tr.loss_iter == dst_tr.loss_iter for n in range(src_tr.loss_iter): print('Checking loss element [%d/%d]' % (n, src_tr.loss_iter), end='\r') assert src_tr.loss_history[n] == dst_tr.loss_history[n] # Try to train for another epoch dst_tr.set_num_epochs(src_tr.num_epochs+1) assert dst_tr.num_epochs == src_tr.num_epochs+1 dst_tr.train() assert src_tr.num_epochs+1 == dst_tr.cur_epoch print('\n ...done') os.remove(test_checkpoint)
def save_load_checkpoint_train_test(self) -> None: model = mnist_net.MNISTNet() # Get trainer object src_tr = mnist_trainer.MNISTTrainer( model, save_every=0, print_every=250, checkpoint_name='save_load_test', device_id=util.get_device_id(), # loader options, num_epochs=self.test_num_epochs, batch_size=self.test_batch_size, num_workers=1 # Need to provide both training and test datasets ) if self.verbose: print('Created trainer object') print(src_tr) # train for one epoch src_tr.train() # Now try to load a checkpoint and ensure that there is an # acc history attribute that is not None dst_tr = mnist_trainer.MNISTTrainer( model, device_id=util.get_device_id(), ) ck_fname = 'checkpoint/save_load_test_epoch-%d.pkl' % ( test_num_epochs - 1) dst_tr.load_checkpoint(ck_fname) # Test object parameters assert src_tr.num_epochs == dst_tr.num_epochs assert src_tr.learning_rate == dst_tr.learning_rate assert src_tr.weight_decay == dst_tr.weight_decay assert src_tr.print_every == dst_tr.print_every assert src_tr.save_every == dst_tr.save_every assert src_tr.device_id == dst_tr.device_id print('\t Comparing model parameters ') src_model_params = src_tr.get_model_params() dst_model_params = dst_tr.get_model_params() assert len(src_model_params.items()) == len(dst_model_params.items()) print('\t Comparing model parameters ') src_model_params = src_tr.get_model_params() dst_model_params = dst_tr.get_model_params() assert len(src_model_params.items()) == len(dst_model_params.items()) # p1, p2 are k,v tuple pairs of each model parameters # k = str # v = torch.Tensor for n, (p1, p2) in enumerate( zip(src_model_params.items(), dst_model_params.items())): assert p1[0] == p2[0] print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n, len(src_model_params.items())), end='\r') assert torch.equal(p1[1], p2[1]) == True print('\n ...done')
def test_save_load(self) -> None: test_checkpoint_file = 'checkpoint/test_aae_trainer_checkpoint.pth' test_history_file = 'checkpoint/test_aae_trainer_history.pth' # get some models q_net = aae_common.AAEQNet(self.x_dim, self.z_dim, self.hidden_size) p_net = aae_common.AAEPNet(self.x_dim, self.z_dim, self.hidden_size) d_net = aae_common.AAEDNetGauss(self.z_dim, self.hidden_size) train_dataset, val_dataset = get_mnist_datasets(self.test_data_dir) # get a trainer src_trainer = aae_trainer.AAETrainer( q_net, p_net, d_net, # datasets train_dataset=train_dataset, val_dataset=val_dataset, # train options num_epochs=self.test_num_epochs, batch_size=self.test_batch_size, # misc print_every=self.print_every, save_every=0, device_id=util.get_device_id(), verbose=self.verbose) # generate the source parameters src_trainer.train() src_trainer.save_checkpoint(test_checkpoint_file) src_trainer.save_history(test_history_file) # get a new trainer dst_trainer = aae_trainer.AAETrainer(device_id=util.get_device_id()) dst_trainer.load_checkpoint(test_checkpoint_file) # Check parameters of each model in turn src_models = [src_trainer.q_net, src_trainer.p_net, src_trainer.d_net] dst_models = [dst_trainer.q_net, src_trainer.p_net, src_trainer.d_net] for src_mod, dst_mod in zip(src_models, dst_models): print('\t Comparing parameters for %s model' % repr(src_mod)) src_model_params = src_mod.get_net_state_dict() dst_model_params = dst_mod.get_net_state_dict() assert len(src_model_params.items()) == len( dst_model_params.items()) # p1, p2 are k,v tuple pairs of each model parameters # k = str # v = torch.Tensor for n, (p1, p2) in enumerate( zip(src_model_params.items(), dst_model_params.items())): assert p1[0] == p2[0] print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n + 1, len(src_model_params.items())), end='\r') assert torch.equal(p1[1], p2[1]) is True print('\n ...done') # History dst_trainer.load_history(test_history_file) assert dst_trainer.d_loss_history is not None assert dst_trainer.g_loss_history is not None assert dst_trainer.recon_loss_history is not None assert len(src_trainer.d_loss_history) == len( dst_trainer.d_loss_history) assert len(src_trainer.g_loss_history) == len( dst_trainer.g_loss_history) assert len(src_trainer.recon_loss_history) == len( dst_trainer.recon_loss_history)
def test_save_load(self) -> None: train_dataset, val_dataset = get_mnist_datasets(self.data_dir) # Get some models. For this test we just accept the default constructor # parameters (num_blocks = 4, start_size = 32, kernel_size = 3) encoder = denoise_ae.DAEEncoder() decoder = denoise_ae.DAEDecoder() test_checkpoint_file = 'checkpoint/dae_trainer_checkpoint.pkl' test_history_file = 'checkpoint/dae_trainer_history.pkl' src_trainer = dae_trainer.DAETrainer( encoder, decoder, # datasets train_dataset=train_dataset, val_dataset=val_dataset, device_id=util.get_device_id(), # trainer params batch_size=self.batch_size, num_epochs=self.test_num_epochs, # disable saving save_every=0, print_every=self.print_every, verbose=self.verbose) train_start_time = time.time() src_trainer.train() train_end_time = time.time() train_total_time = train_end_time - train_start_time print('Trainer %s trained %d epochs in %s' %\ (repr(self), src_trainer.cur_epoch, str(timedelta(seconds = train_total_time))) ) print('Saving checkpoint to file [%s]' % str(test_checkpoint_file)) src_trainer.save_checkpoint(test_checkpoint_file) src_trainer.save_history(test_history_file) # get a new trainer and load dst_trainer = dae_trainer.DAETrainer(device_id=util.get_device_id()) dst_trainer.load_checkpoint(test_checkpoint_file) # check the basic trainer params assert src_trainer.num_epochs == dst_trainer.num_epochs assert src_trainer.learning_rate == dst_trainer.learning_rate assert src_trainer.momentum == dst_trainer.momentum assert src_trainer.weight_decay == dst_trainer.weight_decay assert src_trainer.loss_function == dst_trainer.loss_function assert src_trainer.optim_function == dst_trainer.optim_function assert src_trainer.cur_epoch == dst_trainer.cur_epoch assert src_trainer.iter_per_epoch == dst_trainer.iter_per_epoch assert src_trainer.save_every == dst_trainer.save_every assert src_trainer.print_every == dst_trainer.print_every assert src_trainer.batch_size == dst_trainer.batch_size assert src_trainer.val_batch_size == dst_trainer.val_batch_size assert src_trainer.shuffle == dst_trainer.shuffle # Now check the models src_models = [src_trainer.encoder, src_trainer.decoder] dst_models = [dst_trainer.encoder, dst_trainer.decoder] for src_mod, dst_mod in zip(src_models, dst_models): print('\t Comparing parameters for %s model' % repr(src_mod)) src_model_params = src_mod.get_net_state_dict() dst_model_params = dst_mod.get_net_state_dict() assert len(src_model_params.items()) == len( dst_model_params.items()) # p1, p2 are k,v tuple pairs of each model parameters # k = str # v = torch.Tensor for n, (p1, p2) in enumerate( zip(src_model_params.items(), dst_model_params.items())): assert p1[0] == p2[0] print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n + 1, len(src_model_params.items())), end='\r') assert torch.equal(p1[1], p2[1]) is True print('\n ...done') print('Checking history...') dst_trainer.load_history(test_history_file) assert len(src_trainer.loss_history) == len(dst_trainer.loss_history) for elem in range(len(src_trainer.loss_history)): print('Checking loss history element [%d / %d]' % (elem + 1, len(src_trainer.loss_history)), end='\r') assert src_trainer.loss_history[elem] == dst_trainer.loss_history[ elem] print('\n OK')
def test_save_load(self) -> None: test_checkpoint_file = 'checkpoint/test_aae_semi_trainer_checkpoint.pth' test_history_file = 'checkpoint/test_aae_semi_trainer_history.pth' q_net = aae_common.AAEQNet(self.x_dim, self.z_dim, self.hidden_size, num_classes=self.num_classes) p_net = aae_common.AAEPNet(self.x_dim, self.z_dim + self.num_classes, self.hidden_size) d_cat_net = aae_common.AAEDNetGauss(self.num_classes, self.hidden_size) d_gauss_net = aae_common.AAEDNetGauss(self.z_dim, self.hidden_size) q_net.set_cat_mode() assert q_net.net.cat_mode == True # We also need to sub-sample some parts of the MNIST dataset to produce the # 'labelled' data loaders print('Creating MNIST sub-dataset...') train_label_dataset, val_label_dataset, train_unlabel_dataset = mnist_sub.gen_mnist_subset( self.test_data_dir, transform=self.transform, verbose=self.verbose) assert train_label_dataset is not None assert train_unlabel_dataset is not None assert val_label_dataset is not None src_trainer = aae_semisupervised_trainer.AAESemiTrainer( q_net, p_net, d_cat_net, d_gauss_net, # datasets train_label_dataset=train_label_dataset, train_unlabel_dataset=train_unlabel_dataset, val_label_dataset=val_label_dataset, # train options num_epochs=self.test_num_epochs, batch_size=self.batch_size, # misc print_every=self.print_every, save_every=0, device_id=util.get_device_id(), verbose=self.verbose) src_trainer.train() print('Saving checkpoint to file [%s]' % str(test_checkpoint_file)) src_trainer.save_checkpoint(test_checkpoint_file) dst_trainer = aae_semisupervised_trainer.AAESemiTrainer( device_id=util.get_device_id()) assert dst_trainer.q_net is None assert dst_trainer.p_net is None assert dst_trainer.d_cat_net is None assert dst_trainer.d_gauss_net is None # Test that models, etc are loaded print('Loading checkpoint data from [%s]' % str(test_checkpoint_file)) dst_trainer.load_checkpoint(test_checkpoint_file) assert dst_trainer.q_net is not None assert dst_trainer.p_net is not None assert dst_trainer.d_cat_net is not None assert dst_trainer.d_gauss_net is not None model_list = ['q_net', 'p_net', 'd_cat_net', 'd_gauss_net'] for model in model_list: src_model = getattr(src_trainer, model) dst_model = getattr(dst_trainer, model) assert src_model is not None assert dst_model is not None print('\t Comparing parameters for model [%s]' % repr(src_model)) src_model_params = src_model.get_net_state_dict() dst_model_params = dst_model.get_net_state_dict() assert len(src_model_params.items()) == len( dst_model_params.items()) # p1, p2 are k,v tuple pairs of each model parameters # k = str # v = torch.Tensor for n, (p1, p2) in enumerate( zip(src_model_params.items(), dst_model_params.items())): assert p1[0] == p2[0] print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n + 1, len(src_model_params.items())), end='\r') assert torch.equal(p1[1], p2[1]) is True print('\n ...done') # Test that history is correctly loaded print('Saving history to file [%s]' % str(test_history_file)) src_trainer.save_history(test_history_file) dst_trainer.load_history(test_history_file) # Check iteration values assert src_trainer.loss_iter == dst_trainer.loss_iter assert src_trainer.val_loss_iter == dst_trainer.val_loss_iter assert src_trainer.train_val_loss_iter == dst_trainer.train_val_loss_iter assert src_trainer.acc_iter == dst_trainer.acc_iter assert src_trainer.cur_epoch == dst_trainer.cur_epoch assert src_trainer.iter_per_epoch == dst_trainer.iter_per_epoch # Check history arrays assert dst_trainer.d_loss_history is not None assert dst_trainer.g_loss_history is not None assert dst_trainer.recon_loss_history is not None assert dst_trainer.class_loss_history is not None assert len(src_trainer.d_loss_history) == len( dst_trainer.d_loss_history) assert len(src_trainer.g_loss_history) == len( dst_trainer.g_loss_history) assert len(src_trainer.recon_loss_history) == len( dst_trainer.recon_loss_history) assert len(src_trainer.class_loss_history) == len( dst_trainer.class_loss_history) for idx in range(len(src_trainer.d_loss_history)): print('Checking d_loss_history idx [%d / %d]' % (idx + 1, len(src_trainer.d_loss_history)), end='\r') assert src_trainer.d_loss_history[ idx] == dst_trainer.d_loss_history[idx] print('\n OK') for idx in range(len(src_trainer.g_loss_history)): print('Checking d_loss_history idx [%d / %d]' % (idx + 1, len(src_trainer.g_loss_history)), end='\r') assert src_trainer.g_loss_history[ idx] == dst_trainer.g_loss_history[idx] print('\n OK') for idx in range(len(src_trainer.recon_loss_history)): print('Checking d_loss_history idx [%d / %d]' % (idx + 1, len(src_trainer.recon_loss_history)), end='\r') assert src_trainer.recon_loss_history[ idx] == dst_trainer.recon_loss_history[idx] print('\n OK') for idx in range(len(src_trainer.class_loss_history)): print('Checking d_loss_history idx [%d / %d]' % (idx + 1, len(src_trainer.class_loss_history)), end='\r') assert src_trainer.class_loss_history[ idx] == dst_trainer.class_loss_history[idx] print('\n OK')
def test_save_load_checkpoint(self) -> None: test_checkpoint = 'checkpoint/dcgan_trainer_test.pkl' test_history = 'checkpoint/dcgan_trainer_test_history.pkl' train_dataset = get_dataset() # get models discriminator = dcgan.DCGANDiscriminator() generator = dcgan.DCGANGenerator() # get a trainer src_trainer = dcgan_trainer.DCGANTrainer( D=discriminator, G=generator, # device device_id=util.get_device_id(), batch_size=self.batch_size, # training params train_dataset=train_dataset, num_epochs=self.test_num_epochs, learning_rate=self.test_learning_rate, verbose=self.verbose, print_every=self.print_every, save_every=0, ) src_trainer.train() print('Saving checkpoint data to file [%s]' % str(test_checkpoint)) src_trainer.save_checkpoint(test_checkpoint) src_trainer.save_history(test_history) # load into new trainer dst_trainer = dcgan_trainer.DCGANTrainer( None, None, train_dataset=train_dataset, device_id=util.get_device_id()) dst_trainer.load_checkpoint(test_checkpoint) assert src_trainer.num_epochs == dst_trainer.num_epochs assert src_trainer.learning_rate == dst_trainer.learning_rate assert src_trainer.weight_decay == dst_trainer.weight_decay assert src_trainer.print_every == dst_trainer.print_every assert src_trainer.save_every == dst_trainer.save_every assert src_trainer.device_id == dst_trainer.device_id print('\t Comparing generator model parameters ') src_g = src_trainer.generator.get_net_state_dict() dst_g = dst_trainer.generator.get_net_state_dict() assert len(src_g.items()) == len(dst_g.items()) # p1, p2 are k,v tuple pairs of each model parameters # k = str # v = torch.Tensor for n, (p1, p2) in enumerate(zip(src_g.items(), dst_g.items())): assert p1[0] == p2[0] print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n + 1, len(src_g.items())), end='') assert torch.equal(p1[1], p2[1]) is True print('\t OK') print('\n ...done') print('\t Comparing discriminator model parameters') src_d = src_trainer.discriminator.get_net_state_dict() dst_d = dst_trainer.discriminator.get_net_state_dict() assert len(src_d.items()) == len(dst_d.items()) # p1, p2 are k,v tuple pairs of each model parameters # k = str # v = torch.Tensor for n, (p1, p2) in enumerate(zip(src_d.items(), dst_d.items())): assert p1[0] == p2[0] print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n + 1, len(src_d.items())), end='') assert torch.equal(p1[1], p2[1]) is True print('\t OK') print('\n ...done') # load history and check dst_trainer.load_history(test_history) assert dst_trainer.d_loss_history is not None assert dst_trainer.g_loss_history is not None assert len(src_trainer.d_loss_history) == len( dst_trainer.d_loss_history) assert len(src_trainer.g_loss_history) == len( dst_trainer.g_loss_history) print('Checking D loss history...') for n in range(len(src_trainer.d_loss_history)): assert src_trainer.d_loss_history[n] == dst_trainer.d_loss_history[ n] print(' OK') print('Checking G loss history...') for n in range(len(src_trainer.g_loss_history)): assert src_trainer.g_loss_history[n] == dst_trainer.g_loss_history[ n] print(' OK') # Try training a bit more. Since the values of cur_epoch and num_epochs # are the same, there should be no effect at first dst_trainer.train() assert dst_trainer.cur_epoch == src_trainer.cur_epoch # If we then adjust the number of epochs (to at least cur_epoch+1) then # we should see another dst_trainer.set_num_epochs(src_trainer.num_epochs + 1) dst_trainer.train() assert src_trainer.num_epochs + 1 == dst_trainer.cur_epoch os.remove(test_checkpoint) os.remove(test_history)
def test_save_load_checkpoint(self) -> None: test_checkpoint_name = self.checkpoint_dir + 'resnet_trainer_checkpoint.pkl' test_history_name = self.checkpoint_dir + 'resnet_trainer_history.pkl' # get a model model = resnets.WideResnet( depth = self.resnet_depth, num_classes = 10, # using CIFAR-10 data input_channels=3, w_factor = 1 ) # get a traner src_tr = resnet_trainer.ResnetTrainer( model, # training parameters batch_size = self.test_batch_size, num_epochs = self.test_num_epochs, learning_rate = self.test_learning_rate, # device device_id = util.get_device_id(), # display, print_every = self.print_every, save_every = 0, verbose = self.verbose ) if self.verbose: print('Created %s object' % repr(src_tr)) print(src_tr) print('Training model %s for %d epochs' % (repr(src_tr), self.test_num_epochs)) src_tr.train() # save the final checkpoint src_tr.save_checkpoint(test_checkpoint_name) src_tr.save_history(test_history_name) # get a new trainer and load checkpoint dst_tr = resnet_trainer.ResnetTrainer( model ) dst_tr.load_checkpoint(test_checkpoint_name) # Test object parameters assert src_tr.num_epochs == dst_tr.num_epochs assert src_tr.learning_rate == dst_tr.learning_rate assert src_tr.weight_decay == dst_tr.weight_decay assert src_tr.print_every == dst_tr.print_every assert src_tr.save_every == dst_tr.save_every print('\t Comparing model parameters ') src_model_params = src_tr.get_model_params() dst_model_params = dst_tr.get_model_params() assert len(src_model_params.items()) == len(dst_model_params.items()) # p1, p2 are k,v tuple pairs of each model parameters # k = str # v = torch.Tensor for n, (p1, p2) in enumerate(zip(src_model_params.items(), dst_model_params.items())): assert p1[0] == p2[0] print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n+1, len(src_model_params.items())), end='\r') assert torch.equal(p1[1], p2[1]) == True print('\n ...done') # test history dst_tr.load_history(test_history_name) # loss history assert len(src_tr.loss_history) == len(dst_tr.loss_history) assert src_tr.loss_iter == dst_tr.loss_iter for n in range(len(src_tr.loss_history)): assert src_tr.loss_history[n] == dst_tr.loss_history[n] # test loss history assert len(src_tr.val_loss_history) == len(dst_tr.val_loss_history) assert src_tr.val_loss_iter == dst_tr.val_loss_iter for n in range(len(src_tr.val_loss_history)): assert src_tr.val_loss_history[n] == dst_tr.val_loss_history[n] # test acc history assert len(src_tr.acc_history) == len(dst_tr.acc_history) assert src_tr.acc_iter == dst_tr.acc_iter for n in range(len(src_tr.acc_history)): assert src_tr.acc_history[n] == dst_tr.acc_history[n] fig, ax = vis_loss_history.get_figure_subplots() vis_loss_history.plot_train_history_2subplots( ax, src_tr.loss_history, acc_history = src_tr.acc_history, cur_epoch = src_tr.cur_epoch, iter_per_epoch = src_tr.iter_per_epoch ) fig.savefig('figures/resnet_trainer_train_test_history.png', bbox_inches='tight')
def test_save_load_device_map(self) -> None: test_checkpoint = 'checkpoint/trainer_save_load_device_map.pkl' test_history = 'checkpoint/trainer_save_load_device_map_history.pkl' model = cifar.CIFAR10Net() # Get trainer object src_tr = cifar_trainer.CIFAR10Trainer( model, save_every = 0, print_every = 50, device_id = util.get_device_id(), # loader options, num_epochs = self.test_num_epochs, batch_size = self.test_batch_size, num_workers = self.test_num_workers, ) if self.verbose: print('Created trainer object') print(src_tr) # train for one epoch src_tr.train() src_tr.save_checkpoint(test_checkpoint) assert src_tr.acc_history is not None src_tr.save_history(test_history) # Now try to load a checkpoint and ensure that there is an # acc history attribute that is not None dst_tr = cifar_trainer.CIFAR10Trainer( model, device_id = -1, verbose = self.verbose ) dst_tr.load_checkpoint(test_checkpoint) # Test object parameters assert src_tr.num_epochs == dst_tr.num_epochs assert src_tr.learning_rate == dst_tr.learning_rate assert src_tr.weight_decay == dst_tr.weight_decay assert src_tr.print_every == dst_tr.print_every assert src_tr.save_every == dst_tr.save_every print('\t Comparing model parameters ') src_model_params = src_tr.get_model_params() dst_model_params = dst_tr.get_model_params() assert len(src_model_params.items()) == len(dst_model_params.items()) print('\t Comparing model parameters ') src_model_params = src_tr.get_model_params() dst_model_params = dst_tr.get_model_params() assert len(src_model_params.items()) == len(dst_model_params.items()) # p1, p2 are k,v tuple pairs of each model parameters # k = str # v = torch.Tensor print('Checking that tensors from checkpoint have been tranferred to new device') for n, (p1, p2) in enumerate(zip(src_model_params.items(), dst_model_params.items())): assert p1[0] == p2[0] print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n, len(src_model_params.items())), end='\r') assert 'cpu' == p2[1].device.type print('\n ...done') # Test loss history dst_tr.load_history(test_history) assert dst_tr.acc_history is not None print('\t Comparing loss history....') assert src_tr.loss_iter == dst_tr.loss_iter for n in range(src_tr.loss_iter): print('Checking loss element [%d/%d]' % (n, src_tr.loss_iter), end='\r') assert src_tr.loss_history[n] == dst_tr.loss_history[n]