Esempio n. 1
0
    def test_save_load_checkpoint_train(self) -> None:
        test_dataset_file = 'hdf5/trainer_unit_test.h5'
        test_checkpoint_name = 'checkpoint/save_load_test_checkpoint.pkl'

        # get a model, trainer
        model = mnist_net.MNISTNet()
        src_tr = mnist_trainer.MNISTTrainer(
            model,
            num_epochs=self.test_num_epochs,
            save_every=0,
            print_every=250,
            device_id=util.get_device_id(),
            # dataload options
            checkpoint_name='save_load_test',
            batch_size=16,
            num_workers=1
            #num_workers = GLOBAL_OPTS['num_workers']
        )

        if self.verbose:
            print('Created trainer object')
            print(src_tr)

        # train for one epoch
        src_tr.train()
        src_tr.save_checkpoint(test_checkpoint_name)

        # Make a new trainer and load all parameters into that
        # I guess we need to put some kind of loader and model here...
        new_model = mnist_net.MNISTNet()
        dst_tr = mnist_trainer.MNISTTrainer(new_model,
                                            device_id=util.get_device_id())
        dst_tr.load_checkpoint(test_checkpoint_name)

        # Test object parameters
        assert src_tr.num_epochs == dst_tr.num_epochs
        assert src_tr.learning_rate == dst_tr.learning_rate
        assert src_tr.weight_decay == dst_tr.weight_decay
        assert src_tr.print_every == dst_tr.print_every
        assert src_tr.save_every == dst_tr.save_every
        assert src_tr.device_id == dst_tr.device_id

        print('\t Comparing model parameters ')
        src_model_params = src_tr.get_model_params()
        dst_model_params = dst_tr.get_model_params()
        assert len(src_model_params.items()) == len(dst_model_params.items())

        # p1, p2 are k,v tuple pairs of each model parameters
        # k = str
        # v = torch.Tensor
        for n, (p1, p2) in enumerate(
                zip(src_model_params.items(), dst_model_params.items())):
            assert p1[0] == p2[0]
            print('Checking parameter %s [%d/%d] \t\t' %
                  (str(p1[0]), n + 1, len(src_model_params.items())),
                  end='\r')
            assert torch.equal(p1[1], p2[1]) == True
        print('\n ...done')
Esempio n. 2
0
    def test_history_extend(self) -> None:
        test_checkpoint = 'checkpoint/test_history_extend.pkl'
        test_history = 'checkpoint/test_history_extend_history.pkl'
        model = cifar.CIFAR10Net()
        # TODO : adjust this so we don't need 10 epochs of training
        # Get trainer object
        test_num_epochs = 10
        trainer = cifar_trainer.CIFAR10Trainer(
            model,
            save_every    = 0,
            print_every   = 50,
            device_id     = util.get_device_id(),
            # loader options,
            num_epochs    = test_num_epochs,
            learning_rate = 3e-4,
            batch_size    = 64,
            num_workers   = self.test_num_workers,
        )
        print('Training original model')
        trainer.train()
        trainer.save_checkpoint(test_checkpoint)
        trainer.save_history(test_history)

        # Load a new trainer, train for another 10 epochs (20 total)
        extend_trainer = cifar_trainer.CIFAR10Trainer(
            model,
            save_every    = 0,
            print_every   = 50,
            device_id     = util.get_device_id(),
            # loader options,
            num_epochs    = 10,
            learning_rate = 3e-4,
            batch_size    = 64,
            num_workers   = self.test_num_workers,
        )
        print('Loading checkpoint [%s] into extend trainer...' % str(test_checkpoint))
        extend_trainer.load_checkpoint(test_checkpoint)
        print('Loading history [%s] into extend trainer...' % str(test_history))
        extend_trainer.load_history(test_history)
        # Check history before extending
        assert trainer.device_id == extend_trainer.device_id
        print('extend_trainer device : %s' % str(extend_trainer.device))
        assert test_num_epochs == extend_trainer.num_epochs
        assert 10 == extend_trainer.cur_epoch
        assert extend_trainer.loss_history is not None
        assert (10 * len(extend_trainer.train_loader)) ==  len(extend_trainer.loss_history)

        extend_trainer.set_num_epochs(20)
        assert(20 * len(extend_trainer.train_loader) == len(extend_trainer.loss_history))
        for i in range(10 * len(extend_trainer.train_loader)):
            print('Checking loss iter [%d / %d]' % (i, 20 * len(extend_trainer.train_loader)), end='\r')
            assert trainer.loss_history[i] == extend_trainer.loss_history[i]

        extend_trainer.train()
Esempio n. 3
0
    def test_save_load(self) -> None:
        infer_test_checkpoint = 'checkpoint/infer_save_load_test.pkl'

        model = get_model()
        trainer = get_trainer(model, None, 64, 0)
        # train the model for a while
        trainer.train()
        # save a training checkpoint to disk and load it into an inferrer
        trainer.save_checkpoint(infer_test_checkpoint)

        infer = inferrer.Inferrer(device_id=util.get_device_id())
        infer.load_model(infer_test_checkpoint)

        infer_model = infer.get_model()
        trainer_model = trainer.get_model()

        # check model parameters
        train_model_params = trainer.model.get_net_state_dict()
        infer_model_params = infer.model.get_net_state_dict()
        print('Comparing models')
        for n, (p1, p2) in enumerate(
                zip(train_model_params.items(), infer_model_params.items())):
            assert p1[0] == p2[0]
            print('Checking parameter %s [%d/%d] \t\t' %
                  (str(p1[0]), n + 1, len(train_model_params.items())))
            assert torch.equal(p1[1], p2[1]) == True
        print('\n ...done')

        # run the forward pass
        test_img, _ = next(iter(trainer.val_loader))
        pred = infer.forward(test_img)
        print('Complete prediction vector (shape: %s)' % (str(pred.shape)))
        print(str(pred))
