def test_save_load(self) -> None: test_finder_state_file = 'data/test_lr_finder_state.pth' # get a trainer, etc trainer = get_trainer() src_lr_finder = lr_common.LogFinder( trainer, lr_min = GLOBAL_TEST_PARAMS['test_lr_min'], lr_max = GLOBAL_TEST_PARAMS['test_lr_max'], num_iter = GLOBAL_TEST_PARAMS['test_num_iter'], num_epochs = GLOBAL_TEST_PARAMS['test_lr_num_epochs'], acc_test = True, max_batches = self.test_max_batches, verbose = self.verbose ) print('max_batches set to %d' % src_lr_finder.max_batches) # make a copy of the model parameters before we start looking for a new # learning rate. lr_find_min, lr_find_max = src_lr_finder.find() assert src_lr_finder.smooth_loss_history is not None # save the finder state and load into a new object src_lr_finder.save(test_finder_state_file) dst_lr_finder = lr_common.LogFinder( None, verbose = self.verbose ) dst_lr_finder.load(test_finder_state_file) # Since the trainer is not preserved in the save operation it makes no # sense to check it here # if this works, convert to dict and check assert src_lr_finder.lr_mult == dst_lr_finder.lr_mult assert src_lr_finder.lr_min == dst_lr_finder.lr_min assert src_lr_finder.lr_max == dst_lr_finder.lr_max assert src_lr_finder.explode_thresh == dst_lr_finder.explode_thresh assert src_lr_finder.beta == dst_lr_finder.beta assert src_lr_finder.gamma == dst_lr_finder.gamma assert src_lr_finder.lr_min_factor == dst_lr_finder.lr_min_factor assert src_lr_finder.lr_max_scale == dst_lr_finder.lr_max_scale assert src_lr_finder.lr_select_method == dst_lr_finder.lr_select_method # check histories print('Checking smooth loss history...', end=' ') assert len(src_lr_finder.smooth_loss_history) == len(dst_lr_finder.smooth_loss_history) for n in range(len(src_lr_finder.smooth_loss_history)): assert src_lr_finder.smooth_loss_history[n] == dst_lr_finder.smooth_loss_history[n] print(' OK') print('Checking log learning rate history...', end=' ') assert len(src_lr_finder.log_lr_history) == len(dst_lr_finder.log_lr_history) for n in range(len(src_lr_finder.log_lr_history)): assert src_lr_finder.log_lr_history[n] == dst_lr_finder.log_lr_history[n] print(' OK') print('Checking acc history...', end=' ') assert len(src_lr_finder.acc_history)== len(dst_lr_finder.acc_history) for n in range(len(src_lr_finder.acc_history)): assert src_lr_finder.acc_history[n] == dst_lr_finder.acc_history[n] print(' OK')
def test_model_param_save(self) -> None: # get a trainer, etc trainer = get_trainer() lr_finder = lr_common.LogFinder( trainer, lr_min = GLOBAL_TEST_PARAMS['test_lr_min'], lr_max = GLOBAL_TEST_PARAMS['test_lr_max'], num_iter = GLOBAL_TEST_PARAMS['test_num_iter'], num_epochs = GLOBAL_TEST_PARAMS['test_lr_num_epochs'], acc_test = True, max_batches = self.test_max_batches, verbose = self.verbose ) # shut linter up if self.verbose: print(lr_finder) # make a copy of the model parameters before we start looking for a new # learning rate. lr_find_min, lr_find_max = lr_finder.find() # show plot fig1, ax1 = plt.subplots() lr_finder.plot_lr_vs_acc(ax1) # now check that the restored parameters match the copy of the # parameters save earlier if self.draw_plot is True: plt.show() else: plt.savefig('figures/test_lr_range_find_lr_vs_acc.png', bbox_inches='tight') trainer.print_every = 200 trainer.train() fig2, ax2 = vis_loss_history.get_figure_subplots() vis_loss_history.plot_train_history_2subplots( ax2, trainer.get_loss_history(), acc_curve = trainer.get_acc_history(), iter_per_epoch = trainer.iter_per_epoch, cur_epoch = trainer.cur_epoch ) if self.draw_plot is True: plt.show() else: plt.savefig('figures/test_lr_range_find_train_results.png', bbox_inches='tight')
def get_lr_finder_binary(tr: trainer.Trainer, lr_min: float, lr_max: float, lr_select_method: str = 'max_acc', num_epochs: int = 8, explode_thresh: float = 8.0, print_every: int = 32) -> lr_common.LogFinder: lr_finder = lr_common.LogFinder(tr, lr_min=lr_min, lr_max=lr_max, lr_select_method=lr_select_method, num_epochs=num_epochs, explode_thresh=explode_thresh, print_every=print_every) return lr_finder
def test_find_lr(self) -> None: # get an LRFinder trainer = get_trainer() lr_finder = lr_common.LogFinder( trainer, lr_min = GLOBAL_TEST_PARAMS['test_lr_min'], lr_max = GLOBAL_TEST_PARAMS['test_lr_max'], num_epochs = GLOBAL_TEST_PARAMS['test_lr_num_epochs'], explode_thresh = GLOBAL_TEST_PARAMS['test_lr_explode_thresh'], max_batches = self.test_max_batches, verbose = self.verbose ) if self.verbose: print('Created LRFinder object') print(lr_finder) lr_find_min, lr_find_max = lr_finder.find() print('Found learning rate range as %.3f -> %.3f' % (lr_find_min, lr_find_max)) # show plot finder_fig, finder_ax = vis_loss_history.get_figure_subplots(2) lr_finder.plot_lr_vs_acc(finder_ax[0]) lr_finder.plot_lr_vs_loss(finder_ax[1]) if self.draw_plot is True: plt.show() else: plt.savefig('figures/test_find_lr_plots.png', bbox_inches='tight') # train the network with the discovered parameters trainer.print_every = 200 trainer.train() train_fig, train_ax = vis_loss_history.get_figure_subplots() vis_loss_history.plot_train_history_2subplots( train_ax, trainer.get_loss_history(), acc_curve = trainer.get_acc_history(), iter_per_epoch = trainer.iter_per_epoch, cur_epoch = trainer.cur_epoch ) if self.draw_plot is True: plt.show() else: train_fig.savefig('figures/test_find_lr_train_results.png', bbox_inches='tight')
def test_lr_range_find(self) -> None: trainer = get_trainer() lr_finder = lr_common.LogFinder( trainer, lr_min = GLOBAL_TEST_PARAMS['test_lr_min'], lr_max = GLOBAL_TEST_PARAMS['test_lr_max'], num_iter = GLOBAL_TEST_PARAMS['test_num_iter'], num_epochs = GLOBAL_TEST_PARAMS['test_lr_num_epochs'], acc_test = True, max_batches = self.test_max_batches, verbose = self.verbose ) # shut linter up if self.verbose: print(lr_finder) lr_find_min, lr_find_max = lr_finder.find() # show plot fig1, ax1 = plt.subplots() lr_finder.plot_lr_vs_acc(ax1) if self.draw_plot is True: plt.show() else: plt.savefig('figures/test_lr_range_find_lr_vs_acc.png', bbox_inches='tight') trainer.print_every = 200 trainer.train() fig2, ax2 = vis_loss_history.get_figure_subplots() vis_loss_history.plot_train_history_2subplots( ax2, trainer.get_loss_history(), acc_curve = trainer.get_acc_history(), iter_per_epoch = trainer.iter_per_epoch, cur_epoch = trainer.cur_epoch ) if self.draw_plot is True: plt.show() else: plt.savefig('figures/test_lr_range_find_train_results.png', bbox_inches='tight')
def main() -> None: # get a model and train it as a reference ref_model = get_model() ref_trainer = get_trainer(ref_model, 'ex_cifar10_lr_find_schedule_') if GLOBAL_OPTS['tensorboard_dir'] is not None: ref_writer = tensorboard.SummaryWriter(log_dir=GLOBAL_OPTS['tensorboard_dir']) ref_trainer.set_tb_writer(ref_writer) ref_train_start_time = time.time() ref_trainer.train() ref_train_end_time = time.time() ref_train_total_time = ref_train_end_time - ref_train_start_time print('Total reference training time [%s] (%d epochs) %s' %\ (repr(ref_trainer), ref_trainer.cur_epoch, str(timedelta(seconds = ref_train_total_time))) ) # get a model and train it with a scheduler sched_model = get_model() sched_trainer = get_trainer(sched_model, 'ex_cifar10_lr_find_schedule_') if GLOBAL_OPTS['tensorboard_dir'] is not None: if not os.path.isdir(GLOBAL_OPTS['tensorboard_dir']): os.mkdir(GLOBAL_OPTS['tensorboard_dir']) sched_writer = tensorboard.SummaryWriter(log_dir=GLOBAL_OPTS['tensorboard_dir']) sched_trainer.set_tb_writer(sched_writer) # get an LRFinder object lr_finder = lr_common.LogFinder( sched_trainer, lr_min = GLOBAL_OPTS['find_lr_min'], lr_max = GLOBAL_OPTS['find_lr_max'], num_epochs = GLOBAL_OPTS['find_num_epochs'], explode_thresh = GLOBAL_OPTS['find_explode_thresh'], print_every = GLOBAL_OPTS['find_print_every'] ) print(lr_finder) lr_find_start_time = time.time() lr_finder.find() lr_find_min, lr_find_max = lr_finder.get_lr_range() lr_find_end_time = time.time() lr_find_total_time = lr_find_end_time - lr_find_start_time print('Total parameter search time : %s' % str(timedelta(seconds = lr_find_total_time))) if GLOBAL_OPTS['verbose']: print('Found learning rate range as %.4f -> %.4f' % (lr_find_min, lr_find_max)) # get a scheduler lr_sched_obj = getattr(schedule, GLOBAL_OPTS['sched_type']) lr_scheduler = lr_sched_obj( stepsize = int(len(sched_trainer.train_loader) / 4), lr_min = lr_find_min, lr_max = lr_find_max ) assert(sched_trainer.acc_iter == 0) sched_trainer.set_lr_scheduler(lr_scheduler) sched_train_start_time = time.time() sched_trainer.train() sched_train_end_time = time.time() sched_train_total_time = sched_train_end_time - sched_train_start_time print('Total scheduled training time [%s] (%d epochs) %s' %\ (repr(sched_trainer), ref_trainer.cur_epoch, str(timedelta(seconds = sched_train_total_time))) ) print('Scheduled training time (including find time) : %s' %\ str(timedelta(seconds = sched_train_total_time + lr_find_total_time)) ) # Compare loss, accuracy fig, ax = vis_loss_history.get_figure_subplots(2)