Пример #1
0
def show() -> None:
    fig, ax = plt.subplots()
    history = torch.load(GLOBAL_OPTS['input'])

    if GLOBAL_OPTS['loss_history_key'] not in history:
        raise ValueError('No key [%s] in history file [%s] (try using --mode=probe)' %\
                         (str(GLOBAL_OPTS['loss_history_key']), str(GLOBAL_OPTS['input']))
        )

    loss_history = history[
        GLOBAL_OPTS['loss_history_key']][0:history['loss_iter']]

    if GLOBAL_OPTS['acc_history_key'] in history:
        acc_history = history[
            GLOBAL_OPTS['acc_history_key']][0:history['loss_iter']]
    else:
        acc_history = None
    if 'test_loss_history' in history:
        test_loss_history = history['test_loss_history'][
            0:history['test_loss_iter']]
        if GLOBAL_OPTS['verbose']:
            print('%d test loss iterations' % history['test_loss_iter'])
    else:
        test_loss_history = None

    if GLOBAL_OPTS['test_loss_history_key'] in history:
        test_loss_history = history[
            GLOBAL_OPTS['test_loss_history_key']][0:history['test_loss_iter']]
    else:
        test_loss_history = None

    if acc_history is not None:
        fig, ax = vis_loss_history.get_figure_subplots(2)
    else:
        fig, ax = vis_loss_history.get_figure_subplots(1)

    vis_loss_history.plot_train_history_2subplots(
        ax,
        loss_history,
        test_loss_curve=test_loss_history,
        acc_curve=acc_history,
        title=GLOBAL_OPTS['title'],
        iter_per_epoch=history['iter_per_epoch'],
        cur_epoch=history['cur_epoch'])

    if GLOBAL_OPTS['print_loss']:
        print(str(loss_history))

    if GLOBAL_OPTS['print_acc']:
        print(str(acc_history))

    if GLOBAL_OPTS['plot_filename'] is not None:
        fig.tight_layout()
        fig.savefig(GLOBAL_OPTS['plot_filename'])
    else:
        plt.show()
Пример #2
0
    def test_find_lr(self) -> None:
        # get an LRFinder
        trainer = get_trainer()
        lr_finder = lr_common.LogFinder(
            trainer,
            lr_min         = GLOBAL_TEST_PARAMS['test_lr_min'],
            lr_max         = GLOBAL_TEST_PARAMS['test_lr_max'],
            num_epochs     = GLOBAL_TEST_PARAMS['test_lr_num_epochs'],
            explode_thresh = GLOBAL_TEST_PARAMS['test_lr_explode_thresh'],
            max_batches    = self.test_max_batches,
            verbose        = self.verbose
        )

        if self.verbose:
            print('Created LRFinder object')
            print(lr_finder)

        lr_find_min, lr_find_max = lr_finder.find()
        print('Found learning rate range as %.3f -> %.3f' % (lr_find_min, lr_find_max))

        # show plot
        finder_fig, finder_ax = vis_loss_history.get_figure_subplots(2)
        lr_finder.plot_lr_vs_acc(finder_ax[0])
        lr_finder.plot_lr_vs_loss(finder_ax[1])
        if self.draw_plot is True:
            plt.show()
        else:
            plt.savefig('figures/test_find_lr_plots.png', bbox_inches='tight')

        # train the network with the discovered parameters
        trainer.print_every = 200
        trainer.train()

        train_fig, train_ax = vis_loss_history.get_figure_subplots()
        vis_loss_history.plot_train_history_2subplots(
            train_ax,
            trainer.get_loss_history(),
            acc_curve = trainer.get_acc_history(),
            iter_per_epoch = trainer.iter_per_epoch,
            cur_epoch = trainer.cur_epoch
        )
        if self.draw_plot is True:
            plt.show()
        else:
            train_fig.savefig('figures/test_find_lr_train_results.png', bbox_inches='tight')
