示例#1
0
    def test_history_extend(self) -> None:
        test_checkpoint = 'checkpoint/test_history_extend.pkl'
        test_history = 'checkpoint/test_history_extend_history.pkl'
        model = cifar.CIFAR10Net()
        # TODO : adjust this so we don't need 10 epochs of training
        # Get trainer object
        test_num_epochs = 10
        trainer = cifar_trainer.CIFAR10Trainer(
            model,
            save_every    = 0,
            print_every   = 50,
            device_id     = util.get_device_id(),
            # loader options,
            num_epochs    = test_num_epochs,
            learning_rate = 3e-4,
            batch_size    = 64,
            num_workers   = self.test_num_workers,
        )
        print('Training original model')
        trainer.train()
        trainer.save_checkpoint(test_checkpoint)
        trainer.save_history(test_history)

        # Load a new trainer, train for another 10 epochs (20 total)
        extend_trainer = cifar_trainer.CIFAR10Trainer(
            model,
            save_every    = 0,
            print_every   = 50,
            device_id     = util.get_device_id(),
            # loader options,
            num_epochs    = 10,
            learning_rate = 3e-4,
            batch_size    = 64,
            num_workers   = self.test_num_workers,
        )
        print('Loading checkpoint [%s] into extend trainer...' % str(test_checkpoint))
        extend_trainer.load_checkpoint(test_checkpoint)
        print('Loading history [%s] into extend trainer...' % str(test_history))
        extend_trainer.load_history(test_history)
        # Check history before extending
        assert trainer.device_id == extend_trainer.device_id
        print('extend_trainer device : %s' % str(extend_trainer.device))
        assert test_num_epochs == extend_trainer.num_epochs
        assert 10 == extend_trainer.cur_epoch
        assert extend_trainer.loss_history is not None
        assert (10 * len(extend_trainer.train_loader)) ==  len(extend_trainer.loss_history)

        extend_trainer.set_num_epochs(20)
        assert(20 * len(extend_trainer.train_loader) == len(extend_trainer.loss_history))
        for i in range(10 * len(extend_trainer.train_loader)):
            print('Checking loss iter [%d / %d]' % (i, 20 * len(extend_trainer.train_loader)), end='\r')
            assert trainer.loss_history[i] == extend_trainer.loss_history[i]

        extend_trainer.train()
示例#2
0
def get_trainer(model:common.LernomaticModel,
                checkpoint_name:str,
                trainer_type:str='cifar') -> trainer.Trainer:
    if trainer_type == 'cifar':
        t = cifar_trainer.CIFAR10Trainer(
            model,
            # initial learning rate
            learning_rate   = GLOBAL_OPTS['learning_rate'],
            num_epochs      = GLOBAL_OPTS['num_epochs'],
            batch_size      = GLOBAL_OPTS['batch_size'],
            stop_when_acc   = GLOBAL_OPTS['stop_when_acc'],
            # optimization
            optim_function  = 'SGD',
            # device
            device_id       = GLOBAL_OPTS['device_id'],
            # checkpoint
            checkpoint_dir  = GLOBAL_OPTS['checkpoint_dir'],
            checkpoint_name = checkpoint_name,
            # other
            verbose         = GLOBAL_OPTS['verbose'],
            print_every     = GLOBAL_OPTS['print_every'],
            save_every      = GLOBAL_OPTS['save_every']
        )

        return t
    else:
        raise ValueError("Trainer type [%s] not implemented" % str(trainer_type))
示例#3
0
def get_trainer(model: common.LernomaticModel, batch_size: int,
                checkpoint_name: str) -> cifar_trainer.CIFAR10Trainer:
    trainer = cifar_trainer.CIFAR10Trainer(
        model,
        num_epochs=4,
        checkpoint_name=checkpoint_name,
        # since we don't train for long a large learning rate helps
        learning_rate=1.5e-3,
        save_every=0,
        print_every=50,
        batch_size=batch_size,
        device_id=util.get_device_id(),
        verbose=True)
    return trainer