Esempio n. 4
0
    def test_train(self) -> None:
        test_checkpoint_name = self.checkpoint_dir + 'resnet_trainer_train_checkpoint.pkl'
        test_history_name    = self.checkpoint_dir + 'resnet_trainer_train_history.pkl'
        train_num_epochs = 4
        train_batch_size = 128
        # get a model
        model = resnets.WideResnet(
            depth = self.resnet_depth,
            num_classes = 10,     # using CIFAR-10 data
            input_channels=3,
            w_factor=1
        )
        # get a traner
        trainer = resnet_trainer.ResnetTrainer(
            model,
            # training parameters
            batch_size    = train_batch_size,
            num_epochs    = train_num_epochs,
            learning_rate = self.test_learning_rate,
            # device
            device_id = util.get_device_id(),
            # checkpoint
            checkpoint_dir = self.checkpoint_dir,
            checkpoint_name = 'resnet_trainer_test',
            # display,
            print_every = self.print_every,
            save_every = 5000,
            verbose = self.verbose
        )

        if self.verbose:
            print('Created %s object' % repr(trainer))
            print(trainer)

        print('Training model %s for %d epochs (batch size = %d)' %\
              (repr(trainer), train_num_epochs, train_batch_size)
        )
        trainer.train()

        # save the final checkpoint
        trainer.save_checkpoint(test_checkpoint_name)
        trainer.save_history(test_history_name)

        fig, ax = vis_loss_history.get_figure_subplots()
        vis_loss_history.plot_train_history_2subplots(
            ax,
            trainer.loss_history,
            acc_history = trainer.acc_history,
            cur_epoch = trainer.cur_epoch,
            iter_per_epoch = trainer.iter_per_epoch
        )
        fig.savefig('figures/resnet_trainer_train_history.png', bbox_inches='tight')

        print('======== TestResnetTrainer.test_train <END>')
Esempio n. 5
0
    def test_save_load_history(self) -> None:
        test_history_name = 'checkpoint/test_history.pkl'

        # get a model, trainer
        model = mnist_net.MNISTNet()
        src_tr = mnist_trainer.MNISTTrainer(
            model,
            num_epochs=self.test_num_epochs,
            save_every=0,
            print_every=250,
            device_id=util.get_device_id(),
            # dataload options
            checkpoint_name='save_load_test',
            batch_size=16,
            num_workers=1)

        if self.verbose:
            print('Created trainer object')
            print(src_tr)

        # train for one epoch
        src_tr.train()
        src_tr.save_history(test_history_name)

        # Load history into new object
        dst_tr = mnist_trainer.MNISTTrainer(
            model,
            device_id=util.get_device_id(),
        )
        dst_tr.load_history(test_history_name)

        # Test loss history
        print('\t Comparing loss history....')
        assert src_tr.loss_iter == dst_tr.loss_iter
        for n in range(src_tr.loss_iter):
            print('Checking loss element [%d/%d]' % (n, src_tr.loss_iter),
                  end='\r')
            assert src_tr.loss_history[n] == dst_tr.loss_history[n]

        print('\n ...done')
Esempio n. 6
0
def get_trainer(model: common.LernomaticModel, batch_size: int,
                checkpoint_name: str) -> cifar_trainer.CIFAR10Trainer:
    trainer = cifar_trainer.CIFAR10Trainer(
        model,
        num_epochs=4,
        checkpoint_name=checkpoint_name,
        # since we don't train for long a large learning rate helps
        learning_rate=1.5e-3,
        save_every=0,
        print_every=50,
        batch_size=batch_size,
        device_id=util.get_device_id(),
        verbose=True)
    return trainer
Esempio n. 7
0
def get_trainer(model: common.LernomaticModel, checkpoint_name: str,
                batch_size: int,
                save_every: int) -> cifar_trainer.CIFAR10Trainer:
    trainer = cifar_trainer.CIFAR10Trainer(model,
                                           batch_size=batch_size,
                                           test_batch_size=1,
                                           device_id=util.get_device_id(),
                                           checkpoint_name=checkpoint_name,
                                           save_every=save_every,
                                           save_hist=False,
                                           print_every=50,
                                           num_epochs=4,
                                           learning_rate=9e-4)

    return trainer
Esempio n. 8
0
def get_trainer() -> cifar_trainer.CIFAR10Trainer:
    # get a model to test on and its corresponding trainer
    model = cifar.CIFAR10Net()
    trainer = cifar_trainer.CIFAR10Trainer(
        model,
        # turn off checkpointing
        save_every = 0,
        print_every = GLOBAL_TEST_PARAMS['test_print_every'],
        # data options
        batch_size = GLOBAL_TEST_PARAMS['test_batch_size'],
        # training options
        learning_rate = GLOBAL_TEST_PARAMS['test_learning_rate'],
        num_epochs = GLOBAL_TEST_PARAMS['train_num_epochs'],
        device_id = util.get_device_id(),
    )

    return trainer
Esempio n. 9
0
def get_trainer(learning_rate: float = None) -> cifar_trainer.CIFAR10Trainer:
    if learning_rate is not None:
        test_learning_rate = learning_rate
    else:
        test_learning_rate = GLOBAL_OPTS['learning_rate']

    model = cifar.CIFAR10Net()
    trainer = cifar_trainer.CIFAR10Trainer(
        model,
        # data options
        batch_size=GLOBAL_OPTS['batch_size'],
        num_workers=GLOBAL_OPTS['num_workers'],
        num_epochs=GLOBAL_OPTS['num_epochs'],
        # set initial learning rate
        learning_rate=test_learning_rate,
        # other options
        save_every=0,
        print_every=200,
        device_id=util.get_device_id(),
        verbose=GLOBAL_OPTS['verbose'])

    return trainer
