示例#1
0
def show() -> None:
    fig, ax = plt.subplots()
    history = torch.load(GLOBAL_OPTS['input'])

    if GLOBAL_OPTS['loss_history_key'] not in history:
        raise ValueError('No key [%s] in history file [%s] (try using --mode=probe)' %\
                         (str(GLOBAL_OPTS['loss_history_key']), str(GLOBAL_OPTS['input']))
        )

    loss_history = history[
        GLOBAL_OPTS['loss_history_key']][0:history['loss_iter']]

    if GLOBAL_OPTS['acc_history_key'] in history:
        acc_history = history[
            GLOBAL_OPTS['acc_history_key']][0:history['loss_iter']]
    else:
        acc_history = None
    if 'test_loss_history' in history:
        test_loss_history = history['test_loss_history'][
            0:history['test_loss_iter']]
        if GLOBAL_OPTS['verbose']:
            print('%d test loss iterations' % history['test_loss_iter'])
    else:
        test_loss_history = None

    if GLOBAL_OPTS['test_loss_history_key'] in history:
        test_loss_history = history[
            GLOBAL_OPTS['test_loss_history_key']][0:history['test_loss_iter']]
    else:
        test_loss_history = None

    if acc_history is not None:
        fig, ax = vis_loss_history.get_figure_subplots(2)
    else:
        fig, ax = vis_loss_history.get_figure_subplots(1)

    vis_loss_history.plot_train_history_2subplots(
        ax,
        loss_history,
        test_loss_curve=test_loss_history,
        acc_curve=acc_history,
        title=GLOBAL_OPTS['title'],
        iter_per_epoch=history['iter_per_epoch'],
        cur_epoch=history['cur_epoch'])

    if GLOBAL_OPTS['print_loss']:
        print(str(loss_history))

    if GLOBAL_OPTS['print_acc']:
        print(str(acc_history))

    if GLOBAL_OPTS['plot_filename'] is not None:
        fig.tight_layout()
        fig.savefig(GLOBAL_OPTS['plot_filename'])
    else:
        plt.show()
示例#2
0
def main() -> None:
    model = get_model(GLOBAL_OPTS['model'])
    t = get_trainer(model, "superconverge_cifar10_", trainer_type="cifar")
    lr_finder = get_lr_finder(t, 1e-7, 1.0)
    find_start_time = time.time()
    lr_finder.find()
    lr_min, lr_max = lr_finder.get_lr_range()
    find_end_time = time.time()
    find_total_time = find_end_time - find_start_time
    print('Found learning rate range as %.4f -> %.4f' % (lr_min, lr_max))
    print('Learning rate search took %s' % str(find_total_time))        # TODO : better string format

    if GLOBAL_OPTS['sched_stepsize'] > 0:
        stepsize = GLOBAL_OPTS['sched_stepsize']
    else:
        stepsize = int(len(t.train_loader) / 2)

    # get a scheduler for learning rate
    lr_scheduler = get_scheduler(
        lr_min,
        lr_max,
        stepsize,
        sched_type='TriangularScheduler'
    )
    # get a scheduler for momentum
    mtm_scheduler = get_scheduler(
        lr_min,
        lr_max,
        stepsize,
        sched_type='InvTriangularScheduler'
    )
    t.set_lr_scheduler(lr_scheduler)
    t.set_mtm_scheduler(mtm_scheduler)
    train_start_time = time.time()
    t.train()
    train_end_time = time.time()
    train_total_time = train_end_time - train_start_time
    print('Training time : %s' % str(train_total_time))     # TODO: check string formatting

    # plot outputs
    loss_title   = 'CIFAR10 Superconvergence Loss'
    acc_title    = 'CIFAR10 Superconvergence Accuracy'
    fig_filename = 'figures/cifar10_superconv_test.png'
    train_fig, train_ax = vis_loss_history.get_figure_subplots(2)
    vis_loss_history.plot_train_history_2subplots(
        train_ax,
        t.get_loss_history(),
        acc_history = t.get_acc_history(),
        cur_epoch = t.cur_epoch,
        iter_per_epoch = t.iter_per_epoch,
        loss_title = loss_title,
        acc_title = acc_title
    )
    train_fig.tight_layout()
    train_fig.savefig(fig_filename)
    def test_train(self) -> None:
        test_checkpoint_name = self.checkpoint_dir + 'resnet_trainer_train_checkpoint.pkl'
        test_history_name    = self.checkpoint_dir + 'resnet_trainer_train_history.pkl'
        train_num_epochs = 4
        train_batch_size = 128
        # get a model
        model = resnets.WideResnet(
            depth = self.resnet_depth,
            num_classes = 10,     # using CIFAR-10 data
            input_channels=3,
            w_factor=1
        )
        # get a traner
        trainer = resnet_trainer.ResnetTrainer(
            model,
            # training parameters
            batch_size    = train_batch_size,
            num_epochs    = train_num_epochs,
            learning_rate = self.test_learning_rate,
            # device
            device_id = util.get_device_id(),
            # checkpoint
            checkpoint_dir = self.checkpoint_dir,
            checkpoint_name = 'resnet_trainer_test',
            # display,
            print_every = self.print_every,
            save_every = 5000,
            verbose = self.verbose
        )

        if self.verbose:
            print('Created %s object' % repr(trainer))
            print(trainer)

        print('Training model %s for %d epochs (batch size = %d)' %\
              (repr(trainer), train_num_epochs, train_batch_size)
        )
        trainer.train()

        # save the final checkpoint
        trainer.save_checkpoint(test_checkpoint_name)
        trainer.save_history(test_history_name)

        fig, ax = vis_loss_history.get_figure_subplots()
        vis_loss_history.plot_train_history_2subplots(
            ax,
            trainer.loss_history,
            acc_history = trainer.acc_history,
            cur_epoch = trainer.cur_epoch,
            iter_per_epoch = trainer.iter_per_epoch
        )
        fig.savefig('figures/resnet_trainer_train_history.png', bbox_inches='tight')

        print('======== TestResnetTrainer.test_train <END>')