示例#4
0
def get_trainer(model: common.LernomaticModel, checkpoint_name: str,
                batch_size: int,
                save_every: int) -> cifar_trainer.CIFAR10Trainer:
    trainer = cifar_trainer.CIFAR10Trainer(model,
                                           batch_size=batch_size,
                                           test_batch_size=1,
                                           device_id=util.get_device_id(),
                                           checkpoint_name=checkpoint_name,
                                           save_every=save_every,
                                           save_hist=False,
                                           print_every=50,
                                           num_epochs=4,
                                           learning_rate=9e-4)

    return trainer
示例#5
0
def get_trainer() -> cifar_trainer.CIFAR10Trainer:
    # get a model to test on and its corresponding trainer
    model = cifar.CIFAR10Net()
    trainer = cifar_trainer.CIFAR10Trainer(
        model,
        # turn off checkpointing
        save_every = 0,
        print_every = GLOBAL_TEST_PARAMS['test_print_every'],
        # data options
        batch_size = GLOBAL_TEST_PARAMS['test_batch_size'],
        # training options
        learning_rate = GLOBAL_TEST_PARAMS['test_learning_rate'],
        num_epochs = GLOBAL_TEST_PARAMS['train_num_epochs'],
        device_id = util.get_device_id(),
    )

    return trainer
示例#6
0
def get_trainer(model, checkpoint_name):
    trainer = cifar_trainer.CIFAR10Trainer(
        model,
        batch_size=GLOBAL_OPTS['batch_size'],
        val_batch_size=GLOBAL_OPTS['val_batch_size'],
        num_epochs=GLOBAL_OPTS['num_epochs'],
        learning_rate=GLOBAL_OPTS['learning_rate'],
        #momentum = GLOBAL_OPTS['momentum'],
        weight_decay=GLOBAL_OPTS['weight_decay'],
        # device
        device_id=GLOBAL_OPTS['device_id'],
        # checkpoint
        checkpoint_dir=GLOBAL_OPTS['checkpoint_dir'],
        checkpoint_name=checkpoint_name,
        # display,
        print_every=GLOBAL_OPTS['print_every'],
        save_every=GLOBAL_OPTS['save_every'],
        verbose=GLOBAL_OPTS['verbose'])

    return trainer
示例#7
0
def get_trainer(learning_rate: float = None) -> cifar_trainer.CIFAR10Trainer:
    if learning_rate is not None:
        test_learning_rate = learning_rate
    else:
        test_learning_rate = GLOBAL_OPTS['learning_rate']

    model = cifar.CIFAR10Net()
    trainer = cifar_trainer.CIFAR10Trainer(
        model,
        # data options
        batch_size=GLOBAL_OPTS['batch_size'],
        num_workers=GLOBAL_OPTS['num_workers'],
        num_epochs=GLOBAL_OPTS['num_epochs'],
        # set initial learning rate
        learning_rate=test_learning_rate,
        # other options
        save_every=0,
        print_every=200,
        device_id=util.get_device_id(),
        verbose=GLOBAL_OPTS['verbose'])

    return trainer
示例#8
0
    def test_train(self) -> None:
        test_checkpoint = 'checkpoint/trainer_train_test.pkl'
        model = cifar.CIFAR10Net()

        # Get trainer object
        trainer = cifar_trainer.CIFAR10Trainer(
            model,
            save_every    = 0,
            print_every   = 50,
            device_id     = util.get_device_id(),
            # loader options,
            num_epochs    = self.test_num_epochs,
            learning_rate = 3e-4,
            batch_size    = 128,
            num_workers   = self.test_num_workers,
        )

        if self.verbose:
            print('Created trainer object')
            print(trainer)

        # train for one epoch
        trainer.train()
        trainer.save_checkpoint(test_checkpoint)

        fig, ax = get_figure_subplots()
        vis_loss_history.plot_train_history_2subplots(
            ax,
            trainer.loss_history,
            acc_history = trainer.acc_history,
            cur_epoch = trainer.cur_epoch,
            iter_per_epoch = trainer.iter_per_epoch,
            loss_title = 'CIFAR-10 Trainer Test loss',
            acc_title = 'CIFAR-10 Trainer Test accuracy'
        )
        fig.savefig('figures/trainer_train_test_history.png', bbox_inches='tight')