Esempio n. 10
0
    def test_train(self) -> None:
        test_checkpoint = 'checkpoint/trainer_train_test.pkl'
        model = cifar.CIFAR10Net()

        # Get trainer object
        trainer = cifar_trainer.CIFAR10Trainer(
            model,
            save_every    = 0,
            print_every   = 50,
            device_id     = util.get_device_id(),
            # loader options,
            num_epochs    = self.test_num_epochs,
            learning_rate = 3e-4,
            batch_size    = 128,
            num_workers   = self.test_num_workers,
        )

        if self.verbose:
            print('Created trainer object')
            print(trainer)

        # train for one epoch
        trainer.train()
        trainer.save_checkpoint(test_checkpoint)

        fig, ax = get_figure_subplots()
        vis_loss_history.plot_train_history_2subplots(
            ax,
            trainer.loss_history,
            acc_history = trainer.acc_history,
            cur_epoch = trainer.cur_epoch,
            iter_per_epoch = trainer.iter_per_epoch,
            loss_title = 'CIFAR-10 Trainer Test loss',
            acc_title = 'CIFAR-10 Trainer Test accuracy'
        )
        fig.savefig('figures/trainer_train_test_history.png', bbox_inches='tight')
Esempio n. 11
0
    def test_save_load(self) -> None:
        # Get some data
        train_ab_paths = [path for path in os.listdir(self.train_data_root)]
        val_ab_paths = [path for path in os.listdir(self.val_data_root)]

        train_dataset = aligned_dataset.AlignedDatasetHDF5(self.test_dataset)

        # Get some models - we use resnet and PatchGAN here for now. At some
        # point the bugs in the UnetGenerator also need to be solved and this
        # test should be smaller than a 'real' training run.
        generator = resnet_gen.ResnetGenerator(3, 3, num_filters=64)
        discriminator = pixel_disc.PixelDiscriminator(3 + 3, num_filters=64)

        test_checkpoint_file = 'checkpoint/pix2pix_trainer_checkpoint_test.pkl'
        test_history_file = 'checkpoint/pix2pix_trainer_history_test.pkl'
        # Get a trainer
        src_trainer = pix2pix_trainer.Pix2PixTrainer(
            generator,
            discriminator,
            # dataset
            train_dataset=train_dataset,
            val_dataset=None,
            # trainer general options
            batch_size=self.batch_size,
            device_id=util.get_device_id(),
            num_epochs=self.test_num_epochs,
            # checkpoint
            save_every=0,
            print_every=self.print_every,
        )
        src_trainer.train()

        print('Saving checkpoint to file [%s]' % str(test_checkpoint_file))
        src_trainer.save_checkpoint(test_checkpoint_file)
        print('Saving history to file [%s]' % str(test_history_file))
        src_trainer.save_history(test_history_file)

        # get a new trainer and load
        dst_trainer = pix2pix_trainer.Pix2PixTrainer(
            None, None, device_id=util.get_device_id())
        dst_trainer.load_checkpoint(test_checkpoint_file)

        # Check that some models were loaded
        assert dst_trainer.g_net is not None
        assert dst_trainer.d_net is not None
        assert repr(src_trainer.g_net) == repr(dst_trainer.g_net)
        assert repr(src_trainer.d_net) == repr(dst_trainer.d_net)

        # check model params
        src_models = [src_trainer.g_net, src_trainer.d_net]
        dst_models = [dst_trainer.g_net, dst_trainer.d_net]

        for src_mod, dst_mod in zip(src_models, dst_models):
            print('Checking parameters for model [%s]' % repr(src_mod))
            src_model_params = src_mod.get_net_state_dict()
            dst_model_params = dst_mod.get_net_state_dict()

            assert len(src_model_params.items()) == len(
                dst_model_params.items())

            # p1, p2 are k,v tuple pairs of each model parameters
            # k = str
            # v = torch.Tensor
            for n, (p1, p2) in enumerate(
                    zip(src_model_params.items(), dst_model_params.items())):
                assert p1[0] == p2[0]
                print('Checking parameter %s [%d/%d] \t\t' %
                      (str(p1[0]), n + 1, len(src_model_params.items())),
                      end='\r')
                assert torch.equal(p1[1], p2[1]) is True
            print('\n ...done')

        # check the various trainer stats
        assert src_trainer.beta1 == dst_trainer.beta1
        assert src_trainer.l1_lambda == dst_trainer.l1_lambda
        assert src_trainer.gan_mode == dst_trainer.gan_mode
        assert src_trainer.learning_rate == dst_trainer.learning_rate
        assert src_trainer.batch_size == dst_trainer.batch_size
        assert src_trainer.print_every == dst_trainer.print_every
        assert src_trainer.save_every == dst_trainer.save_every
        assert src_trainer.cur_epoch == dst_trainer.cur_epoch
        assert src_trainer.num_epochs == dst_trainer.num_epochs

        # check history
        dst_trainer.load_history(test_history_file)
        assert dst_trainer.g_loss_history is not None
        assert dst_trainer.d_loss_history is not None
        assert len(src_trainer.g_loss_history) == len(
            dst_trainer.g_loss_history)
        assert len(src_trainer.d_loss_history) == len(
            dst_trainer.d_loss_history)

        for loss_elem in range(len(src_trainer.g_loss_history)):
            print('Checking g_loss_history [%d / %d]' %
                  (loss_elem + 1, len(src_trainer.g_loss_history)),
                  end='\r')
            assert src_trainer.g_loss_history[
                loss_elem] == dst_trainer.g_loss_history[loss_elem]
        print('\n OK')

        for loss_elem in range(len(src_trainer.d_loss_history)):
            print('Checking d_loss_history [%d / %d]' %
                  (loss_elem + 1, len(src_trainer.g_loss_history)),
                  end='\r')
            assert src_trainer.d_loss_history[
                loss_elem] == dst_trainer.d_loss_history[loss_elem]
        print('\n OK')

        # we need to set the train loaders
        dst_trainer.set_train_dataset(train_dataset)
        # Now try to extent the trainer history and train for another epoch
        dst_trainer.set_num_epochs(src_trainer.num_epochs + 1)
        assert dst_trainer.cur_epoch == src_trainer.cur_epoch

        print('Continuing training for dst_trainer from epoch %d' %
              dst_trainer.cur_epoch)
        dst_trainer.train()
        assert src_trainer.num_epochs + 1 == dst_trainer.cur_epoch