示例#4
0
def generate_plot(trainer, loss_title, acc_title, fig_filename):
    train_fig, train_ax = vis_loss_history.get_figure_subplots(2)
    vis_loss_history.plot_train_history_2subplots(
        train_ax,
        trainer.get_loss_history(),
        acc_history=trainer.get_acc_history(),
        cur_epoch=trainer.cur_epoch,
        iter_per_epoch=trainer.iter_per_epoch,
        loss_title=loss_title,
        acc_title=acc_title)
    train_fig.tight_layout()
    train_fig.savefig(fig_filename)
示例#5
0
    def test_model_param_save(self) -> None:
        # get a trainer, etc
        trainer = get_trainer()
        lr_finder = lr_common.LogFinder(
            trainer,
            lr_min      = GLOBAL_TEST_PARAMS['test_lr_min'],
            lr_max      = GLOBAL_TEST_PARAMS['test_lr_max'],
            num_iter    = GLOBAL_TEST_PARAMS['test_num_iter'],
            num_epochs  = GLOBAL_TEST_PARAMS['test_lr_num_epochs'],
            acc_test    = True,
            max_batches = self.test_max_batches,
            verbose     = self.verbose
        )

        # shut linter up
        if self.verbose:
            print(lr_finder)

        # make a copy of the model parameters before we start looking for a new
        # learning rate.
        lr_find_min, lr_find_max = lr_finder.find()
        # show plot
        fig1, ax1 = plt.subplots()
        lr_finder.plot_lr_vs_acc(ax1)

        # now check that the restored parameters match the copy of the
        # parameters save earlier

        if self.draw_plot is True:
            plt.show()
        else:
            plt.savefig('figures/test_lr_range_find_lr_vs_acc.png', bbox_inches='tight')

        trainer.print_every = 200
        trainer.train()

        fig2, ax2 = vis_loss_history.get_figure_subplots()
        vis_loss_history.plot_train_history_2subplots(
            ax2,
            trainer.get_loss_history(),
            acc_curve = trainer.get_acc_history(),
            iter_per_epoch = trainer.iter_per_epoch,
            cur_epoch = trainer.cur_epoch
        )
        if self.draw_plot is True:
            plt.show()
        else:
            plt.savefig('figures/test_lr_range_find_train_results.png', bbox_inches='tight')