示例#9
0
def main() -> None:
    # Get a model
    model = cifar.CIFAR10Net()
    # Get a trainer
    trainer = cifar_trainer.CIFAR10Trainer(
        model,
        # training parameters
        batch_size=GLOBAL_OPTS['batch_size'],
        num_epochs=GLOBAL_OPTS['num_epochs'],
        learning_rate=GLOBAL_OPTS['learning_rate'],
        momentum=GLOBAL_OPTS['momentum'],
        weight_decay=GLOBAL_OPTS['weight_decay'],
        # device
        device_id=GLOBAL_OPTS['device_id'],
        # checkpoint
        checkpoint_dir=GLOBAL_OPTS['checkpoint_dir'],
        checkpoint_name=GLOBAL_OPTS['checkpoint_name'],
        # display,
        print_every=GLOBAL_OPTS['print_every'],
        save_every=GLOBAL_OPTS['save_every'],
        verbose=GLOBAL_OPTS['verbose'])

    if GLOBAL_OPTS['tensorboard_dir'] is not None:
        writer = tensorboard.SummaryWriter(
            log_dir=GLOBAL_OPTS['tensorboard_dir'])
        trainer.set_tb_writer(writer)

    # Optionally do a search pass here and add a scheduler
    if GLOBAL_OPTS['find_lr']:
        lr_finder = expr_util.get_lr_finder(trainer)
        lr_find_start_time = time.time()
        lr_finder.find()
        lr_find_min, lr_find_max = lr_finder.get_lr_range()
        lr_find_end_time = time.time()
        lr_find_total_time = lr_find_end_time - lr_find_start_time
        print('Found learning rate range %.4f -> %.4f' %
              (lr_find_min, lr_find_max))
        print('Total find time [%s] ' %\
                str(timedelta(seconds = lr_find_total_time))
        )

        # Now get a scheduler
        stepsize = trainer.get_num_epochs() * len(trainer.train_loader) // 2
        # get scheduler
        lr_scheduler = expr_util.get_scheduler(
            lr_find_min,
            lr_find_max,
            stepsize,
            sched_type='TriangularScheduler')
        trainer.set_lr_scheduler(lr_scheduler)

    # train the model
    train_start_time = time.time()
    trainer.train()
    train_end_time = time.time()
    train_total_time = train_end_time - train_start_time
    print('Total training time [%s] (%d epochs)  %s' %\
            (repr(trainer), trainer.cur_epoch,
             str(timedelta(seconds = train_total_time)))
    )

    # Visualise the output
    train_fig, train_ax = vis_loss_history.get_figure_subplots()
    vis_loss_history.plot_train_history_2subplots(
        train_ax,
        trainer.get_loss_history(),
        acc_history=trainer.get_acc_history(),
        cur_epoch=trainer.cur_epoch,
        iter_per_epoch=trainer.iter_per_epoch,
        loss_title='CIFAR-10 Training Loss',
        acc_title='CIFAR-10 Training Accuracy ')
    train_fig.savefig(GLOBAL_OPTS['fig_name'], bbox_inches='tight')