Esempio n. 12
0
    def test_save_load_checkpoint(self) -> None:
        test_checkpoint = 'checkpoint/trainer_test_save_load.pkl'
        # get a model
        model = cifar.CIFAR10Net()
        # get a trainer
        src_tr = cifar_trainer.CIFAR10Trainer(
            model,
            num_epochs  = self.test_num_epochs,
            save_every  = 0,
            device_id   = util.get_device_id(),
            batch_size  = self.test_batch_size,
            num_workers = self.test_num_workers
        )

        if self.verbose:
            print('Created trainer object')
            print(src_tr)

        # train for one epoch
        src_tr.train()
        src_tr.save_checkpoint(test_checkpoint)
        # Make a new trainer and load all parameters into that
        # I guess we need to put some kind of loader and model here...
        dst_tr = cifar_trainer.CIFAR10Trainer(
            model,
            device_id = util.get_device_id()
        )
        dst_tr.load_checkpoint(test_checkpoint)

        # Test object parameters
        assert src_tr.num_epochs == dst_tr.num_epochs
        assert src_tr.learning_rate == dst_tr.learning_rate
        assert src_tr.weight_decay == dst_tr.weight_decay
        assert src_tr.print_every == dst_tr.print_every
        assert src_tr.save_every == dst_tr.save_every
        assert src_tr.device_id == dst_tr.device_id

        print('\t Comparing model parameters ')
        src_model_params = src_tr.get_model_params()
        dst_model_params = dst_tr.get_model_params()
        assert len(src_model_params.items()) == len(dst_model_params.items())

        # p1, p2 are k,v tuple pairs of each model parameters
        # k = str
        # v = torch.Tensor
        for n, (p1, p2) in enumerate(zip(src_model_params.items(), dst_model_params.items())):
            assert p1[0] == p2[0]
            print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n+1, len(src_model_params.items())), end='\r')
            assert torch.equal(p1[1], p2[1]) == True
        print('\n ...done')

        # Test loss history
        val_loss_history = 'test_save_load_history.pkl'
        src_tr.save_history(val_loss_history)
        dst_tr.load_history(val_loss_history)
        print('\t Comparing loss history....')
        assert src_tr.loss_iter == dst_tr.loss_iter
        for n in range(src_tr.loss_iter):
            print('Checking loss element [%d/%d]' % (n, src_tr.loss_iter), end='\r')
            assert src_tr.loss_history[n] == dst_tr.loss_history[n]

        # Try to train for another epoch
        dst_tr.set_num_epochs(src_tr.num_epochs+1)
        assert dst_tr.num_epochs == src_tr.num_epochs+1
        dst_tr.train()
        assert src_tr.num_epochs+1 == dst_tr.cur_epoch

        print('\n ...done')
        os.remove(test_checkpoint)
Esempio n. 13
0
    def save_load_checkpoint_train_test(self) -> None:
        model = mnist_net.MNISTNet()

        # Get trainer object
        src_tr = mnist_trainer.MNISTTrainer(
            model,
            save_every=0,
            print_every=250,
            checkpoint_name='save_load_test',
            device_id=util.get_device_id(),
            # loader options,
            num_epochs=self.test_num_epochs,
            batch_size=self.test_batch_size,
            num_workers=1
            # Need to provide both training and test datasets
        )

        if self.verbose:
            print('Created trainer object')
            print(src_tr)

        # train for one epoch
        src_tr.train()

        # Now try to load a checkpoint and ensure that there is an
        # acc history attribute that is not None
        dst_tr = mnist_trainer.MNISTTrainer(
            model,
            device_id=util.get_device_id(),
        )
        ck_fname = 'checkpoint/save_load_test_epoch-%d.pkl' % (
            test_num_epochs - 1)
        dst_tr.load_checkpoint(ck_fname)

        # Test object parameters
        assert src_tr.num_epochs == dst_tr.num_epochs
        assert src_tr.learning_rate == dst_tr.learning_rate
        assert src_tr.weight_decay == dst_tr.weight_decay
        assert src_tr.print_every == dst_tr.print_every
        assert src_tr.save_every == dst_tr.save_every
        assert src_tr.device_id == dst_tr.device_id

        print('\t Comparing model parameters ')
        src_model_params = src_tr.get_model_params()
        dst_model_params = dst_tr.get_model_params()
        assert len(src_model_params.items()) == len(dst_model_params.items())

        print('\t Comparing model parameters ')
        src_model_params = src_tr.get_model_params()
        dst_model_params = dst_tr.get_model_params()
        assert len(src_model_params.items()) == len(dst_model_params.items())

        # p1, p2 are k,v tuple pairs of each model parameters
        # k = str
        # v = torch.Tensor
        for n, (p1, p2) in enumerate(
                zip(src_model_params.items(), dst_model_params.items())):
            assert p1[0] == p2[0]
            print('Checking parameter %s [%d/%d] \t\t' %
                  (str(p1[0]), n, len(src_model_params.items())),
                  end='\r')
            assert torch.equal(p1[1], p2[1]) == True
        print('\n ...done')