Пример #3
0
def main() -> None:
    model = get_model(GLOBAL_OPTS['model'])
    t = get_trainer(model, "superconverge_cifar10_", trainer_type="cifar")
    lr_finder = get_lr_finder(t, 1e-7, 1.0)
    find_start_time = time.time()
    lr_finder.find()
    lr_min, lr_max = lr_finder.get_lr_range()
    find_end_time = time.time()
    find_total_time = find_end_time - find_start_time
    print('Found learning rate range as %.4f -> %.4f' % (lr_min, lr_max))
    print('Learning rate search took %s' % str(find_total_time))        # TODO : better string format

    if GLOBAL_OPTS['sched_stepsize'] > 0:
        stepsize = GLOBAL_OPTS['sched_stepsize']
    else:
        stepsize = int(len(t.train_loader) / 2)

    # get a scheduler for learning rate
    lr_scheduler = get_scheduler(
        lr_min,
        lr_max,
        stepsize,
        sched_type='TriangularScheduler'
    )
    # get a scheduler for momentum
    mtm_scheduler = get_scheduler(
        lr_min,
        lr_max,
        stepsize,
        sched_type='InvTriangularScheduler'
    )
    t.set_lr_scheduler(lr_scheduler)
    t.set_mtm_scheduler(mtm_scheduler)
    train_start_time = time.time()
    t.train()
    train_end_time = time.time()
    train_total_time = train_end_time - train_start_time
    print('Training time : %s' % str(train_total_time))     # TODO: check string formatting

    # plot outputs
    loss_title   = 'CIFAR10 Superconvergence Loss'
    acc_title    = 'CIFAR10 Superconvergence Accuracy'
    fig_filename = 'figures/cifar10_superconv_test.png'
    train_fig, train_ax = vis_loss_history.get_figure_subplots(2)
    vis_loss_history.plot_train_history_2subplots(
        train_ax,
        t.get_loss_history(),
        acc_history = t.get_acc_history(),
        cur_epoch = t.cur_epoch,
        iter_per_epoch = t.iter_per_epoch,
        loss_title = loss_title,
        acc_title = acc_title
    )
    train_fig.tight_layout()
    train_fig.savefig(fig_filename)
Пример #4
0
    def test_train(self) -> None:
        test_checkpoint_name = self.checkpoint_dir + 'resnet_trainer_train_checkpoint.pkl'
        test_history_name    = self.checkpoint_dir + 'resnet_trainer_train_history.pkl'
        train_num_epochs = 4
        train_batch_size = 128
        # get a model
        model = resnets.WideResnet(
            depth = self.resnet_depth,
            num_classes = 10,     # using CIFAR-10 data
            input_channels=3,
            w_factor=1
        )
        # get a traner
        trainer = resnet_trainer.ResnetTrainer(
            model,
            # training parameters
            batch_size    = train_batch_size,
            num_epochs    = train_num_epochs,
            learning_rate = self.test_learning_rate,
            # device
            device_id = util.get_device_id(),
            # checkpoint
            checkpoint_dir = self.checkpoint_dir,
            checkpoint_name = 'resnet_trainer_test',
            # display,
            print_every = self.print_every,
            save_every = 5000,
            verbose = self.verbose
        )

        if self.verbose:
            print('Created %s object' % repr(trainer))
            print(trainer)

        print('Training model %s for %d epochs (batch size = %d)' %\
              (repr(trainer), train_num_epochs, train_batch_size)
        )
        trainer.train()

        # save the final checkpoint
        trainer.save_checkpoint(test_checkpoint_name)
        trainer.save_history(test_history_name)

        fig, ax = vis_loss_history.get_figure_subplots()
        vis_loss_history.plot_train_history_2subplots(
            ax,
            trainer.loss_history,
            acc_history = trainer.acc_history,
            cur_epoch = trainer.cur_epoch,
            iter_per_epoch = trainer.iter_per_epoch
        )
        fig.savefig('figures/resnet_trainer_train_history.png', bbox_inches='tight')

        print('======== TestResnetTrainer.test_train <END>')
Пример #5
0
def generate_plot(trainer, loss_title, acc_title, fig_filename):
    train_fig, train_ax = vis_loss_history.get_figure_subplots(2)
    vis_loss_history.plot_train_history_2subplots(
        train_ax,
        trainer.get_loss_history(),
        acc_history=trainer.get_acc_history(),
        cur_epoch=trainer.cur_epoch,
        iter_per_epoch=trainer.iter_per_epoch,
        loss_title=loss_title,
        acc_title=acc_title)
    train_fig.tight_layout()
    train_fig.savefig(fig_filename)