示例#10
0
    def test_save_load_checkpoint(self) -> None:
        test_checkpoint = 'checkpoint/trainer_test_save_load.pkl'
        # get a model
        model = cifar.CIFAR10Net()
        # get a trainer
        src_tr = cifar_trainer.CIFAR10Trainer(
            model,
            num_epochs  = self.test_num_epochs,
            save_every  = 0,
            device_id   = util.get_device_id(),
            batch_size  = self.test_batch_size,
            num_workers = self.test_num_workers
        )

        if self.verbose:
            print('Created trainer object')
            print(src_tr)

        # train for one epoch
        src_tr.train()
        src_tr.save_checkpoint(test_checkpoint)
        # Make a new trainer and load all parameters into that
        # I guess we need to put some kind of loader and model here...
        dst_tr = cifar_trainer.CIFAR10Trainer(
            model,
            device_id = util.get_device_id()
        )
        dst_tr.load_checkpoint(test_checkpoint)

        # Test object parameters
        assert src_tr.num_epochs == dst_tr.num_epochs
        assert src_tr.learning_rate == dst_tr.learning_rate
        assert src_tr.weight_decay == dst_tr.weight_decay
        assert src_tr.print_every == dst_tr.print_every
        assert src_tr.save_every == dst_tr.save_every
        assert src_tr.device_id == dst_tr.device_id

        print('\t Comparing model parameters ')
        src_model_params = src_tr.get_model_params()
        dst_model_params = dst_tr.get_model_params()
        assert len(src_model_params.items()) == len(dst_model_params.items())

        # p1, p2 are k,v tuple pairs of each model parameters
        # k = str
        # v = torch.Tensor
        for n, (p1, p2) in enumerate(zip(src_model_params.items(), dst_model_params.items())):
            assert p1[0] == p2[0]
            print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n+1, len(src_model_params.items())), end='\r')
            assert torch.equal(p1[1], p2[1]) == True
        print('\n ...done')

        # Test loss history
        val_loss_history = 'test_save_load_history.pkl'
        src_tr.save_history(val_loss_history)
        dst_tr.load_history(val_loss_history)
        print('\t Comparing loss history....')
        assert src_tr.loss_iter == dst_tr.loss_iter
        for n in range(src_tr.loss_iter):
            print('Checking loss element [%d/%d]' % (n, src_tr.loss_iter), end='\r')
            assert src_tr.loss_history[n] == dst_tr.loss_history[n]

        # Try to train for another epoch
        dst_tr.set_num_epochs(src_tr.num_epochs+1)
        assert dst_tr.num_epochs == src_tr.num_epochs+1
        dst_tr.train()
        assert src_tr.num_epochs+1 == dst_tr.cur_epoch

        print('\n ...done')
        os.remove(test_checkpoint)
示例#11
0
    def test_save_load_device_map(self) -> None:
        test_checkpoint = 'checkpoint/trainer_save_load_device_map.pkl'
        test_history = 'checkpoint/trainer_save_load_device_map_history.pkl'

        model = cifar.CIFAR10Net()
        # Get trainer object
        src_tr = cifar_trainer.CIFAR10Trainer(
            model,
            save_every  = 0,
            print_every = 50,
            device_id   = util.get_device_id(),
            # loader options,
            num_epochs  = self.test_num_epochs,
            batch_size  = self.test_batch_size,
            num_workers = self.test_num_workers,
        )

        if self.verbose:
            print('Created trainer object')
            print(src_tr)

        # train for one epoch
        src_tr.train()
        src_tr.save_checkpoint(test_checkpoint)
        assert src_tr.acc_history is not None
        src_tr.save_history(test_history)

        # Now try to load a checkpoint and ensure that there is an
        # acc history attribute that is not None
        dst_tr = cifar_trainer.CIFAR10Trainer(
            model,
            device_id = -1,
            verbose = self.verbose
        )
        dst_tr.load_checkpoint(test_checkpoint)

        # Test object parameters
        assert src_tr.num_epochs == dst_tr.num_epochs
        assert src_tr.learning_rate == dst_tr.learning_rate
        assert src_tr.weight_decay == dst_tr.weight_decay
        assert src_tr.print_every == dst_tr.print_every
        assert src_tr.save_every == dst_tr.save_every

        print('\t Comparing model parameters ')
        src_model_params = src_tr.get_model_params()
        dst_model_params = dst_tr.get_model_params()
        assert len(src_model_params.items()) == len(dst_model_params.items())

        print('\t Comparing model parameters ')
        src_model_params = src_tr.get_model_params()
        dst_model_params = dst_tr.get_model_params()
        assert len(src_model_params.items()) == len(dst_model_params.items())

        # p1, p2 are k,v tuple pairs of each model parameters
        # k = str
        # v = torch.Tensor
        print('Checking that tensors from checkpoint have been tranferred to new device')
        for n, (p1, p2) in enumerate(zip(src_model_params.items(), dst_model_params.items())):
            assert p1[0] == p2[0]
            print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n, len(src_model_params.items())), end='\r')
            assert 'cpu' == p2[1].device.type
        print('\n ...done')

        # Test loss history
        dst_tr.load_history(test_history)
        assert dst_tr.acc_history is not None
        print('\t Comparing loss history....')
        assert src_tr.loss_iter == dst_tr.loss_iter
        for n in range(src_tr.loss_iter):
            print('Checking loss element [%d/%d]' % (n, src_tr.loss_iter), end='\r')
            assert src_tr.loss_history[n] == dst_tr.loss_history[n]