Esempio n. 14
0
    def test_save_load(self) -> None:
        test_checkpoint_file = 'checkpoint/test_aae_trainer_checkpoint.pth'
        test_history_file = 'checkpoint/test_aae_trainer_history.pth'

        # get some models
        q_net = aae_common.AAEQNet(self.x_dim, self.z_dim, self.hidden_size)
        p_net = aae_common.AAEPNet(self.x_dim, self.z_dim, self.hidden_size)
        d_net = aae_common.AAEDNetGauss(self.z_dim, self.hidden_size)

        train_dataset, val_dataset = get_mnist_datasets(self.test_data_dir)

        # get a trainer
        src_trainer = aae_trainer.AAETrainer(
            q_net,
            p_net,
            d_net,
            # datasets
            train_dataset=train_dataset,
            val_dataset=val_dataset,
            # train options
            num_epochs=self.test_num_epochs,
            batch_size=self.test_batch_size,
            # misc
            print_every=self.print_every,
            save_every=0,
            device_id=util.get_device_id(),
            verbose=self.verbose)
        # generate the source parameters
        src_trainer.train()
        src_trainer.save_checkpoint(test_checkpoint_file)
        src_trainer.save_history(test_history_file)

        # get a new trainer
        dst_trainer = aae_trainer.AAETrainer(device_id=util.get_device_id())
        dst_trainer.load_checkpoint(test_checkpoint_file)

        # Check parameters of each model in turn
        src_models = [src_trainer.q_net, src_trainer.p_net, src_trainer.d_net]
        dst_models = [dst_trainer.q_net, src_trainer.p_net, src_trainer.d_net]

        for src_mod, dst_mod in zip(src_models, dst_models):

            print('\t Comparing parameters for %s model' % repr(src_mod))
            src_model_params = src_mod.get_net_state_dict()
            dst_model_params = dst_mod.get_net_state_dict()

            assert len(src_model_params.items()) == len(
                dst_model_params.items())

            # p1, p2 are k,v tuple pairs of each model parameters
            # k = str
            # v = torch.Tensor
            for n, (p1, p2) in enumerate(
                    zip(src_model_params.items(), dst_model_params.items())):
                assert p1[0] == p2[0]
                print('Checking parameter %s [%d/%d] \t\t' %
                      (str(p1[0]), n + 1, len(src_model_params.items())),
                      end='\r')
                assert torch.equal(p1[1], p2[1]) is True
            print('\n ...done')

        # History
        dst_trainer.load_history(test_history_file)
        assert dst_trainer.d_loss_history is not None
        assert dst_trainer.g_loss_history is not None
        assert dst_trainer.recon_loss_history is not None

        assert len(src_trainer.d_loss_history) == len(
            dst_trainer.d_loss_history)
        assert len(src_trainer.g_loss_history) == len(
            dst_trainer.g_loss_history)
        assert len(src_trainer.recon_loss_history) == len(
            dst_trainer.recon_loss_history)
Esempio n. 15
0
    def test_save_load(self) -> None:
        train_dataset, val_dataset = get_mnist_datasets(self.data_dir)

        # Get some models. For this test we just accept the default constructor
        # parameters (num_blocks = 4, start_size = 32, kernel_size = 3)
        encoder = denoise_ae.DAEEncoder()
        decoder = denoise_ae.DAEDecoder()

        test_checkpoint_file = 'checkpoint/dae_trainer_checkpoint.pkl'
        test_history_file = 'checkpoint/dae_trainer_history.pkl'
        src_trainer = dae_trainer.DAETrainer(
            encoder,
            decoder,
            # datasets
            train_dataset=train_dataset,
            val_dataset=val_dataset,
            device_id=util.get_device_id(),
            # trainer params
            batch_size=self.batch_size,
            num_epochs=self.test_num_epochs,
            # disable saving
            save_every=0,
            print_every=self.print_every,
            verbose=self.verbose)
        train_start_time = time.time()
        src_trainer.train()
        train_end_time = time.time()
        train_total_time = train_end_time - train_start_time

        print('Trainer %s trained %d epochs in %s' %\
                (repr(self), src_trainer.cur_epoch, str(timedelta(seconds = train_total_time)))
        )

        print('Saving checkpoint to file [%s]' % str(test_checkpoint_file))
        src_trainer.save_checkpoint(test_checkpoint_file)
        src_trainer.save_history(test_history_file)

        # get a new trainer and load
        dst_trainer = dae_trainer.DAETrainer(device_id=util.get_device_id())
        dst_trainer.load_checkpoint(test_checkpoint_file)

        # check the basic trainer params
        assert src_trainer.num_epochs == dst_trainer.num_epochs
        assert src_trainer.learning_rate == dst_trainer.learning_rate
        assert src_trainer.momentum == dst_trainer.momentum
        assert src_trainer.weight_decay == dst_trainer.weight_decay
        assert src_trainer.loss_function == dst_trainer.loss_function
        assert src_trainer.optim_function == dst_trainer.optim_function
        assert src_trainer.cur_epoch == dst_trainer.cur_epoch
        assert src_trainer.iter_per_epoch == dst_trainer.iter_per_epoch
        assert src_trainer.save_every == dst_trainer.save_every
        assert src_trainer.print_every == dst_trainer.print_every
        assert src_trainer.batch_size == dst_trainer.batch_size
        assert src_trainer.val_batch_size == dst_trainer.val_batch_size
        assert src_trainer.shuffle == dst_trainer.shuffle

        # Now check the models
        src_models = [src_trainer.encoder, src_trainer.decoder]
        dst_models = [dst_trainer.encoder, dst_trainer.decoder]

        for src_mod, dst_mod in zip(src_models, dst_models):

            print('\t Comparing parameters for %s model' % repr(src_mod))
            src_model_params = src_mod.get_net_state_dict()
            dst_model_params = dst_mod.get_net_state_dict()

            assert len(src_model_params.items()) == len(
                dst_model_params.items())

            # p1, p2 are k,v tuple pairs of each model parameters
            # k = str
            # v = torch.Tensor
            for n, (p1, p2) in enumerate(
                    zip(src_model_params.items(), dst_model_params.items())):
                assert p1[0] == p2[0]
                print('Checking parameter %s [%d/%d] \t\t' %
                      (str(p1[0]), n + 1, len(src_model_params.items())),
                      end='\r')
                assert torch.equal(p1[1], p2[1]) is True
            print('\n ...done')

        print('Checking history...')
        dst_trainer.load_history(test_history_file)

        assert len(src_trainer.loss_history) == len(dst_trainer.loss_history)
        for elem in range(len(src_trainer.loss_history)):
            print('Checking loss history element [%d / %d]' %
                  (elem + 1, len(src_trainer.loss_history)),
                  end='\r')
            assert src_trainer.loss_history[elem] == dst_trainer.loss_history[
                elem]

        print('\n OK')