Пример #6
0
    def test_model_param_save(self) -> None:
        # get a trainer, etc
        trainer = get_trainer()
        lr_finder = lr_common.LogFinder(
            trainer,
            lr_min      = GLOBAL_TEST_PARAMS['test_lr_min'],
            lr_max      = GLOBAL_TEST_PARAMS['test_lr_max'],
            num_iter    = GLOBAL_TEST_PARAMS['test_num_iter'],
            num_epochs  = GLOBAL_TEST_PARAMS['test_lr_num_epochs'],
            acc_test    = True,
            max_batches = self.test_max_batches,
            verbose     = self.verbose
        )

        # shut linter up
        if self.verbose:
            print(lr_finder)

        # make a copy of the model parameters before we start looking for a new
        # learning rate.
        lr_find_min, lr_find_max = lr_finder.find()
        # show plot
        fig1, ax1 = plt.subplots()
        lr_finder.plot_lr_vs_acc(ax1)

        # now check that the restored parameters match the copy of the
        # parameters save earlier

        if self.draw_plot is True:
            plt.show()
        else:
            plt.savefig('figures/test_lr_range_find_lr_vs_acc.png', bbox_inches='tight')

        trainer.print_every = 200
        trainer.train()

        fig2, ax2 = vis_loss_history.get_figure_subplots()
        vis_loss_history.plot_train_history_2subplots(
            ax2,
            trainer.get_loss_history(),
            acc_curve = trainer.get_acc_history(),
            iter_per_epoch = trainer.iter_per_epoch,
            cur_epoch = trainer.cur_epoch
        )
        if self.draw_plot is True:
            plt.show()
        else:
            plt.savefig('figures/test_lr_range_find_train_results.png', bbox_inches='tight')
Пример #7
0
    def test_lr_range_find(self) -> None:
        trainer = get_trainer()
        lr_finder = lr_common.LogFinder(
            trainer,
            lr_min     = GLOBAL_TEST_PARAMS['test_lr_min'],
            lr_max     = GLOBAL_TEST_PARAMS['test_lr_max'],
            num_iter   = GLOBAL_TEST_PARAMS['test_num_iter'],
            num_epochs = GLOBAL_TEST_PARAMS['test_lr_num_epochs'],
            acc_test   = True,
            max_batches    = self.test_max_batches,
            verbose    = self.verbose
        )

        # shut linter up
        if self.verbose:
            print(lr_finder)

        lr_find_min, lr_find_max = lr_finder.find()
        # show plot
        fig1, ax1 = plt.subplots()
        lr_finder.plot_lr_vs_acc(ax1)

        if self.draw_plot is True:
            plt.show()
        else:
            plt.savefig('figures/test_lr_range_find_lr_vs_acc.png', bbox_inches='tight')

        trainer.print_every = 200
        trainer.train()

        fig2, ax2 = vis_loss_history.get_figure_subplots()
        vis_loss_history.plot_train_history_2subplots(
            ax2,
            trainer.get_loss_history(),
            acc_curve = trainer.get_acc_history(),
            iter_per_epoch = trainer.iter_per_epoch,
            cur_epoch = trainer.cur_epoch
        )
        if self.draw_plot is True:
            plt.show()
        else:
            plt.savefig('figures/test_lr_range_find_train_results.png', bbox_inches='tight')
Пример #8
0
            print('Creating subdir [%s] for schedule %d/%d [%s]' %\
                  (subdir, idx+1, len(schedulers), str(schedulers[idx]))
            )
            os.mkdir(subdir)
            writer = tensorboard.SummaryWriter(log_dir=subdir)
            trainer.set_tb_writer(writer)

        if idx == 0:
            lr_find_min, lr_find_max, lr_finder = find_lr(trainer,
                                                          return_finder=True)
            # create plots
            lr_acc_title  = '[' + str(GLOBAL_OPTS['model']) + '[' + str(GLOBAL_OPTS['find_lr_select_method']) + '] ' +\
                str(schedulers[idx]) + ' learning rate vs acc (log)'
            lr_loss_title = '[' + str(GLOBAL_OPTS['model']) + '[' + str(GLOBAL_OPTS['find_lr_select_method']) + '] ' +\
                str(schedulers[idx]) + ' learning rate vs loss (log)'
            lr_fig, lr_ax = vis_loss_history.get_figure_subplots(2)
            lr_finder.plot_lr_vs_acc(lr_ax[0], lr_acc_title, log=True)
            lr_finder.plot_lr_vs_loss(lr_ax[1], lr_loss_title, log=True)
            # save
            lr_fig.tight_layout()
            lr_fig.savefig('figures/[%s][%s]_%s_lr_finder_output.png' %\
                           (str(GLOBAL_OPTS['model']), str(GLOBAL_OPTS['find_lr_select_method']), str(schedulers[idx]))
            )

            print('Found learning rates as %.4f -> %.4f' %
                  (lr_find_min, lr_find_max))
            #if GLOBAL_OPTS['tensorboard_dir'] is not None:
            #    writer.add_hparams(
            #        hparam_dict = {
            #            'lr_find_min': lr_find_min,
            #            'lr_find_max': lr_find_max
