Beispiel #1
0
    def test_save_load(self) -> None:
        test_finder_state_file = 'data/test_lr_finder_state.pth'
        # get a trainer, etc
        trainer = get_trainer()
        src_lr_finder = lr_common.LogFinder(
            trainer,
            lr_min      = GLOBAL_TEST_PARAMS['test_lr_min'],
            lr_max      = GLOBAL_TEST_PARAMS['test_lr_max'],
            num_iter    = GLOBAL_TEST_PARAMS['test_num_iter'],
            num_epochs  = GLOBAL_TEST_PARAMS['test_lr_num_epochs'],
            acc_test    = True,
            max_batches = self.test_max_batches,
            verbose     = self.verbose
        )
        print('max_batches set to %d' % src_lr_finder.max_batches)

        # make a copy of the model parameters before we start looking for a new
        # learning rate.
        lr_find_min, lr_find_max = src_lr_finder.find()
        assert src_lr_finder.smooth_loss_history is not None
        # save the finder state and load into a new object
        src_lr_finder.save(test_finder_state_file)

        dst_lr_finder = lr_common.LogFinder(
            None,
            verbose    = self.verbose
        )
        dst_lr_finder.load(test_finder_state_file)
        # Since the trainer is not preserved in the save operation it makes no
        # sense to check it here

        # if this works, convert to dict and check
        assert src_lr_finder.lr_mult == dst_lr_finder.lr_mult
        assert src_lr_finder.lr_min == dst_lr_finder.lr_min
        assert src_lr_finder.lr_max == dst_lr_finder.lr_max
        assert src_lr_finder.explode_thresh == dst_lr_finder.explode_thresh
        assert src_lr_finder.beta == dst_lr_finder.beta
        assert src_lr_finder.gamma == dst_lr_finder.gamma
        assert src_lr_finder.lr_min_factor == dst_lr_finder.lr_min_factor
        assert src_lr_finder.lr_max_scale == dst_lr_finder.lr_max_scale
        assert src_lr_finder.lr_select_method == dst_lr_finder.lr_select_method

        # check histories
        print('Checking smooth loss history...', end=' ')
        assert len(src_lr_finder.smooth_loss_history) == len(dst_lr_finder.smooth_loss_history)
        for n in range(len(src_lr_finder.smooth_loss_history)):
             assert src_lr_finder.smooth_loss_history[n] == dst_lr_finder.smooth_loss_history[n]
        print(' OK')

        print('Checking log learning rate history...', end=' ')
        assert len(src_lr_finder.log_lr_history) == len(dst_lr_finder.log_lr_history)
        for n in range(len(src_lr_finder.log_lr_history)):
            assert src_lr_finder.log_lr_history[n] == dst_lr_finder.log_lr_history[n]
        print(' OK')

        print('Checking acc history...', end=' ')
        assert len(src_lr_finder.acc_history)== len(dst_lr_finder.acc_history)
        for n in range(len(src_lr_finder.acc_history)):
            assert src_lr_finder.acc_history[n] == dst_lr_finder.acc_history[n]
        print(' OK')
Beispiel #2
0
    def test_model_param_save(self) -> None:
        # get a trainer, etc
        trainer = get_trainer()
        lr_finder = lr_common.LogFinder(
            trainer,
            lr_min      = GLOBAL_TEST_PARAMS['test_lr_min'],
            lr_max      = GLOBAL_TEST_PARAMS['test_lr_max'],
            num_iter    = GLOBAL_TEST_PARAMS['test_num_iter'],
            num_epochs  = GLOBAL_TEST_PARAMS['test_lr_num_epochs'],
            acc_test    = True,
            max_batches = self.test_max_batches,
            verbose     = self.verbose
        )

        # shut linter up
        if self.verbose:
            print(lr_finder)

        # make a copy of the model parameters before we start looking for a new
        # learning rate.
        lr_find_min, lr_find_max = lr_finder.find()
        # show plot
        fig1, ax1 = plt.subplots()
        lr_finder.plot_lr_vs_acc(ax1)

        # now check that the restored parameters match the copy of the
        # parameters save earlier

        if self.draw_plot is True:
            plt.show()
        else:
            plt.savefig('figures/test_lr_range_find_lr_vs_acc.png', bbox_inches='tight')

        trainer.print_every = 200
        trainer.train()

        fig2, ax2 = vis_loss_history.get_figure_subplots()
        vis_loss_history.plot_train_history_2subplots(
            ax2,
            trainer.get_loss_history(),
            acc_curve = trainer.get_acc_history(),
            iter_per_epoch = trainer.iter_per_epoch,
            cur_epoch = trainer.cur_epoch
        )
        if self.draw_plot is True:
            plt.show()
        else:
            plt.savefig('figures/test_lr_range_find_train_results.png', bbox_inches='tight')