Esempio n. 16
0
    def test_save_load(self) -> None:
        test_checkpoint_file = 'checkpoint/test_aae_semi_trainer_checkpoint.pth'
        test_history_file = 'checkpoint/test_aae_semi_trainer_history.pth'

        q_net = aae_common.AAEQNet(self.x_dim,
                                   self.z_dim,
                                   self.hidden_size,
                                   num_classes=self.num_classes)
        p_net = aae_common.AAEPNet(self.x_dim, self.z_dim + self.num_classes,
                                   self.hidden_size)
        d_cat_net = aae_common.AAEDNetGauss(self.num_classes, self.hidden_size)
        d_gauss_net = aae_common.AAEDNetGauss(self.z_dim, self.hidden_size)

        q_net.set_cat_mode()
        assert q_net.net.cat_mode == True

        # We also need to sub-sample some parts of the MNIST dataset to produce the
        # 'labelled' data loaders
        print('Creating MNIST sub-dataset...')
        train_label_dataset, val_label_dataset, train_unlabel_dataset = mnist_sub.gen_mnist_subset(
            self.test_data_dir, transform=self.transform, verbose=self.verbose)

        assert train_label_dataset is not None
        assert train_unlabel_dataset is not None
        assert val_label_dataset is not None

        src_trainer = aae_semisupervised_trainer.AAESemiTrainer(
            q_net,
            p_net,
            d_cat_net,
            d_gauss_net,
            # datasets
            train_label_dataset=train_label_dataset,
            train_unlabel_dataset=train_unlabel_dataset,
            val_label_dataset=val_label_dataset,
            # train options
            num_epochs=self.test_num_epochs,
            batch_size=self.batch_size,
            # misc
            print_every=self.print_every,
            save_every=0,
            device_id=util.get_device_id(),
            verbose=self.verbose)

        src_trainer.train()
        print('Saving checkpoint to file [%s]' % str(test_checkpoint_file))
        src_trainer.save_checkpoint(test_checkpoint_file)

        dst_trainer = aae_semisupervised_trainer.AAESemiTrainer(
            device_id=util.get_device_id())
        assert dst_trainer.q_net is None
        assert dst_trainer.p_net is None
        assert dst_trainer.d_cat_net is None
        assert dst_trainer.d_gauss_net is None

        # Test that models, etc are loaded
        print('Loading checkpoint data from [%s]' % str(test_checkpoint_file))
        dst_trainer.load_checkpoint(test_checkpoint_file)
        assert dst_trainer.q_net is not None
        assert dst_trainer.p_net is not None
        assert dst_trainer.d_cat_net is not None
        assert dst_trainer.d_gauss_net is not None

        model_list = ['q_net', 'p_net', 'd_cat_net', 'd_gauss_net']

        for model in model_list:
            src_model = getattr(src_trainer, model)
            dst_model = getattr(dst_trainer, model)
            assert src_model is not None
            assert dst_model is not None
            print('\t Comparing parameters for model [%s]' % repr(src_model))
            src_model_params = src_model.get_net_state_dict()
            dst_model_params = dst_model.get_net_state_dict()

            assert len(src_model_params.items()) == len(
                dst_model_params.items())

            # p1, p2 are k,v tuple pairs of each model parameters
            # k = str
            # v = torch.Tensor
            for n, (p1, p2) in enumerate(
                    zip(src_model_params.items(), dst_model_params.items())):
                assert p1[0] == p2[0]
                print('Checking parameter %s [%d/%d] \t\t' %
                      (str(p1[0]), n + 1, len(src_model_params.items())),
                      end='\r')
                assert torch.equal(p1[1], p2[1]) is True
            print('\n ...done')

        # Test that history is correctly loaded
        print('Saving history to file [%s]' % str(test_history_file))
        src_trainer.save_history(test_history_file)
        dst_trainer.load_history(test_history_file)

        # Check iteration values
        assert src_trainer.loss_iter == dst_trainer.loss_iter
        assert src_trainer.val_loss_iter == dst_trainer.val_loss_iter
        assert src_trainer.train_val_loss_iter == dst_trainer.train_val_loss_iter
        assert src_trainer.acc_iter == dst_trainer.acc_iter
        assert src_trainer.cur_epoch == dst_trainer.cur_epoch
        assert src_trainer.iter_per_epoch == dst_trainer.iter_per_epoch

        # Check history arrays
        assert dst_trainer.d_loss_history is not None
        assert dst_trainer.g_loss_history is not None
        assert dst_trainer.recon_loss_history is not None
        assert dst_trainer.class_loss_history is not None

        assert len(src_trainer.d_loss_history) == len(
            dst_trainer.d_loss_history)
        assert len(src_trainer.g_loss_history) == len(
            dst_trainer.g_loss_history)
        assert len(src_trainer.recon_loss_history) == len(
            dst_trainer.recon_loss_history)
        assert len(src_trainer.class_loss_history) == len(
            dst_trainer.class_loss_history)

        for idx in range(len(src_trainer.d_loss_history)):
            print('Checking d_loss_history idx [%d / %d]' %
                  (idx + 1, len(src_trainer.d_loss_history)),
                  end='\r')
            assert src_trainer.d_loss_history[
                idx] == dst_trainer.d_loss_history[idx]
        print('\n OK')

        for idx in range(len(src_trainer.g_loss_history)):
            print('Checking d_loss_history idx [%d / %d]' %
                  (idx + 1, len(src_trainer.g_loss_history)),
                  end='\r')
            assert src_trainer.g_loss_history[
                idx] == dst_trainer.g_loss_history[idx]
        print('\n OK')

        for idx in range(len(src_trainer.recon_loss_history)):
            print('Checking d_loss_history idx [%d / %d]' %
                  (idx + 1, len(src_trainer.recon_loss_history)),
                  end='\r')
            assert src_trainer.recon_loss_history[
                idx] == dst_trainer.recon_loss_history[idx]
        print('\n OK')

        for idx in range(len(src_trainer.class_loss_history)):
            print('Checking d_loss_history idx [%d / %d]' %
                  (idx + 1, len(src_trainer.class_loss_history)),
                  end='\r')
            assert src_trainer.class_loss_history[
                idx] == dst_trainer.class_loss_history[idx]
        print('\n OK')