Пример #9
0
def main() -> None:
    gan_data_transform = transforms.Compose([
        transforms.Resize(GLOBAL_OPTS['image_size']),
        transforms.CenterCrop(GLOBAL_OPTS['image_size']),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    if GLOBAL_OPTS['dataset'] is not None:
        train_dataset = lmdb_dataset.LMDBDataset(GLOBAL_OPTS['dataset'],
                                                 transform=gan_data_transform)
    else:

        train_dataset = datasets.ImageFolder(root=GLOBAL_OPTS['dataset_root'],
                                             transform=gan_data_transform)

    # get some models
    generator = dcgan.DCGANGenerator(zvec_dim=GLOBAL_OPTS['zvec_dim'],
                                     num_filters=GLOBAL_OPTS['g_num_filters'],
                                     img_size=GLOBAL_OPTS['image_size'])
    discriminator = dcgan.DCGANDiscriminator(
        num_filters=GLOBAL_OPTS['d_num_filters'],
        img_size=GLOBAL_OPTS['image_size'])

    # get a trainer
    gan_trainer = dcgan_trainer.DCGANTrainer(
        discriminator,
        generator,
        #  DCGAN trainer specific arguments
        beta1=GLOBAL_OPTS['beta1'],
        # general trainer arguments
        train_dataset=train_dataset,
        # training opts
        learning_rate=GLOBAL_OPTS['learning_rate'],
        num_epochs=GLOBAL_OPTS['num_epochs'],
        batch_size=GLOBAL_OPTS['batch_size'],
        # Checkpoints
        save_every=GLOBAL_OPTS['save_every'],
        checkpoint_name=GLOBAL_OPTS['checkpoint_name'],
        checkpoint_dir=GLOBAL_OPTS['checkpoint_dir'],
        # display
        print_every=GLOBAL_OPTS['print_every'],
        verbose=GLOBAL_OPTS['verbose'],
        # device
        device_id=GLOBAL_OPTS['device_id'])

    if GLOBAL_OPTS['load_checkpoint'] is not None:
        gan_trainer.load_checkpoint(GLOBAL_OPTS['load_checkpoint'])

    print(gan_trainer.device)

    # Get a scheduler
    lr_scheduler = schedule.DecayToEpoch(
        non_decay_time=int(GLOBAL_OPTS['num_epochs'] // 2),
        decay_length=int(GLOBAL_OPTS['num_epochs'] // 2),
        initial_lr=GLOBAL_OPTS['learning_rate'],
        final_lr=0.0)
    print('Created scheduler')
    print(lr_scheduler)
    gan_trainer.set_lr_scheduler(lr_scheduler)

    train_start_time = time.time()
    gan_trainer.train()
    train_end_time = time.time()
    train_total_time = train_end_time - train_start_time
    print('Total training time : %s' %
          str(timedelta(seconds=train_total_time)))

    # show the training results
    dcgan_fig, dcgan_ax = vis_loss_history.get_figure_subplots(1)
    vis_loss_history.plot_train_history_dcgan(
        dcgan_ax,
        gan_trainer.get_g_loss_history(),
        gan_trainer.get_d_loss_history(),
        cur_epoch=gan_trainer.cur_epoch,
        iter_per_epoch=gan_trainer.iter_per_epoch)

    dcgan_fig.tight_layout()
    dcgan_fig.savefig('figures/dcgan_train_history.png')
Пример #10
0
def main() -> None:
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset_transform = transforms.Compose([
        transforms.RandomRotation(5),
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(224,
                                     scale=(0.96, 1.0),
                                     ratio=(0.95, 1.05)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    test_dataset_transform = transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # HDF5 Datasets
    if GLOBAL_USE_HDF5 is True:
        cvd_train_dataset = hdf5_dataset.HDF5Dataset(
            GLOBAL_OPTS['train_dataset'],
            feature_name='images',
            label_name='labels',
            label_max_dim=1,
            transform=normalize)

        cvd_val_dataset = hdf5_dataset.HDF5Dataset(GLOBAL_OPTS['test_dataset'],
                                                   feature_name='images',
                                                   label_name='labels',
                                                   label_max_dim=1,
                                                   transform=normalize)
    else:
        cvd_train_dir = '/home/kreshnik/ml-data/cats-vs-dogs/train'
        # ImageFolder dataset
        cvd_train_dataset = datasets.ImageFolder(cvd_train_dir,
                                                 train_dataset_transform)

        csv_val_dir = '/home/kreshnik/ml-data/cats-vs-dogs/test'
        cvd_val_dataset = datasets.ImageFolder(csv_val_dir,
                                               test_dataset_transform)

    # get a network
    model = cvdnet.CVDNet2()
    cvd_train = cvd_trainer.CVDTrainer(
        model,
        # dataset options
        train_dataset=cvd_train_dataset,
        val_dataset=cvd_val_dataset,
        # training options
        loss_function='CrossEntropyLoss',
        learning_rate=GLOBAL_OPTS['learning_rate'],
        weight_decay=GLOBAL_OPTS['weight_decay'],
        momentum=GLOBAL_OPTS['momentum'],
        num_epochs=GLOBAL_OPTS['num_epochs'],
        batch_size=GLOBAL_OPTS['batch_size'],
        val_batch_size=GLOBAL_OPTS['val_batch_size'],
        # checkpoint
        checkpoint_name=GLOBAL_OPTS['checkpoint_name'],
        save_every=GLOBAL_OPTS['save_every'],
        # device
        device_id=GLOBAL_OPTS['device_id'],
        # other
        print_every=GLOBAL_OPTS['print_every'],
        verbose=GLOBAL_OPTS['verbose'])

    # Add a tensorboard writer
    if GLOBAL_OPTS['tensorboard_dir'] is not None:
        if not os.path.isdir(GLOBAL_OPTS['tensorboard_dir']):
            os.mkdir(GLOBAL_OPTS['tensorboard_dir'])
        writer = tensorboard.SummaryWriter(
            log_dir=GLOBAL_OPTS['tensorboard_dir'])
        cvd_train.set_tb_writer(writer)

    # Optionally do a search pass here and add a scheduler
    if GLOBAL_OPTS['find_lr']:
        lr_finder = expr_util.get_lr_finder(cvd_train)
        lr_find_start_time = time.time()
        lr_finder.find()
        lr_find_min, lr_find_max = lr_finder.get_lr_range()
        lr_find_end_time = time.time()
        lr_find_total_time = lr_find_end_time - lr_find_start_time
        print('Found learning rate range %.4f -> %.4f' %
              (lr_find_min, lr_find_max))
        print('Total find time [%s] ' %\
                str(timedelta(seconds = lr_find_total_time))
        )

        # Now get a scheduler
        stepsize = cvd_train.get_num_epochs() * len(trainer.train_loader) // 2
        # get scheduler
        lr_scheduler = expr_util.get_scheduler(
            lr_find_min,
            lr_find_max,
            stepsize,
            sched_type='TriangularScheduler')
        cvd_train.set_lr_scheduler(lr_scheduler)

    # train the model
    train_start_time = time.time()
    cvd_train.train()
    train_end_time = time.time()
    train_total_time = train_end_time - train_start_time

    print('Total scheduled training time [%s] (%d epochs)  %s' %\
            (repr(cvd_train), cvd_train.cur_epoch,
             str(timedelta(seconds = train_total_time)))
    )

    # Show results
    fig, ax = vis_loss_history.get_figure_subplots(num_subplots=2)
    vis_loss_history.plot_train_history_2subplots(
        ax,
        cvd_train.get_loss_history(),
        acc_history=cvd_train.get_acc_history(),
        iter_per_epoch=cvd_train.iter_per_epoch,
        loss_title='CVD Loss',
        acc_title='CVD Acc',
        cur_epoch=cvd_train.cur_epoch)

    fig.savefig('figures/cvd_train.png')
Пример #11
0
def main() -> None:
    train_dataset, val_dataset = get_datasets(GLOBAL_OPTS['data_dir'])

    # get some models
    encoder = denoise_ae.DAEEncoder(num_channels=1)
    decoder = denoise_ae.DAEDecoder(num_channels=1)

    trainer = dae_trainer.DAETrainer(
        encoder,
        decoder,
        # datasets
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        device_id=GLOBAL_OPTS['device_id'],
        # trainer params
        batch_size=GLOBAL_OPTS['batch_size'],
        num_epochs=GLOBAL_OPTS['num_epochs'],
        learning_rate=GLOBAL_OPTS['learning_rate'],
        # checkpoints, saving, etc
        checkpoint_dir=GLOBAL_OPTS['checkpoint_dir'],
        checkpoint_name=GLOBAL_OPTS['checkpoint_name'],
        save_every=GLOBAL_OPTS['save_every'],
        print_every=GLOBAL_OPTS['print_every'],
        verbose=GLOBAL_OPTS['verbose'])

    # Add a summary writer
    if GLOBAL_OPTS['tensorboard_dir'] is not None:
        if GLOBAL_OPTS['verbose']:
            print('Adding tensorboard writer to [%s]' % repr(trainer))
        writer = tensorboard.SummaryWriter(
            log_dir=GLOBAL_OPTS['tensorboard_dir'])
        trainer.set_tb_writer(writer)

    # train the model
    train_start_time = time.time()
    trainer.train()
    train_end_time = time.time()
    train_total_time = train_end_time - train_start_time

    print('Trained [%s] for %d epochs, total time : %s' %\
          (repr(trainer), trainer.cur_epoch, str(timedelta(seconds = train_total_time)))
    )

    # Take the models and load them into an inferrer
    if GLOBAL_OPTS['infer']:
        trainer.drop_last = True
        # make the number of squares in the output figure settable later
        subplot_square = 8
        trainer.set_batch_size(
            subplot_square *
            subplot_square)  # So that we have an NxN grid of outputs
        # Get some figure stuff
        out_fig, out_ax_list = vis_img.get_grid_subplots(subplot_square)
        noise_fig, noise_ax_list = vis_img.get_grid_subplots(subplot_square)

        # Get an inferrer
        inferrer = dae_inferrer.DAEInferrer(
            trainer.encoder,
            trainer.decoder,
            noise_bias=GLOBAL_OPTS['noise_bias'],
            noise_factor=GLOBAL_OPTS['noise_factor'],
            device_id=GLOBAL_OPTS['device_id'])
        if GLOBAL_OPTS['verbose']:
            print('Created [%s] object and attached models [%s], [%s]' %\
                  (repr(trainer), repr(trainer.encoder), repr(trainer.decoder))
            )

        infer_start_time = time.time()
        for batch_idx, (data, _) in enumerate(trainer.val_loader):
            print('Inferring batch [%d / %d]' %
                  (batch_idx + 1, len(trainer.val_loader)),
                  end='\r')
            noise_batch = inferrer.get_noise(data)
            out_batch = inferrer.forward(data)

            # Plot noise
            plot_denoise(noise_ax_list, noise_batch)
            noise_fig_fname = 'figures/dae/mnist_dae_batch_%d_noise.png' % int(
                batch_idx)
            noise_fig.tight_layout()
            noise_fig.savefig(noise_fig_fname)

            # Plot outputs
            plot_denoise(out_ax_list, out_batch)
            out_fig_fname = 'figures/dae/mnist_dae_batch_%d_output.png' % int(
                batch_idx)
            out_fig.tight_layout()
            out_fig.savefig(out_fig_fname)

        print('\n OK')
        infer_end_time = time.time()
        infer_total_time = infer_end_time - infer_start_time
        print('Inferrer [%s] inferrer %d batches of size %d, total time : %s' %\
            (repr(inferrer), len(trainer.val_loader), trainer.batch_size, str(timedelta(seconds = infer_total_time)))
        )

    # Plot the loss history
    hist_fig, hist_ax = vis_loss_history.get_figure_subplots()
    vis_loss_history.plot_train_history_2subplots(
        hist_ax,
        trainer.get_loss_history(),
        cur_epoch=trainer.cur_epoch,
        iter_per_epoch=trainer.iter_per_epoch,
        loss_title='Denoising AE MNIST Training loss')
    hist_fig.savefig(GLOBAL_OPTS['loss_history_file'], bbox_inches='tight')
Пример #12
0
    def test_save_load_checkpoint(self) -> None:
        test_checkpoint_name = self.checkpoint_dir + 'resnet_trainer_checkpoint.pkl'
        test_history_name    = self.checkpoint_dir + 'resnet_trainer_history.pkl'
        # get a model
        model = resnets.WideResnet(
            depth = self.resnet_depth,
            num_classes = 10,     # using CIFAR-10 data
            input_channels=3,
            w_factor = 1
        )
        # get a traner
        src_tr = resnet_trainer.ResnetTrainer(
            model,
            # training parameters
            batch_size    = self.test_batch_size,
            num_epochs    = self.test_num_epochs,
            learning_rate = self.test_learning_rate,
            # device
            device_id     = util.get_device_id(),
            # display,
            print_every   = self.print_every,
            save_every    = 0,
            verbose       = self.verbose
        )

        if self.verbose:
            print('Created %s object' % repr(src_tr))
            print(src_tr)

        print('Training model %s for %d epochs' % (repr(src_tr), self.test_num_epochs))
        src_tr.train()

        # save the final checkpoint
        src_tr.save_checkpoint(test_checkpoint_name)
        src_tr.save_history(test_history_name)

        # get a new trainer and load checkpoint
        dst_tr = resnet_trainer.ResnetTrainer(
            model
        )
        dst_tr.load_checkpoint(test_checkpoint_name)

        # Test object parameters
        assert src_tr.num_epochs == dst_tr.num_epochs
        assert src_tr.learning_rate == dst_tr.learning_rate
        assert src_tr.weight_decay == dst_tr.weight_decay
        assert src_tr.print_every == dst_tr.print_every
        assert src_tr.save_every == dst_tr.save_every

        print('\t Comparing model parameters ')
        src_model_params = src_tr.get_model_params()
        dst_model_params = dst_tr.get_model_params()
        assert len(src_model_params.items()) == len(dst_model_params.items())

        # p1, p2 are k,v tuple pairs of each model parameters
        # k = str
        # v = torch.Tensor
        for n, (p1, p2) in enumerate(zip(src_model_params.items(), dst_model_params.items())):
            assert p1[0] == p2[0]
            print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n+1, len(src_model_params.items())), end='\r')
            assert torch.equal(p1[1], p2[1]) == True
        print('\n ...done')

        # test history
        dst_tr.load_history(test_history_name)

        # loss history
        assert len(src_tr.loss_history) == len(dst_tr.loss_history)
        assert src_tr.loss_iter == dst_tr.loss_iter
        for n in range(len(src_tr.loss_history)):
            assert src_tr.loss_history[n] == dst_tr.loss_history[n]

        # test loss history
        assert len(src_tr.val_loss_history) == len(dst_tr.val_loss_history)
        assert src_tr.val_loss_iter == dst_tr.val_loss_iter
        for n in range(len(src_tr.val_loss_history)):
            assert src_tr.val_loss_history[n] == dst_tr.val_loss_history[n]

        # test acc history
        assert len(src_tr.acc_history) == len(dst_tr.acc_history)
        assert src_tr.acc_iter == dst_tr.acc_iter
        for n in range(len(src_tr.acc_history)):
            assert src_tr.acc_history[n] == dst_tr.acc_history[n]

        fig, ax = vis_loss_history.get_figure_subplots()
        vis_loss_history.plot_train_history_2subplots(
            ax,
            src_tr.loss_history,
            acc_history = src_tr.acc_history,
            cur_epoch = src_tr.cur_epoch,
            iter_per_epoch = src_tr.iter_per_epoch
        )
        fig.savefig('figures/resnet_trainer_train_test_history.png', bbox_inches='tight')
Пример #13
0
def main() -> None:
    # Get a model
    model = cifar.CIFAR10Net()
    # Get a trainer
    trainer = cifar_trainer.CIFAR10Trainer(
        model,
        # training parameters
        batch_size=GLOBAL_OPTS['batch_size'],
        num_epochs=GLOBAL_OPTS['num_epochs'],
        learning_rate=GLOBAL_OPTS['learning_rate'],
        momentum=GLOBAL_OPTS['momentum'],
        weight_decay=GLOBAL_OPTS['weight_decay'],
        # device
        device_id=GLOBAL_OPTS['device_id'],
        # checkpoint
        checkpoint_dir=GLOBAL_OPTS['checkpoint_dir'],
        checkpoint_name=GLOBAL_OPTS['checkpoint_name'],
        # display,
        print_every=GLOBAL_OPTS['print_every'],
        save_every=GLOBAL_OPTS['save_every'],
        verbose=GLOBAL_OPTS['verbose'])

    if GLOBAL_OPTS['tensorboard_dir'] is not None:
        writer = tensorboard.SummaryWriter(
            log_dir=GLOBAL_OPTS['tensorboard_dir'])
        trainer.set_tb_writer(writer)

    # Optionally do a search pass here and add a scheduler
    if GLOBAL_OPTS['find_lr']:
        lr_finder = expr_util.get_lr_finder(trainer)
        lr_find_start_time = time.time()
        lr_finder.find()
        lr_find_min, lr_find_max = lr_finder.get_lr_range()
        lr_find_end_time = time.time()
        lr_find_total_time = lr_find_end_time - lr_find_start_time
        print('Found learning rate range %.4f -> %.4f' %
              (lr_find_min, lr_find_max))
        print('Total find time [%s] ' %\
                str(timedelta(seconds = lr_find_total_time))
        )

        # Now get a scheduler
        stepsize = trainer.get_num_epochs() * len(trainer.train_loader) // 2
        # get scheduler
        lr_scheduler = expr_util.get_scheduler(
            lr_find_min,
            lr_find_max,
            stepsize,
            sched_type='TriangularScheduler')
        trainer.set_lr_scheduler(lr_scheduler)

    # train the model
    train_start_time = time.time()
    trainer.train()
    train_end_time = time.time()
    train_total_time = train_end_time - train_start_time
    print('Total training time [%s] (%d epochs)  %s' %\
            (repr(trainer), trainer.cur_epoch,
             str(timedelta(seconds = train_total_time)))
    )

    # Visualise the output
    train_fig, train_ax = vis_loss_history.get_figure_subplots()
    vis_loss_history.plot_train_history_2subplots(
        train_ax,
        trainer.get_loss_history(),
        acc_history=trainer.get_acc_history(),
        cur_epoch=trainer.cur_epoch,
        iter_per_epoch=trainer.iter_per_epoch,
        loss_title='CIFAR-10 Training Loss',
        acc_title='CIFAR-10 Training Accuracy ')
    train_fig.savefig(GLOBAL_OPTS['fig_name'], bbox_inches='tight')
Пример #14
0
def main() -> None:
    # get a model and train it as a reference
    ref_model = get_model()
    ref_trainer = get_trainer(ref_model, 'ex_cifar10_lr_find_schedule_')

    if GLOBAL_OPTS['tensorboard_dir'] is not None:
        ref_writer = tensorboard.SummaryWriter(log_dir=GLOBAL_OPTS['tensorboard_dir'])
        ref_trainer.set_tb_writer(ref_writer)

    ref_train_start_time = time.time()
    ref_trainer.train()
    ref_train_end_time = time.time()
    ref_train_total_time = ref_train_end_time - ref_train_start_time
    print('Total reference training time [%s] (%d epochs)  %s' %\
            (repr(ref_trainer), ref_trainer.cur_epoch,
             str(timedelta(seconds = ref_train_total_time)))
    )

    # get a model and train it with a scheduler
    sched_model = get_model()
    sched_trainer = get_trainer(sched_model, 'ex_cifar10_lr_find_schedule_')

    if GLOBAL_OPTS['tensorboard_dir'] is not None:
        if not os.path.isdir(GLOBAL_OPTS['tensorboard_dir']):
            os.mkdir(GLOBAL_OPTS['tensorboard_dir'])
        sched_writer = tensorboard.SummaryWriter(log_dir=GLOBAL_OPTS['tensorboard_dir'])
        sched_trainer.set_tb_writer(sched_writer)

    # get an LRFinder object
    lr_finder = lr_common.LogFinder(
        sched_trainer,
        lr_min         = GLOBAL_OPTS['find_lr_min'],
        lr_max         = GLOBAL_OPTS['find_lr_max'],
        num_epochs     = GLOBAL_OPTS['find_num_epochs'],
        explode_thresh = GLOBAL_OPTS['find_explode_thresh'],
        print_every    = GLOBAL_OPTS['find_print_every']
    )
    print(lr_finder)

    lr_find_start_time = time.time()
    lr_finder.find()
    lr_find_min, lr_find_max = lr_finder.get_lr_range()
    lr_find_end_time = time.time()
    lr_find_total_time = lr_find_end_time - lr_find_start_time
    print('Total parameter search time : %s' % str(timedelta(seconds = lr_find_total_time)))

    if GLOBAL_OPTS['verbose']:
        print('Found learning rate range as %.4f -> %.4f' % (lr_find_min, lr_find_max))

    # get a scheduler
    lr_sched_obj = getattr(schedule, GLOBAL_OPTS['sched_type'])
    lr_scheduler = lr_sched_obj(
        stepsize = int(len(sched_trainer.train_loader) / 4),
        lr_min = lr_find_min,
        lr_max = lr_find_max
    )
    assert(sched_trainer.acc_iter == 0)
    sched_trainer.set_lr_scheduler(lr_scheduler)

    sched_train_start_time = time.time()
    sched_trainer.train()
    sched_train_end_time = time.time()
    sched_train_total_time = sched_train_end_time - sched_train_start_time
    print('Total scheduled training time [%s] (%d epochs)  %s' %\
            (repr(sched_trainer), ref_trainer.cur_epoch,
             str(timedelta(seconds = sched_train_total_time)))
    )
    print('Scheduled training time (including find time) : %s' %\
          str(timedelta(seconds = sched_train_total_time + lr_find_total_time))
    )

    # Compare loss, accuracy
    fig, ax = vis_loss_history.get_figure_subplots(2)