示例#6
0
    def test_find_lr(self) -> None:
        # get an LRFinder
        trainer = get_trainer()
        lr_finder = lr_common.LogFinder(
            trainer,
            lr_min         = GLOBAL_TEST_PARAMS['test_lr_min'],
            lr_max         = GLOBAL_TEST_PARAMS['test_lr_max'],
            num_epochs     = GLOBAL_TEST_PARAMS['test_lr_num_epochs'],
            explode_thresh = GLOBAL_TEST_PARAMS['test_lr_explode_thresh'],
            max_batches    = self.test_max_batches,
            verbose        = self.verbose
        )

        if self.verbose:
            print('Created LRFinder object')
            print(lr_finder)

        lr_find_min, lr_find_max = lr_finder.find()
        print('Found learning rate range as %.3f -> %.3f' % (lr_find_min, lr_find_max))

        # show plot
        finder_fig, finder_ax = vis_loss_history.get_figure_subplots(2)
        lr_finder.plot_lr_vs_acc(finder_ax[0])
        lr_finder.plot_lr_vs_loss(finder_ax[1])
        if self.draw_plot is True:
            plt.show()
        else:
            plt.savefig('figures/test_find_lr_plots.png', bbox_inches='tight')

        # train the network with the discovered parameters
        trainer.print_every = 200
        trainer.train()

        train_fig, train_ax = vis_loss_history.get_figure_subplots()
        vis_loss_history.plot_train_history_2subplots(
            train_ax,
            trainer.get_loss_history(),
            acc_curve = trainer.get_acc_history(),
            iter_per_epoch = trainer.iter_per_epoch,
            cur_epoch = trainer.cur_epoch
        )
        if self.draw_plot is True:
            plt.show()
        else:
            train_fig.savefig('figures/test_find_lr_train_results.png', bbox_inches='tight')
示例#7
0
    def test_lr_range_find(self) -> None:
        trainer = get_trainer()
        lr_finder = lr_common.LogFinder(
            trainer,
            lr_min     = GLOBAL_TEST_PARAMS['test_lr_min'],
            lr_max     = GLOBAL_TEST_PARAMS['test_lr_max'],
            num_iter   = GLOBAL_TEST_PARAMS['test_num_iter'],
            num_epochs = GLOBAL_TEST_PARAMS['test_lr_num_epochs'],
            acc_test   = True,
            max_batches    = self.test_max_batches,
            verbose    = self.verbose
        )

        # shut linter up
        if self.verbose:
            print(lr_finder)

        lr_find_min, lr_find_max = lr_finder.find()
        # show plot
        fig1, ax1 = plt.subplots()
        lr_finder.plot_lr_vs_acc(ax1)

        if self.draw_plot is True:
            plt.show()
        else:
            plt.savefig('figures/test_lr_range_find_lr_vs_acc.png', bbox_inches='tight')

        trainer.print_every = 200
        trainer.train()

        fig2, ax2 = vis_loss_history.get_figure_subplots()
        vis_loss_history.plot_train_history_2subplots(
            ax2,
            trainer.get_loss_history(),
            acc_curve = trainer.get_acc_history(),
            iter_per_epoch = trainer.iter_per_epoch,
            cur_epoch = trainer.cur_epoch
        )
        if self.draw_plot is True:
            plt.show()
        else:
            plt.savefig('figures/test_lr_range_find_train_results.png', bbox_inches='tight')