Beispiel #3
0
def get_lr_finder_binary(tr: trainer.Trainer,
                         lr_min: float,
                         lr_max: float,
                         lr_select_method: str = 'max_acc',
                         num_epochs: int = 8,
                         explode_thresh: float = 8.0,
                         print_every: int = 32) -> lr_common.LogFinder:
    lr_finder = lr_common.LogFinder(tr,
                                    lr_min=lr_min,
                                    lr_max=lr_max,
                                    lr_select_method=lr_select_method,
                                    num_epochs=num_epochs,
                                    explode_thresh=explode_thresh,
                                    print_every=print_every)

    return lr_finder
Beispiel #4
0
    def test_find_lr(self) -> None:
        # get an LRFinder
        trainer = get_trainer()
        lr_finder = lr_common.LogFinder(
            trainer,
            lr_min         = GLOBAL_TEST_PARAMS['test_lr_min'],
            lr_max         = GLOBAL_TEST_PARAMS['test_lr_max'],
            num_epochs     = GLOBAL_TEST_PARAMS['test_lr_num_epochs'],
            explode_thresh = GLOBAL_TEST_PARAMS['test_lr_explode_thresh'],
            max_batches    = self.test_max_batches,
            verbose        = self.verbose
        )

        if self.verbose:
            print('Created LRFinder object')
            print(lr_finder)

        lr_find_min, lr_find_max = lr_finder.find()
        print('Found learning rate range as %.3f -> %.3f' % (lr_find_min, lr_find_max))

        # show plot
        finder_fig, finder_ax = vis_loss_history.get_figure_subplots(2)
        lr_finder.plot_lr_vs_acc(finder_ax[0])
        lr_finder.plot_lr_vs_loss(finder_ax[1])
        if self.draw_plot is True:
            plt.show()
        else:
            plt.savefig('figures/test_find_lr_plots.png', bbox_inches='tight')

        # train the network with the discovered parameters
        trainer.print_every = 200
        trainer.train()

        train_fig, train_ax = vis_loss_history.get_figure_subplots()
        vis_loss_history.plot_train_history_2subplots(
            train_ax,
            trainer.get_loss_history(),
            acc_curve = trainer.get_acc_history(),
            iter_per_epoch = trainer.iter_per_epoch,
            cur_epoch = trainer.cur_epoch
        )
        if self.draw_plot is True:
            plt.show()
        else:
            train_fig.savefig('figures/test_find_lr_train_results.png', bbox_inches='tight')
Beispiel #5
0
    def test_lr_range_find(self) -> None:
        trainer = get_trainer()
        lr_finder = lr_common.LogFinder(
            trainer,
            lr_min     = GLOBAL_TEST_PARAMS['test_lr_min'],
            lr_max     = GLOBAL_TEST_PARAMS['test_lr_max'],
            num_iter   = GLOBAL_TEST_PARAMS['test_num_iter'],
            num_epochs = GLOBAL_TEST_PARAMS['test_lr_num_epochs'],
            acc_test   = True,
            max_batches    = self.test_max_batches,
            verbose    = self.verbose
        )

        # shut linter up
        if self.verbose:
            print(lr_finder)

        lr_find_min, lr_find_max = lr_finder.find()
        # show plot
        fig1, ax1 = plt.subplots()
        lr_finder.plot_lr_vs_acc(ax1)

        if self.draw_plot is True:
            plt.show()
        else:
            plt.savefig('figures/test_lr_range_find_lr_vs_acc.png', bbox_inches='tight')

        trainer.print_every = 200
        trainer.train()

        fig2, ax2 = vis_loss_history.get_figure_subplots()
        vis_loss_history.plot_train_history_2subplots(
            ax2,
            trainer.get_loss_history(),
            acc_curve = trainer.get_acc_history(),
            iter_per_epoch = trainer.iter_per_epoch,
            cur_epoch = trainer.cur_epoch
        )
        if self.draw_plot is True:
            plt.show()
        else:
            plt.savefig('figures/test_lr_range_find_train_results.png', bbox_inches='tight')