Esempio n. 17
0
    def test_save_load_checkpoint(self) -> None:
        test_checkpoint = 'checkpoint/dcgan_trainer_test.pkl'
        test_history = 'checkpoint/dcgan_trainer_test_history.pkl'

        train_dataset = get_dataset()
        # get models
        discriminator = dcgan.DCGANDiscriminator()
        generator = dcgan.DCGANGenerator()
        # get a trainer
        src_trainer = dcgan_trainer.DCGANTrainer(
            D=discriminator,
            G=generator,
            # device
            device_id=util.get_device_id(),
            batch_size=self.batch_size,
            # training params
            train_dataset=train_dataset,
            num_epochs=self.test_num_epochs,
            learning_rate=self.test_learning_rate,
            verbose=self.verbose,
            print_every=self.print_every,
            save_every=0,
        )
        src_trainer.train()
        print('Saving checkpoint data to file [%s]' % str(test_checkpoint))
        src_trainer.save_checkpoint(test_checkpoint)
        src_trainer.save_history(test_history)

        # load into new trainer
        dst_trainer = dcgan_trainer.DCGANTrainer(
            None,
            None,
            train_dataset=train_dataset,
            device_id=util.get_device_id())
        dst_trainer.load_checkpoint(test_checkpoint)
        assert src_trainer.num_epochs == dst_trainer.num_epochs
        assert src_trainer.learning_rate == dst_trainer.learning_rate
        assert src_trainer.weight_decay == dst_trainer.weight_decay
        assert src_trainer.print_every == dst_trainer.print_every
        assert src_trainer.save_every == dst_trainer.save_every
        assert src_trainer.device_id == dst_trainer.device_id

        print('\t Comparing generator model parameters ')
        src_g = src_trainer.generator.get_net_state_dict()
        dst_g = dst_trainer.generator.get_net_state_dict()
        assert len(src_g.items()) == len(dst_g.items())

        # p1, p2 are k,v tuple pairs of each model parameters
        # k = str
        # v = torch.Tensor
        for n, (p1, p2) in enumerate(zip(src_g.items(), dst_g.items())):
            assert p1[0] == p2[0]
            print('Checking parameter %s [%d/%d] \t\t' %
                  (str(p1[0]), n + 1, len(src_g.items())),
                  end='')
            assert torch.equal(p1[1], p2[1]) is True
            print('\t OK')
        print('\n ...done')

        print('\t Comparing discriminator model parameters')
        src_d = src_trainer.discriminator.get_net_state_dict()
        dst_d = dst_trainer.discriminator.get_net_state_dict()
        assert len(src_d.items()) == len(dst_d.items())

        # p1, p2 are k,v tuple pairs of each model parameters
        # k = str
        # v = torch.Tensor
        for n, (p1, p2) in enumerate(zip(src_d.items(), dst_d.items())):
            assert p1[0] == p2[0]
            print('Checking parameter %s [%d/%d] \t\t' %
                  (str(p1[0]), n + 1, len(src_d.items())),
                  end='')
            assert torch.equal(p1[1], p2[1]) is True
            print('\t OK')
        print('\n ...done')

        # load history and check
        dst_trainer.load_history(test_history)
        assert dst_trainer.d_loss_history is not None
        assert dst_trainer.g_loss_history is not None
        assert len(src_trainer.d_loss_history) == len(
            dst_trainer.d_loss_history)
        assert len(src_trainer.g_loss_history) == len(
            dst_trainer.g_loss_history)

        print('Checking D loss history...')
        for n in range(len(src_trainer.d_loss_history)):
            assert src_trainer.d_loss_history[n] == dst_trainer.d_loss_history[
                n]
        print(' OK')

        print('Checking G loss history...')
        for n in range(len(src_trainer.g_loss_history)):
            assert src_trainer.g_loss_history[n] == dst_trainer.g_loss_history[
                n]
        print(' OK')

        # Try training a bit more. Since the values of cur_epoch and num_epochs
        # are the same, there should be no effect at first
        dst_trainer.train()
        assert dst_trainer.cur_epoch == src_trainer.cur_epoch

        # If we then adjust the number of epochs (to at least cur_epoch+1) then
        # we should see another
        dst_trainer.set_num_epochs(src_trainer.num_epochs + 1)
        dst_trainer.train()
        assert src_trainer.num_epochs + 1 == dst_trainer.cur_epoch

        os.remove(test_checkpoint)
        os.remove(test_history)