示例#8
0
    def test_train(self) -> None:
        test_checkpoint = 'checkpoint/trainer_train_test.pkl'
        model = cifar.CIFAR10Net()

        # Get trainer object
        trainer = cifar_trainer.CIFAR10Trainer(
            model,
            save_every    = 0,
            print_every   = 50,
            device_id     = util.get_device_id(),
            # loader options,
            num_epochs    = self.test_num_epochs,
            learning_rate = 3e-4,
            batch_size    = 128,
            num_workers   = self.test_num_workers,
        )

        if self.verbose:
            print('Created trainer object')
            print(trainer)

        # train for one epoch
        trainer.train()
        trainer.save_checkpoint(test_checkpoint)

        fig, ax = get_figure_subplots()
        vis_loss_history.plot_train_history_2subplots(
            ax,
            trainer.loss_history,
            acc_history = trainer.acc_history,
            cur_epoch = trainer.cur_epoch,
            iter_per_epoch = trainer.iter_per_epoch,
            loss_title = 'CIFAR-10 Trainer Test loss',
            acc_title = 'CIFAR-10 Trainer Test accuracy'
        )
        fig.savefig('figures/trainer_train_test_history.png', bbox_inches='tight')
示例#9
0
def main() -> None:
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset_transform = transforms.Compose([
        transforms.RandomRotation(5),
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(224,
                                     scale=(0.96, 1.0),
                                     ratio=(0.95, 1.05)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    test_dataset_transform = transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # HDF5 Datasets
    if GLOBAL_USE_HDF5 is True:
        cvd_train_dataset = hdf5_dataset.HDF5Dataset(
            GLOBAL_OPTS['train_dataset'],
            feature_name='images',
            label_name='labels',
            label_max_dim=1,
            transform=normalize)

        cvd_val_dataset = hdf5_dataset.HDF5Dataset(GLOBAL_OPTS['test_dataset'],
                                                   feature_name='images',
                                                   label_name='labels',
                                                   label_max_dim=1,
                                                   transform=normalize)
    else:
        cvd_train_dir = '/home/kreshnik/ml-data/cats-vs-dogs/train'
        # ImageFolder dataset
        cvd_train_dataset = datasets.ImageFolder(cvd_train_dir,
                                                 train_dataset_transform)

        csv_val_dir = '/home/kreshnik/ml-data/cats-vs-dogs/test'
        cvd_val_dataset = datasets.ImageFolder(csv_val_dir,
                                               test_dataset_transform)

    # get a network
    model = cvdnet.CVDNet2()
    cvd_train = cvd_trainer.CVDTrainer(
        model,
        # dataset options
        train_dataset=cvd_train_dataset,
        val_dataset=cvd_val_dataset,
        # training options
        loss_function='CrossEntropyLoss',
        learning_rate=GLOBAL_OPTS['learning_rate'],
        weight_decay=GLOBAL_OPTS['weight_decay'],
        momentum=GLOBAL_OPTS['momentum'],
        num_epochs=GLOBAL_OPTS['num_epochs'],
        batch_size=GLOBAL_OPTS['batch_size'],
        val_batch_size=GLOBAL_OPTS['val_batch_size'],
        # checkpoint
        checkpoint_name=GLOBAL_OPTS['checkpoint_name'],
        save_every=GLOBAL_OPTS['save_every'],
        # device
        device_id=GLOBAL_OPTS['device_id'],
        # other
        print_every=GLOBAL_OPTS['print_every'],
        verbose=GLOBAL_OPTS['verbose'])

    # Add a tensorboard writer
    if GLOBAL_OPTS['tensorboard_dir'] is not None:
        if not os.path.isdir(GLOBAL_OPTS['tensorboard_dir']):
            os.mkdir(GLOBAL_OPTS['tensorboard_dir'])
        writer = tensorboard.SummaryWriter(
            log_dir=GLOBAL_OPTS['tensorboard_dir'])
        cvd_train.set_tb_writer(writer)

    # Optionally do a search pass here and add a scheduler
    if GLOBAL_OPTS['find_lr']:
        lr_finder = expr_util.get_lr_finder(cvd_train)
        lr_find_start_time = time.time()
        lr_finder.find()
        lr_find_min, lr_find_max = lr_finder.get_lr_range()
        lr_find_end_time = time.time()
        lr_find_total_time = lr_find_end_time - lr_find_start_time
        print('Found learning rate range %.4f -> %.4f' %
              (lr_find_min, lr_find_max))
        print('Total find time [%s] ' %\
                str(timedelta(seconds = lr_find_total_time))
        )

        # Now get a scheduler
        stepsize = cvd_train.get_num_epochs() * len(trainer.train_loader) // 2
        # get scheduler
        lr_scheduler = expr_util.get_scheduler(
            lr_find_min,
            lr_find_max,
            stepsize,
            sched_type='TriangularScheduler')
        cvd_train.set_lr_scheduler(lr_scheduler)

    # train the model
    train_start_time = time.time()
    cvd_train.train()
    train_end_time = time.time()
    train_total_time = train_end_time - train_start_time

    print('Total scheduled training time [%s] (%d epochs)  %s' %\
            (repr(cvd_train), cvd_train.cur_epoch,
             str(timedelta(seconds = train_total_time)))
    )

    # Show results
    fig, ax = vis_loss_history.get_figure_subplots(num_subplots=2)
    vis_loss_history.plot_train_history_2subplots(
        ax,
        cvd_train.get_loss_history(),
        acc_history=cvd_train.get_acc_history(),
        iter_per_epoch=cvd_train.iter_per_epoch,
        loss_title='CVD Loss',
        acc_title='CVD Acc',
        cur_epoch=cvd_train.cur_epoch)

    fig.savefig('figures/cvd_train.png')
示例#10
0
def main() -> None:
    train_dataset, val_dataset = get_datasets(GLOBAL_OPTS['data_dir'])

    # get some models
    encoder = denoise_ae.DAEEncoder(num_channels=1)
    decoder = denoise_ae.DAEDecoder(num_channels=1)

    trainer = dae_trainer.DAETrainer(
        encoder,
        decoder,
        # datasets
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        device_id=GLOBAL_OPTS['device_id'],
        # trainer params
        batch_size=GLOBAL_OPTS['batch_size'],
        num_epochs=GLOBAL_OPTS['num_epochs'],
        learning_rate=GLOBAL_OPTS['learning_rate'],
        # checkpoints, saving, etc
        checkpoint_dir=GLOBAL_OPTS['checkpoint_dir'],
        checkpoint_name=GLOBAL_OPTS['checkpoint_name'],
        save_every=GLOBAL_OPTS['save_every'],
        print_every=GLOBAL_OPTS['print_every'],
        verbose=GLOBAL_OPTS['verbose'])

    # Add a summary writer
    if GLOBAL_OPTS['tensorboard_dir'] is not None:
        if GLOBAL_OPTS['verbose']:
            print('Adding tensorboard writer to [%s]' % repr(trainer))
        writer = tensorboard.SummaryWriter(
            log_dir=GLOBAL_OPTS['tensorboard_dir'])
        trainer.set_tb_writer(writer)

    # train the model
    train_start_time = time.time()
    trainer.train()
    train_end_time = time.time()
    train_total_time = train_end_time - train_start_time

    print('Trained [%s] for %d epochs, total time : %s' %\
          (repr(trainer), trainer.cur_epoch, str(timedelta(seconds = train_total_time)))
    )

    # Take the models and load them into an inferrer
    if GLOBAL_OPTS['infer']:
        trainer.drop_last = True
        # make the number of squares in the output figure settable later
        subplot_square = 8
        trainer.set_batch_size(
            subplot_square *
            subplot_square)  # So that we have an NxN grid of outputs
        # Get some figure stuff
        out_fig, out_ax_list = vis_img.get_grid_subplots(subplot_square)
        noise_fig, noise_ax_list = vis_img.get_grid_subplots(subplot_square)

        # Get an inferrer
        inferrer = dae_inferrer.DAEInferrer(
            trainer.encoder,
            trainer.decoder,
            noise_bias=GLOBAL_OPTS['noise_bias'],
            noise_factor=GLOBAL_OPTS['noise_factor'],
            device_id=GLOBAL_OPTS['device_id'])
        if GLOBAL_OPTS['verbose']:
            print('Created [%s] object and attached models [%s], [%s]' %\
                  (repr(trainer), repr(trainer.encoder), repr(trainer.decoder))
            )

        infer_start_time = time.time()
        for batch_idx, (data, _) in enumerate(trainer.val_loader):
            print('Inferring batch [%d / %d]' %
                  (batch_idx + 1, len(trainer.val_loader)),
                  end='\r')
            noise_batch = inferrer.get_noise(data)
            out_batch = inferrer.forward(data)

            # Plot noise
            plot_denoise(noise_ax_list, noise_batch)
            noise_fig_fname = 'figures/dae/mnist_dae_batch_%d_noise.png' % int(
                batch_idx)
            noise_fig.tight_layout()
            noise_fig.savefig(noise_fig_fname)

            # Plot outputs
            plot_denoise(out_ax_list, out_batch)
            out_fig_fname = 'figures/dae/mnist_dae_batch_%d_output.png' % int(
                batch_idx)
            out_fig.tight_layout()
            out_fig.savefig(out_fig_fname)

        print('\n OK')
        infer_end_time = time.time()
        infer_total_time = infer_end_time - infer_start_time
        print('Inferrer [%s] inferrer %d batches of size %d, total time : %s' %\
            (repr(inferrer), len(trainer.val_loader), trainer.batch_size, str(timedelta(seconds = infer_total_time)))
        )

    # Plot the loss history
    hist_fig, hist_ax = vis_loss_history.get_figure_subplots()
    vis_loss_history.plot_train_history_2subplots(
        hist_ax,
        trainer.get_loss_history(),
        cur_epoch=trainer.cur_epoch,
        iter_per_epoch=trainer.iter_per_epoch,
        loss_title='Denoising AE MNIST Training loss')
    hist_fig.savefig(GLOBAL_OPTS['loss_history_file'], bbox_inches='tight')
示例#11
0
    def test_save_load_checkpoint(self) -> None:
        test_checkpoint_name = self.checkpoint_dir + 'resnet_trainer_checkpoint.pkl'
        test_history_name    = self.checkpoint_dir + 'resnet_trainer_history.pkl'
        # get a model
        model = resnets.WideResnet(
            depth = self.resnet_depth,
            num_classes = 10,     # using CIFAR-10 data
            input_channels=3,
            w_factor = 1
        )
        # get a traner
        src_tr = resnet_trainer.ResnetTrainer(
            model,
            # training parameters
            batch_size    = self.test_batch_size,
            num_epochs    = self.test_num_epochs,
            learning_rate = self.test_learning_rate,
            # device
            device_id     = util.get_device_id(),
            # display,
            print_every   = self.print_every,
            save_every    = 0,
            verbose       = self.verbose
        )

        if self.verbose:
            print('Created %s object' % repr(src_tr))
            print(src_tr)

        print('Training model %s for %d epochs' % (repr(src_tr), self.test_num_epochs))
        src_tr.train()

        # save the final checkpoint
        src_tr.save_checkpoint(test_checkpoint_name)
        src_tr.save_history(test_history_name)

        # get a new trainer and load checkpoint
        dst_tr = resnet_trainer.ResnetTrainer(
            model
        )
        dst_tr.load_checkpoint(test_checkpoint_name)

        # Test object parameters
        assert src_tr.num_epochs == dst_tr.num_epochs
        assert src_tr.learning_rate == dst_tr.learning_rate
        assert src_tr.weight_decay == dst_tr.weight_decay
        assert src_tr.print_every == dst_tr.print_every
        assert src_tr.save_every == dst_tr.save_every

        print('\t Comparing model parameters ')
        src_model_params = src_tr.get_model_params()
        dst_model_params = dst_tr.get_model_params()
        assert len(src_model_params.items()) == len(dst_model_params.items())

        # p1, p2 are k,v tuple pairs of each model parameters
        # k = str
        # v = torch.Tensor
        for n, (p1, p2) in enumerate(zip(src_model_params.items(), dst_model_params.items())):
            assert p1[0] == p2[0]
            print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n+1, len(src_model_params.items())), end='\r')
            assert torch.equal(p1[1], p2[1]) == True
        print('\n ...done')

        # test history
        dst_tr.load_history(test_history_name)

        # loss history
        assert len(src_tr.loss_history) == len(dst_tr.loss_history)
        assert src_tr.loss_iter == dst_tr.loss_iter
        for n in range(len(src_tr.loss_history)):
            assert src_tr.loss_history[n] == dst_tr.loss_history[n]

        # test loss history
        assert len(src_tr.val_loss_history) == len(dst_tr.val_loss_history)
        assert src_tr.val_loss_iter == dst_tr.val_loss_iter
        for n in range(len(src_tr.val_loss_history)):
            assert src_tr.val_loss_history[n] == dst_tr.val_loss_history[n]

        # test acc history
        assert len(src_tr.acc_history) == len(dst_tr.acc_history)
        assert src_tr.acc_iter == dst_tr.acc_iter
        for n in range(len(src_tr.acc_history)):
            assert src_tr.acc_history[n] == dst_tr.acc_history[n]

        fig, ax = vis_loss_history.get_figure_subplots()
        vis_loss_history.plot_train_history_2subplots(
            ax,
            src_tr.loss_history,
            acc_history = src_tr.acc_history,
            cur_epoch = src_tr.cur_epoch,
            iter_per_epoch = src_tr.iter_per_epoch
        )
        fig.savefig('figures/resnet_trainer_train_test_history.png', bbox_inches='tight')
示例#12
0
def main() -> None:
    # Get a model
    model = cifar.CIFAR10Net()
    # Get a trainer
    trainer = cifar_trainer.CIFAR10Trainer(
        model,
        # training parameters
        batch_size=GLOBAL_OPTS['batch_size'],
        num_epochs=GLOBAL_OPTS['num_epochs'],
        learning_rate=GLOBAL_OPTS['learning_rate'],
        momentum=GLOBAL_OPTS['momentum'],
        weight_decay=GLOBAL_OPTS['weight_decay'],
        # device
        device_id=GLOBAL_OPTS['device_id'],
        # checkpoint
        checkpoint_dir=GLOBAL_OPTS['checkpoint_dir'],
        checkpoint_name=GLOBAL_OPTS['checkpoint_name'],
        # display,
        print_every=GLOBAL_OPTS['print_every'],
        save_every=GLOBAL_OPTS['save_every'],
        verbose=GLOBAL_OPTS['verbose'])

    if GLOBAL_OPTS['tensorboard_dir'] is not None:
        writer = tensorboard.SummaryWriter(
            log_dir=GLOBAL_OPTS['tensorboard_dir'])
        trainer.set_tb_writer(writer)

    # Optionally do a search pass here and add a scheduler
    if GLOBAL_OPTS['find_lr']:
        lr_finder = expr_util.get_lr_finder(trainer)
        lr_find_start_time = time.time()
        lr_finder.find()
        lr_find_min, lr_find_max = lr_finder.get_lr_range()
        lr_find_end_time = time.time()
        lr_find_total_time = lr_find_end_time - lr_find_start_time
        print('Found learning rate range %.4f -> %.4f' %
              (lr_find_min, lr_find_max))
        print('Total find time [%s] ' %\
                str(timedelta(seconds = lr_find_total_time))
        )

        # Now get a scheduler
        stepsize = trainer.get_num_epochs() * len(trainer.train_loader) // 2
        # get scheduler
        lr_scheduler = expr_util.get_scheduler(
            lr_find_min,
            lr_find_max,
            stepsize,
            sched_type='TriangularScheduler')
        trainer.set_lr_scheduler(lr_scheduler)

    # train the model
    train_start_time = time.time()
    trainer.train()
    train_end_time = time.time()
    train_total_time = train_end_time - train_start_time
    print('Total training time [%s] (%d epochs)  %s' %\
            (repr(trainer), trainer.cur_epoch,
             str(timedelta(seconds = train_total_time)))
    )

    # Visualise the output
    train_fig, train_ax = vis_loss_history.get_figure_subplots()
    vis_loss_history.plot_train_history_2subplots(
        train_ax,
        trainer.get_loss_history(),
        acc_history=trainer.get_acc_history(),
        cur_epoch=trainer.cur_epoch,
        iter_per_epoch=trainer.iter_per_epoch,
        loss_title='CIFAR-10 Training Loss',
        acc_title='CIFAR-10 Training Accuracy ')
    train_fig.savefig(GLOBAL_OPTS['fig_name'], bbox_inches='tight')