def main() -> None:
    # get a model and train it as a reference
    ref_model = get_model()
    ref_trainer = get_trainer(ref_model, 'ex_cifar10_lr_find_schedule_')

    if GLOBAL_OPTS['tensorboard_dir'] is not None:
        ref_writer = tensorboard.SummaryWriter(log_dir=GLOBAL_OPTS['tensorboard_dir'])
        ref_trainer.set_tb_writer(ref_writer)

    ref_train_start_time = time.time()
    ref_trainer.train()
    ref_train_end_time = time.time()
    ref_train_total_time = ref_train_end_time - ref_train_start_time
    print('Total reference training time [%s] (%d epochs)  %s' %\
            (repr(ref_trainer), ref_trainer.cur_epoch,
             str(timedelta(seconds = ref_train_total_time)))
    )

    # get a model and train it with a scheduler
    sched_model = get_model()
    sched_trainer = get_trainer(sched_model, 'ex_cifar10_lr_find_schedule_')

    if GLOBAL_OPTS['tensorboard_dir'] is not None:
        if not os.path.isdir(GLOBAL_OPTS['tensorboard_dir']):
            os.mkdir(GLOBAL_OPTS['tensorboard_dir'])
        sched_writer = tensorboard.SummaryWriter(log_dir=GLOBAL_OPTS['tensorboard_dir'])
        sched_trainer.set_tb_writer(sched_writer)

    # get an LRFinder object
    lr_finder = lr_common.LogFinder(
        sched_trainer,
        lr_min         = GLOBAL_OPTS['find_lr_min'],
        lr_max         = GLOBAL_OPTS['find_lr_max'],
        num_epochs     = GLOBAL_OPTS['find_num_epochs'],
        explode_thresh = GLOBAL_OPTS['find_explode_thresh'],
        print_every    = GLOBAL_OPTS['find_print_every']
    )
    print(lr_finder)

    lr_find_start_time = time.time()
    lr_finder.find()
    lr_find_min, lr_find_max = lr_finder.get_lr_range()
    lr_find_end_time = time.time()
    lr_find_total_time = lr_find_end_time - lr_find_start_time
    print('Total parameter search time : %s' % str(timedelta(seconds = lr_find_total_time)))

    if GLOBAL_OPTS['verbose']:
        print('Found learning rate range as %.4f -> %.4f' % (lr_find_min, lr_find_max))

    # get a scheduler
    lr_sched_obj = getattr(schedule, GLOBAL_OPTS['sched_type'])
    lr_scheduler = lr_sched_obj(
        stepsize = int(len(sched_trainer.train_loader) / 4),
        lr_min = lr_find_min,
        lr_max = lr_find_max
    )
    assert(sched_trainer.acc_iter == 0)
    sched_trainer.set_lr_scheduler(lr_scheduler)

    sched_train_start_time = time.time()
    sched_trainer.train()
    sched_train_end_time = time.time()
    sched_train_total_time = sched_train_end_time - sched_train_start_time
    print('Total scheduled training time [%s] (%d epochs)  %s' %\
            (repr(sched_trainer), ref_trainer.cur_epoch,
             str(timedelta(seconds = sched_train_total_time)))
    )
    print('Scheduled training time (including find time) : %s' %\
          str(timedelta(seconds = sched_train_total_time + lr_find_total_time))
    )

    # Compare loss, accuracy
    fig, ax = vis_loss_history.get_figure_subplots(2)