def show() -> None: fig, ax = plt.subplots() history = torch.load(GLOBAL_OPTS['input']) if GLOBAL_OPTS['loss_history_key'] not in history: raise ValueError('No key [%s] in history file [%s] (try using --mode=probe)' %\ (str(GLOBAL_OPTS['loss_history_key']), str(GLOBAL_OPTS['input'])) ) loss_history = history[ GLOBAL_OPTS['loss_history_key']][0:history['loss_iter']] if GLOBAL_OPTS['acc_history_key'] in history: acc_history = history[ GLOBAL_OPTS['acc_history_key']][0:history['loss_iter']] else: acc_history = None if 'test_loss_history' in history: test_loss_history = history['test_loss_history'][ 0:history['test_loss_iter']] if GLOBAL_OPTS['verbose']: print('%d test loss iterations' % history['test_loss_iter']) else: test_loss_history = None if GLOBAL_OPTS['test_loss_history_key'] in history: test_loss_history = history[ GLOBAL_OPTS['test_loss_history_key']][0:history['test_loss_iter']] else: test_loss_history = None if acc_history is not None: fig, ax = vis_loss_history.get_figure_subplots(2) else: fig, ax = vis_loss_history.get_figure_subplots(1) vis_loss_history.plot_train_history_2subplots( ax, loss_history, test_loss_curve=test_loss_history, acc_curve=acc_history, title=GLOBAL_OPTS['title'], iter_per_epoch=history['iter_per_epoch'], cur_epoch=history['cur_epoch']) if GLOBAL_OPTS['print_loss']: print(str(loss_history)) if GLOBAL_OPTS['print_acc']: print(str(acc_history)) if GLOBAL_OPTS['plot_filename'] is not None: fig.tight_layout() fig.savefig(GLOBAL_OPTS['plot_filename']) else: plt.show()
def test_find_lr(self) -> None: # get an LRFinder trainer = get_trainer() lr_finder = lr_common.LogFinder( trainer, lr_min = GLOBAL_TEST_PARAMS['test_lr_min'], lr_max = GLOBAL_TEST_PARAMS['test_lr_max'], num_epochs = GLOBAL_TEST_PARAMS['test_lr_num_epochs'], explode_thresh = GLOBAL_TEST_PARAMS['test_lr_explode_thresh'], max_batches = self.test_max_batches, verbose = self.verbose ) if self.verbose: print('Created LRFinder object') print(lr_finder) lr_find_min, lr_find_max = lr_finder.find() print('Found learning rate range as %.3f -> %.3f' % (lr_find_min, lr_find_max)) # show plot finder_fig, finder_ax = vis_loss_history.get_figure_subplots(2) lr_finder.plot_lr_vs_acc(finder_ax[0]) lr_finder.plot_lr_vs_loss(finder_ax[1]) if self.draw_plot is True: plt.show() else: plt.savefig('figures/test_find_lr_plots.png', bbox_inches='tight') # train the network with the discovered parameters trainer.print_every = 200 trainer.train() train_fig, train_ax = vis_loss_history.get_figure_subplots() vis_loss_history.plot_train_history_2subplots( train_ax, trainer.get_loss_history(), acc_curve = trainer.get_acc_history(), iter_per_epoch = trainer.iter_per_epoch, cur_epoch = trainer.cur_epoch ) if self.draw_plot is True: plt.show() else: train_fig.savefig('figures/test_find_lr_train_results.png', bbox_inches='tight')
def main() -> None: model = get_model(GLOBAL_OPTS['model']) t = get_trainer(model, "superconverge_cifar10_", trainer_type="cifar") lr_finder = get_lr_finder(t, 1e-7, 1.0) find_start_time = time.time() lr_finder.find() lr_min, lr_max = lr_finder.get_lr_range() find_end_time = time.time() find_total_time = find_end_time - find_start_time print('Found learning rate range as %.4f -> %.4f' % (lr_min, lr_max)) print('Learning rate search took %s' % str(find_total_time)) # TODO : better string format if GLOBAL_OPTS['sched_stepsize'] > 0: stepsize = GLOBAL_OPTS['sched_stepsize'] else: stepsize = int(len(t.train_loader) / 2) # get a scheduler for learning rate lr_scheduler = get_scheduler( lr_min, lr_max, stepsize, sched_type='TriangularScheduler' ) # get a scheduler for momentum mtm_scheduler = get_scheduler( lr_min, lr_max, stepsize, sched_type='InvTriangularScheduler' ) t.set_lr_scheduler(lr_scheduler) t.set_mtm_scheduler(mtm_scheduler) train_start_time = time.time() t.train() train_end_time = time.time() train_total_time = train_end_time - train_start_time print('Training time : %s' % str(train_total_time)) # TODO: check string formatting # plot outputs loss_title = 'CIFAR10 Superconvergence Loss' acc_title = 'CIFAR10 Superconvergence Accuracy' fig_filename = 'figures/cifar10_superconv_test.png' train_fig, train_ax = vis_loss_history.get_figure_subplots(2) vis_loss_history.plot_train_history_2subplots( train_ax, t.get_loss_history(), acc_history = t.get_acc_history(), cur_epoch = t.cur_epoch, iter_per_epoch = t.iter_per_epoch, loss_title = loss_title, acc_title = acc_title ) train_fig.tight_layout() train_fig.savefig(fig_filename)
def test_train(self) -> None: test_checkpoint_name = self.checkpoint_dir + 'resnet_trainer_train_checkpoint.pkl' test_history_name = self.checkpoint_dir + 'resnet_trainer_train_history.pkl' train_num_epochs = 4 train_batch_size = 128 # get a model model = resnets.WideResnet( depth = self.resnet_depth, num_classes = 10, # using CIFAR-10 data input_channels=3, w_factor=1 ) # get a traner trainer = resnet_trainer.ResnetTrainer( model, # training parameters batch_size = train_batch_size, num_epochs = train_num_epochs, learning_rate = self.test_learning_rate, # device device_id = util.get_device_id(), # checkpoint checkpoint_dir = self.checkpoint_dir, checkpoint_name = 'resnet_trainer_test', # display, print_every = self.print_every, save_every = 5000, verbose = self.verbose ) if self.verbose: print('Created %s object' % repr(trainer)) print(trainer) print('Training model %s for %d epochs (batch size = %d)' %\ (repr(trainer), train_num_epochs, train_batch_size) ) trainer.train() # save the final checkpoint trainer.save_checkpoint(test_checkpoint_name) trainer.save_history(test_history_name) fig, ax = vis_loss_history.get_figure_subplots() vis_loss_history.plot_train_history_2subplots( ax, trainer.loss_history, acc_history = trainer.acc_history, cur_epoch = trainer.cur_epoch, iter_per_epoch = trainer.iter_per_epoch ) fig.savefig('figures/resnet_trainer_train_history.png', bbox_inches='tight') print('======== TestResnetTrainer.test_train <END>')
def generate_plot(trainer, loss_title, acc_title, fig_filename): train_fig, train_ax = vis_loss_history.get_figure_subplots(2) vis_loss_history.plot_train_history_2subplots( train_ax, trainer.get_loss_history(), acc_history=trainer.get_acc_history(), cur_epoch=trainer.cur_epoch, iter_per_epoch=trainer.iter_per_epoch, loss_title=loss_title, acc_title=acc_title) train_fig.tight_layout() train_fig.savefig(fig_filename)
def test_model_param_save(self) -> None: # get a trainer, etc trainer = get_trainer() lr_finder = lr_common.LogFinder( trainer, lr_min = GLOBAL_TEST_PARAMS['test_lr_min'], lr_max = GLOBAL_TEST_PARAMS['test_lr_max'], num_iter = GLOBAL_TEST_PARAMS['test_num_iter'], num_epochs = GLOBAL_TEST_PARAMS['test_lr_num_epochs'], acc_test = True, max_batches = self.test_max_batches, verbose = self.verbose ) # shut linter up if self.verbose: print(lr_finder) # make a copy of the model parameters before we start looking for a new # learning rate. lr_find_min, lr_find_max = lr_finder.find() # show plot fig1, ax1 = plt.subplots() lr_finder.plot_lr_vs_acc(ax1) # now check that the restored parameters match the copy of the # parameters save earlier if self.draw_plot is True: plt.show() else: plt.savefig('figures/test_lr_range_find_lr_vs_acc.png', bbox_inches='tight') trainer.print_every = 200 trainer.train() fig2, ax2 = vis_loss_history.get_figure_subplots() vis_loss_history.plot_train_history_2subplots( ax2, trainer.get_loss_history(), acc_curve = trainer.get_acc_history(), iter_per_epoch = trainer.iter_per_epoch, cur_epoch = trainer.cur_epoch ) if self.draw_plot is True: plt.show() else: plt.savefig('figures/test_lr_range_find_train_results.png', bbox_inches='tight')
def test_lr_range_find(self) -> None: trainer = get_trainer() lr_finder = lr_common.LogFinder( trainer, lr_min = GLOBAL_TEST_PARAMS['test_lr_min'], lr_max = GLOBAL_TEST_PARAMS['test_lr_max'], num_iter = GLOBAL_TEST_PARAMS['test_num_iter'], num_epochs = GLOBAL_TEST_PARAMS['test_lr_num_epochs'], acc_test = True, max_batches = self.test_max_batches, verbose = self.verbose ) # shut linter up if self.verbose: print(lr_finder) lr_find_min, lr_find_max = lr_finder.find() # show plot fig1, ax1 = plt.subplots() lr_finder.plot_lr_vs_acc(ax1) if self.draw_plot is True: plt.show() else: plt.savefig('figures/test_lr_range_find_lr_vs_acc.png', bbox_inches='tight') trainer.print_every = 200 trainer.train() fig2, ax2 = vis_loss_history.get_figure_subplots() vis_loss_history.plot_train_history_2subplots( ax2, trainer.get_loss_history(), acc_curve = trainer.get_acc_history(), iter_per_epoch = trainer.iter_per_epoch, cur_epoch = trainer.cur_epoch ) if self.draw_plot is True: plt.show() else: plt.savefig('figures/test_lr_range_find_train_results.png', bbox_inches='tight')
print('Creating subdir [%s] for schedule %d/%d [%s]' %\ (subdir, idx+1, len(schedulers), str(schedulers[idx])) ) os.mkdir(subdir) writer = tensorboard.SummaryWriter(log_dir=subdir) trainer.set_tb_writer(writer) if idx == 0: lr_find_min, lr_find_max, lr_finder = find_lr(trainer, return_finder=True) # create plots lr_acc_title = '[' + str(GLOBAL_OPTS['model']) + '[' + str(GLOBAL_OPTS['find_lr_select_method']) + '] ' +\ str(schedulers[idx]) + ' learning rate vs acc (log)' lr_loss_title = '[' + str(GLOBAL_OPTS['model']) + '[' + str(GLOBAL_OPTS['find_lr_select_method']) + '] ' +\ str(schedulers[idx]) + ' learning rate vs loss (log)' lr_fig, lr_ax = vis_loss_history.get_figure_subplots(2) lr_finder.plot_lr_vs_acc(lr_ax[0], lr_acc_title, log=True) lr_finder.plot_lr_vs_loss(lr_ax[1], lr_loss_title, log=True) # save lr_fig.tight_layout() lr_fig.savefig('figures/[%s][%s]_%s_lr_finder_output.png' %\ (str(GLOBAL_OPTS['model']), str(GLOBAL_OPTS['find_lr_select_method']), str(schedulers[idx])) ) print('Found learning rates as %.4f -> %.4f' % (lr_find_min, lr_find_max)) #if GLOBAL_OPTS['tensorboard_dir'] is not None: # writer.add_hparams( # hparam_dict = { # 'lr_find_min': lr_find_min, # 'lr_find_max': lr_find_max
def main() -> None: gan_data_transform = transforms.Compose([ transforms.Resize(GLOBAL_OPTS['image_size']), transforms.CenterCrop(GLOBAL_OPTS['image_size']), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) if GLOBAL_OPTS['dataset'] is not None: train_dataset = lmdb_dataset.LMDBDataset(GLOBAL_OPTS['dataset'], transform=gan_data_transform) else: train_dataset = datasets.ImageFolder(root=GLOBAL_OPTS['dataset_root'], transform=gan_data_transform) # get some models generator = dcgan.DCGANGenerator(zvec_dim=GLOBAL_OPTS['zvec_dim'], num_filters=GLOBAL_OPTS['g_num_filters'], img_size=GLOBAL_OPTS['image_size']) discriminator = dcgan.DCGANDiscriminator( num_filters=GLOBAL_OPTS['d_num_filters'], img_size=GLOBAL_OPTS['image_size']) # get a trainer gan_trainer = dcgan_trainer.DCGANTrainer( discriminator, generator, # DCGAN trainer specific arguments beta1=GLOBAL_OPTS['beta1'], # general trainer arguments train_dataset=train_dataset, # training opts learning_rate=GLOBAL_OPTS['learning_rate'], num_epochs=GLOBAL_OPTS['num_epochs'], batch_size=GLOBAL_OPTS['batch_size'], # Checkpoints save_every=GLOBAL_OPTS['save_every'], checkpoint_name=GLOBAL_OPTS['checkpoint_name'], checkpoint_dir=GLOBAL_OPTS['checkpoint_dir'], # display print_every=GLOBAL_OPTS['print_every'], verbose=GLOBAL_OPTS['verbose'], # device device_id=GLOBAL_OPTS['device_id']) if GLOBAL_OPTS['load_checkpoint'] is not None: gan_trainer.load_checkpoint(GLOBAL_OPTS['load_checkpoint']) print(gan_trainer.device) # Get a scheduler lr_scheduler = schedule.DecayToEpoch( non_decay_time=int(GLOBAL_OPTS['num_epochs'] // 2), decay_length=int(GLOBAL_OPTS['num_epochs'] // 2), initial_lr=GLOBAL_OPTS['learning_rate'], final_lr=0.0) print('Created scheduler') print(lr_scheduler) gan_trainer.set_lr_scheduler(lr_scheduler) train_start_time = time.time() gan_trainer.train() train_end_time = time.time() train_total_time = train_end_time - train_start_time print('Total training time : %s' % str(timedelta(seconds=train_total_time))) # show the training results dcgan_fig, dcgan_ax = vis_loss_history.get_figure_subplots(1) vis_loss_history.plot_train_history_dcgan( dcgan_ax, gan_trainer.get_g_loss_history(), gan_trainer.get_d_loss_history(), cur_epoch=gan_trainer.cur_epoch, iter_per_epoch=gan_trainer.iter_per_epoch) dcgan_fig.tight_layout() dcgan_fig.savefig('figures/dcgan_train_history.png')
def main() -> None: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset_transform = transforms.Compose([ transforms.RandomRotation(5), transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop(224, scale=(0.96, 1.0), ratio=(0.95, 1.05)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) test_dataset_transform = transforms.Compose([ transforms.Resize([224, 224]), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # HDF5 Datasets if GLOBAL_USE_HDF5 is True: cvd_train_dataset = hdf5_dataset.HDF5Dataset( GLOBAL_OPTS['train_dataset'], feature_name='images', label_name='labels', label_max_dim=1, transform=normalize) cvd_val_dataset = hdf5_dataset.HDF5Dataset(GLOBAL_OPTS['test_dataset'], feature_name='images', label_name='labels', label_max_dim=1, transform=normalize) else: cvd_train_dir = '/home/kreshnik/ml-data/cats-vs-dogs/train' # ImageFolder dataset cvd_train_dataset = datasets.ImageFolder(cvd_train_dir, train_dataset_transform) csv_val_dir = '/home/kreshnik/ml-data/cats-vs-dogs/test' cvd_val_dataset = datasets.ImageFolder(csv_val_dir, test_dataset_transform) # get a network model = cvdnet.CVDNet2() cvd_train = cvd_trainer.CVDTrainer( model, # dataset options train_dataset=cvd_train_dataset, val_dataset=cvd_val_dataset, # training options loss_function='CrossEntropyLoss', learning_rate=GLOBAL_OPTS['learning_rate'], weight_decay=GLOBAL_OPTS['weight_decay'], momentum=GLOBAL_OPTS['momentum'], num_epochs=GLOBAL_OPTS['num_epochs'], batch_size=GLOBAL_OPTS['batch_size'], val_batch_size=GLOBAL_OPTS['val_batch_size'], # checkpoint checkpoint_name=GLOBAL_OPTS['checkpoint_name'], save_every=GLOBAL_OPTS['save_every'], # device device_id=GLOBAL_OPTS['device_id'], # other print_every=GLOBAL_OPTS['print_every'], verbose=GLOBAL_OPTS['verbose']) # Add a tensorboard writer if GLOBAL_OPTS['tensorboard_dir'] is not None: if not os.path.isdir(GLOBAL_OPTS['tensorboard_dir']): os.mkdir(GLOBAL_OPTS['tensorboard_dir']) writer = tensorboard.SummaryWriter( log_dir=GLOBAL_OPTS['tensorboard_dir']) cvd_train.set_tb_writer(writer) # Optionally do a search pass here and add a scheduler if GLOBAL_OPTS['find_lr']: lr_finder = expr_util.get_lr_finder(cvd_train) lr_find_start_time = time.time() lr_finder.find() lr_find_min, lr_find_max = lr_finder.get_lr_range() lr_find_end_time = time.time() lr_find_total_time = lr_find_end_time - lr_find_start_time print('Found learning rate range %.4f -> %.4f' % (lr_find_min, lr_find_max)) print('Total find time [%s] ' %\ str(timedelta(seconds = lr_find_total_time)) ) # Now get a scheduler stepsize = cvd_train.get_num_epochs() * len(trainer.train_loader) // 2 # get scheduler lr_scheduler = expr_util.get_scheduler( lr_find_min, lr_find_max, stepsize, sched_type='TriangularScheduler') cvd_train.set_lr_scheduler(lr_scheduler) # train the model train_start_time = time.time() cvd_train.train() train_end_time = time.time() train_total_time = train_end_time - train_start_time print('Total scheduled training time [%s] (%d epochs) %s' %\ (repr(cvd_train), cvd_train.cur_epoch, str(timedelta(seconds = train_total_time))) ) # Show results fig, ax = vis_loss_history.get_figure_subplots(num_subplots=2) vis_loss_history.plot_train_history_2subplots( ax, cvd_train.get_loss_history(), acc_history=cvd_train.get_acc_history(), iter_per_epoch=cvd_train.iter_per_epoch, loss_title='CVD Loss', acc_title='CVD Acc', cur_epoch=cvd_train.cur_epoch) fig.savefig('figures/cvd_train.png')
def main() -> None: train_dataset, val_dataset = get_datasets(GLOBAL_OPTS['data_dir']) # get some models encoder = denoise_ae.DAEEncoder(num_channels=1) decoder = denoise_ae.DAEDecoder(num_channels=1) trainer = dae_trainer.DAETrainer( encoder, decoder, # datasets train_dataset=train_dataset, val_dataset=val_dataset, device_id=GLOBAL_OPTS['device_id'], # trainer params batch_size=GLOBAL_OPTS['batch_size'], num_epochs=GLOBAL_OPTS['num_epochs'], learning_rate=GLOBAL_OPTS['learning_rate'], # checkpoints, saving, etc checkpoint_dir=GLOBAL_OPTS['checkpoint_dir'], checkpoint_name=GLOBAL_OPTS['checkpoint_name'], save_every=GLOBAL_OPTS['save_every'], print_every=GLOBAL_OPTS['print_every'], verbose=GLOBAL_OPTS['verbose']) # Add a summary writer if GLOBAL_OPTS['tensorboard_dir'] is not None: if GLOBAL_OPTS['verbose']: print('Adding tensorboard writer to [%s]' % repr(trainer)) writer = tensorboard.SummaryWriter( log_dir=GLOBAL_OPTS['tensorboard_dir']) trainer.set_tb_writer(writer) # train the model train_start_time = time.time() trainer.train() train_end_time = time.time() train_total_time = train_end_time - train_start_time print('Trained [%s] for %d epochs, total time : %s' %\ (repr(trainer), trainer.cur_epoch, str(timedelta(seconds = train_total_time))) ) # Take the models and load them into an inferrer if GLOBAL_OPTS['infer']: trainer.drop_last = True # make the number of squares in the output figure settable later subplot_square = 8 trainer.set_batch_size( subplot_square * subplot_square) # So that we have an NxN grid of outputs # Get some figure stuff out_fig, out_ax_list = vis_img.get_grid_subplots(subplot_square) noise_fig, noise_ax_list = vis_img.get_grid_subplots(subplot_square) # Get an inferrer inferrer = dae_inferrer.DAEInferrer( trainer.encoder, trainer.decoder, noise_bias=GLOBAL_OPTS['noise_bias'], noise_factor=GLOBAL_OPTS['noise_factor'], device_id=GLOBAL_OPTS['device_id']) if GLOBAL_OPTS['verbose']: print('Created [%s] object and attached models [%s], [%s]' %\ (repr(trainer), repr(trainer.encoder), repr(trainer.decoder)) ) infer_start_time = time.time() for batch_idx, (data, _) in enumerate(trainer.val_loader): print('Inferring batch [%d / %d]' % (batch_idx + 1, len(trainer.val_loader)), end='\r') noise_batch = inferrer.get_noise(data) out_batch = inferrer.forward(data) # Plot noise plot_denoise(noise_ax_list, noise_batch) noise_fig_fname = 'figures/dae/mnist_dae_batch_%d_noise.png' % int( batch_idx) noise_fig.tight_layout() noise_fig.savefig(noise_fig_fname) # Plot outputs plot_denoise(out_ax_list, out_batch) out_fig_fname = 'figures/dae/mnist_dae_batch_%d_output.png' % int( batch_idx) out_fig.tight_layout() out_fig.savefig(out_fig_fname) print('\n OK') infer_end_time = time.time() infer_total_time = infer_end_time - infer_start_time print('Inferrer [%s] inferrer %d batches of size %d, total time : %s' %\ (repr(inferrer), len(trainer.val_loader), trainer.batch_size, str(timedelta(seconds = infer_total_time))) ) # Plot the loss history hist_fig, hist_ax = vis_loss_history.get_figure_subplots() vis_loss_history.plot_train_history_2subplots( hist_ax, trainer.get_loss_history(), cur_epoch=trainer.cur_epoch, iter_per_epoch=trainer.iter_per_epoch, loss_title='Denoising AE MNIST Training loss') hist_fig.savefig(GLOBAL_OPTS['loss_history_file'], bbox_inches='tight')
def test_save_load_checkpoint(self) -> None: test_checkpoint_name = self.checkpoint_dir + 'resnet_trainer_checkpoint.pkl' test_history_name = self.checkpoint_dir + 'resnet_trainer_history.pkl' # get a model model = resnets.WideResnet( depth = self.resnet_depth, num_classes = 10, # using CIFAR-10 data input_channels=3, w_factor = 1 ) # get a traner src_tr = resnet_trainer.ResnetTrainer( model, # training parameters batch_size = self.test_batch_size, num_epochs = self.test_num_epochs, learning_rate = self.test_learning_rate, # device device_id = util.get_device_id(), # display, print_every = self.print_every, save_every = 0, verbose = self.verbose ) if self.verbose: print('Created %s object' % repr(src_tr)) print(src_tr) print('Training model %s for %d epochs' % (repr(src_tr), self.test_num_epochs)) src_tr.train() # save the final checkpoint src_tr.save_checkpoint(test_checkpoint_name) src_tr.save_history(test_history_name) # get a new trainer and load checkpoint dst_tr = resnet_trainer.ResnetTrainer( model ) dst_tr.load_checkpoint(test_checkpoint_name) # Test object parameters assert src_tr.num_epochs == dst_tr.num_epochs assert src_tr.learning_rate == dst_tr.learning_rate assert src_tr.weight_decay == dst_tr.weight_decay assert src_tr.print_every == dst_tr.print_every assert src_tr.save_every == dst_tr.save_every print('\t Comparing model parameters ') src_model_params = src_tr.get_model_params() dst_model_params = dst_tr.get_model_params() assert len(src_model_params.items()) == len(dst_model_params.items()) # p1, p2 are k,v tuple pairs of each model parameters # k = str # v = torch.Tensor for n, (p1, p2) in enumerate(zip(src_model_params.items(), dst_model_params.items())): assert p1[0] == p2[0] print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n+1, len(src_model_params.items())), end='\r') assert torch.equal(p1[1], p2[1]) == True print('\n ...done') # test history dst_tr.load_history(test_history_name) # loss history assert len(src_tr.loss_history) == len(dst_tr.loss_history) assert src_tr.loss_iter == dst_tr.loss_iter for n in range(len(src_tr.loss_history)): assert src_tr.loss_history[n] == dst_tr.loss_history[n] # test loss history assert len(src_tr.val_loss_history) == len(dst_tr.val_loss_history) assert src_tr.val_loss_iter == dst_tr.val_loss_iter for n in range(len(src_tr.val_loss_history)): assert src_tr.val_loss_history[n] == dst_tr.val_loss_history[n] # test acc history assert len(src_tr.acc_history) == len(dst_tr.acc_history) assert src_tr.acc_iter == dst_tr.acc_iter for n in range(len(src_tr.acc_history)): assert src_tr.acc_history[n] == dst_tr.acc_history[n] fig, ax = vis_loss_history.get_figure_subplots() vis_loss_history.plot_train_history_2subplots( ax, src_tr.loss_history, acc_history = src_tr.acc_history, cur_epoch = src_tr.cur_epoch, iter_per_epoch = src_tr.iter_per_epoch ) fig.savefig('figures/resnet_trainer_train_test_history.png', bbox_inches='tight')
def main() -> None: # Get a model model = cifar.CIFAR10Net() # Get a trainer trainer = cifar_trainer.CIFAR10Trainer( model, # training parameters batch_size=GLOBAL_OPTS['batch_size'], num_epochs=GLOBAL_OPTS['num_epochs'], learning_rate=GLOBAL_OPTS['learning_rate'], momentum=GLOBAL_OPTS['momentum'], weight_decay=GLOBAL_OPTS['weight_decay'], # device device_id=GLOBAL_OPTS['device_id'], # checkpoint checkpoint_dir=GLOBAL_OPTS['checkpoint_dir'], checkpoint_name=GLOBAL_OPTS['checkpoint_name'], # display, print_every=GLOBAL_OPTS['print_every'], save_every=GLOBAL_OPTS['save_every'], verbose=GLOBAL_OPTS['verbose']) if GLOBAL_OPTS['tensorboard_dir'] is not None: writer = tensorboard.SummaryWriter( log_dir=GLOBAL_OPTS['tensorboard_dir']) trainer.set_tb_writer(writer) # Optionally do a search pass here and add a scheduler if GLOBAL_OPTS['find_lr']: lr_finder = expr_util.get_lr_finder(trainer) lr_find_start_time = time.time() lr_finder.find() lr_find_min, lr_find_max = lr_finder.get_lr_range() lr_find_end_time = time.time() lr_find_total_time = lr_find_end_time - lr_find_start_time print('Found learning rate range %.4f -> %.4f' % (lr_find_min, lr_find_max)) print('Total find time [%s] ' %\ str(timedelta(seconds = lr_find_total_time)) ) # Now get a scheduler stepsize = trainer.get_num_epochs() * len(trainer.train_loader) // 2 # get scheduler lr_scheduler = expr_util.get_scheduler( lr_find_min, lr_find_max, stepsize, sched_type='TriangularScheduler') trainer.set_lr_scheduler(lr_scheduler) # train the model train_start_time = time.time() trainer.train() train_end_time = time.time() train_total_time = train_end_time - train_start_time print('Total training time [%s] (%d epochs) %s' %\ (repr(trainer), trainer.cur_epoch, str(timedelta(seconds = train_total_time))) ) # Visualise the output train_fig, train_ax = vis_loss_history.get_figure_subplots() vis_loss_history.plot_train_history_2subplots( train_ax, trainer.get_loss_history(), acc_history=trainer.get_acc_history(), cur_epoch=trainer.cur_epoch, iter_per_epoch=trainer.iter_per_epoch, loss_title='CIFAR-10 Training Loss', acc_title='CIFAR-10 Training Accuracy ') train_fig.savefig(GLOBAL_OPTS['fig_name'], bbox_inches='tight')
def main() -> None: # get a model and train it as a reference ref_model = get_model() ref_trainer = get_trainer(ref_model, 'ex_cifar10_lr_find_schedule_') if GLOBAL_OPTS['tensorboard_dir'] is not None: ref_writer = tensorboard.SummaryWriter(log_dir=GLOBAL_OPTS['tensorboard_dir']) ref_trainer.set_tb_writer(ref_writer) ref_train_start_time = time.time() ref_trainer.train() ref_train_end_time = time.time() ref_train_total_time = ref_train_end_time - ref_train_start_time print('Total reference training time [%s] (%d epochs) %s' %\ (repr(ref_trainer), ref_trainer.cur_epoch, str(timedelta(seconds = ref_train_total_time))) ) # get a model and train it with a scheduler sched_model = get_model() sched_trainer = get_trainer(sched_model, 'ex_cifar10_lr_find_schedule_') if GLOBAL_OPTS['tensorboard_dir'] is not None: if not os.path.isdir(GLOBAL_OPTS['tensorboard_dir']): os.mkdir(GLOBAL_OPTS['tensorboard_dir']) sched_writer = tensorboard.SummaryWriter(log_dir=GLOBAL_OPTS['tensorboard_dir']) sched_trainer.set_tb_writer(sched_writer) # get an LRFinder object lr_finder = lr_common.LogFinder( sched_trainer, lr_min = GLOBAL_OPTS['find_lr_min'], lr_max = GLOBAL_OPTS['find_lr_max'], num_epochs = GLOBAL_OPTS['find_num_epochs'], explode_thresh = GLOBAL_OPTS['find_explode_thresh'], print_every = GLOBAL_OPTS['find_print_every'] ) print(lr_finder) lr_find_start_time = time.time() lr_finder.find() lr_find_min, lr_find_max = lr_finder.get_lr_range() lr_find_end_time = time.time() lr_find_total_time = lr_find_end_time - lr_find_start_time print('Total parameter search time : %s' % str(timedelta(seconds = lr_find_total_time))) if GLOBAL_OPTS['verbose']: print('Found learning rate range as %.4f -> %.4f' % (lr_find_min, lr_find_max)) # get a scheduler lr_sched_obj = getattr(schedule, GLOBAL_OPTS['sched_type']) lr_scheduler = lr_sched_obj( stepsize = int(len(sched_trainer.train_loader) / 4), lr_min = lr_find_min, lr_max = lr_find_max ) assert(sched_trainer.acc_iter == 0) sched_trainer.set_lr_scheduler(lr_scheduler) sched_train_start_time = time.time() sched_trainer.train() sched_train_end_time = time.time() sched_train_total_time = sched_train_end_time - sched_train_start_time print('Total scheduled training time [%s] (%d epochs) %s' %\ (repr(sched_trainer), ref_trainer.cur_epoch, str(timedelta(seconds = sched_train_total_time))) ) print('Scheduled training time (including find time) : %s' %\ str(timedelta(seconds = sched_train_total_time + lr_find_total_time)) ) # Compare loss, accuracy fig, ax = vis_loss_history.get_figure_subplots(2)