Esempio n. 18
0
    def test_save_load_checkpoint(self) -> None:
        test_checkpoint_name = self.checkpoint_dir + 'resnet_trainer_checkpoint.pkl'
        test_history_name    = self.checkpoint_dir + 'resnet_trainer_history.pkl'
        # get a model
        model = resnets.WideResnet(
            depth = self.resnet_depth,
            num_classes = 10,     # using CIFAR-10 data
            input_channels=3,
            w_factor = 1
        )
        # get a traner
        src_tr = resnet_trainer.ResnetTrainer(
            model,
            # training parameters
            batch_size    = self.test_batch_size,
            num_epochs    = self.test_num_epochs,
            learning_rate = self.test_learning_rate,
            # device
            device_id     = util.get_device_id(),
            # display,
            print_every   = self.print_every,
            save_every    = 0,
            verbose       = self.verbose
        )

        if self.verbose:
            print('Created %s object' % repr(src_tr))
            print(src_tr)

        print('Training model %s for %d epochs' % (repr(src_tr), self.test_num_epochs))
        src_tr.train()

        # save the final checkpoint
        src_tr.save_checkpoint(test_checkpoint_name)
        src_tr.save_history(test_history_name)

        # get a new trainer and load checkpoint
        dst_tr = resnet_trainer.ResnetTrainer(
            model
        )
        dst_tr.load_checkpoint(test_checkpoint_name)

        # Test object parameters
        assert src_tr.num_epochs == dst_tr.num_epochs
        assert src_tr.learning_rate == dst_tr.learning_rate
        assert src_tr.weight_decay == dst_tr.weight_decay
        assert src_tr.print_every == dst_tr.print_every
        assert src_tr.save_every == dst_tr.save_every

        print('\t Comparing model parameters ')
        src_model_params = src_tr.get_model_params()
        dst_model_params = dst_tr.get_model_params()
        assert len(src_model_params.items()) == len(dst_model_params.items())

        # p1, p2 are k,v tuple pairs of each model parameters
        # k = str
        # v = torch.Tensor
        for n, (p1, p2) in enumerate(zip(src_model_params.items(), dst_model_params.items())):
            assert p1[0] == p2[0]
            print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n+1, len(src_model_params.items())), end='\r')
            assert torch.equal(p1[1], p2[1]) == True
        print('\n ...done')

        # test history
        dst_tr.load_history(test_history_name)

        # loss history
        assert len(src_tr.loss_history) == len(dst_tr.loss_history)
        assert src_tr.loss_iter == dst_tr.loss_iter
        for n in range(len(src_tr.loss_history)):
            assert src_tr.loss_history[n] == dst_tr.loss_history[n]

        # test loss history
        assert len(src_tr.val_loss_history) == len(dst_tr.val_loss_history)
        assert src_tr.val_loss_iter == dst_tr.val_loss_iter
        for n in range(len(src_tr.val_loss_history)):
            assert src_tr.val_loss_history[n] == dst_tr.val_loss_history[n]

        # test acc history
        assert len(src_tr.acc_history) == len(dst_tr.acc_history)
        assert src_tr.acc_iter == dst_tr.acc_iter
        for n in range(len(src_tr.acc_history)):
            assert src_tr.acc_history[n] == dst_tr.acc_history[n]

        fig, ax = vis_loss_history.get_figure_subplots()
        vis_loss_history.plot_train_history_2subplots(
            ax,
            src_tr.loss_history,
            acc_history = src_tr.acc_history,
            cur_epoch = src_tr.cur_epoch,
            iter_per_epoch = src_tr.iter_per_epoch
        )
        fig.savefig('figures/resnet_trainer_train_test_history.png', bbox_inches='tight')
Esempio n. 19
0
    def test_save_load_device_map(self) -> None:
        test_checkpoint = 'checkpoint/trainer_save_load_device_map.pkl'
        test_history = 'checkpoint/trainer_save_load_device_map_history.pkl'

        model = cifar.CIFAR10Net()
        # Get trainer object
        src_tr = cifar_trainer.CIFAR10Trainer(
            model,
            save_every  = 0,
            print_every = 50,
            device_id   = util.get_device_id(),
            # loader options,
            num_epochs  = self.test_num_epochs,
            batch_size  = self.test_batch_size,
            num_workers = self.test_num_workers,
        )

        if self.verbose:
            print('Created trainer object')
            print(src_tr)

        # train for one epoch
        src_tr.train()
        src_tr.save_checkpoint(test_checkpoint)
        assert src_tr.acc_history is not None
        src_tr.save_history(test_history)

        # Now try to load a checkpoint and ensure that there is an
        # acc history attribute that is not None
        dst_tr = cifar_trainer.CIFAR10Trainer(
            model,
            device_id = -1,
            verbose = self.verbose
        )
        dst_tr.load_checkpoint(test_checkpoint)

        # Test object parameters
        assert src_tr.num_epochs == dst_tr.num_epochs
        assert src_tr.learning_rate == dst_tr.learning_rate
        assert src_tr.weight_decay == dst_tr.weight_decay
        assert src_tr.print_every == dst_tr.print_every
        assert src_tr.save_every == dst_tr.save_every

        print('\t Comparing model parameters ')
        src_model_params = src_tr.get_model_params()
        dst_model_params = dst_tr.get_model_params()
        assert len(src_model_params.items()) == len(dst_model_params.items())

        print('\t Comparing model parameters ')
        src_model_params = src_tr.get_model_params()
        dst_model_params = dst_tr.get_model_params()
        assert len(src_model_params.items()) == len(dst_model_params.items())

        # p1, p2 are k,v tuple pairs of each model parameters
        # k = str
        # v = torch.Tensor
        print('Checking that tensors from checkpoint have been tranferred to new device')
        for n, (p1, p2) in enumerate(zip(src_model_params.items(), dst_model_params.items())):
            assert p1[0] == p2[0]
            print('Checking parameter %s [%d/%d] \t\t' % (str(p1[0]), n, len(src_model_params.items())), end='\r')
            assert 'cpu' == p2[1].device.type
        print('\n ...done')

        # Test loss history
        dst_tr.load_history(test_history)
        assert dst_tr.acc_history is not None
        print('\t Comparing loss history....')
        assert src_tr.loss_iter == dst_tr.loss_iter
        for n in range(src_tr.loss_iter):
            print('Checking loss element [%d/%d]' % (n, src_tr.loss_iter), end='\r')
            assert src_tr.loss_history[n] == dst_tr.loss_history[n]