Beispiel #1
0
    def propose_H(self, dataset):
        config = self.get_base_config(dataset)

        # Wrap the class in KWLWrapper
        original_class_name = config.model.__class__.__name__
        config.model = KWayLogisticWrapper(config.model)
        config.model = config.model.to(self.args.device)

        h_path = Models.get_ref_model_path(self.args, original_class_name, dataset.name, suffix_str='KLogistic')
        best_h_path = path.join(h_path, 'model.best.pth')
        
        trainer = IterativeTrainer(config, self.args)

        if not path.isfile(best_h_path):      
            raise NotImplementedError("Please use setup_model to pretrain the networks first!")
        else:
            print(colored('Loading H1 model from %s'%best_h_path, 'red'))
            config.model.load_state_dict(torch.load(best_h_path))
        
        trainer.run_epoch(0, phase='all')
        test_average_acc = config.logger.get_measure('all_accuracy').mean_epoch(epoch=0)
        print("All average accuracy %s"%colored('%.4f%%'%(test_average_acc*100), 'red'))

        self.base_model = config.model
        self.base_model.eval()
Beispiel #2
0
    def propose_H(self, dataset):
        config = self.get_base_config(dataset)

        import models as Models
        if self.default_model == 0:
            config.model.netid = "BCE." + config.model.netid
        else:
            config.model.netid = "MSE." + config.model.netid

        home_path = Models.get_ref_model_path(self.args, config.model.__class__.__name__, dataset.name, suffix_str=config.model.netid)
        hbest_path = path.join(home_path, 'model.best.pth')
        best_h_path = hbest_path

        trainer = IterativeTrainer(config, self.args)

        if not path.isfile(best_h_path):
            raise NotImplementedError("%s not found!, Please use setup_model to pretrain the networks first!"%best_h_path)
        else:
            print(colored('Loading H1 model from %s'%best_h_path, 'red'))
            config.model.load_state_dict(torch.load(best_h_path))
        
        trainer.run_epoch(0, phase='all')
        test_loss = config.logger.get_measure('all_loss').mean_epoch(epoch=0)
        print("All average loss %s"%colored('%.4f'%(test_loss), 'red'))

        self.base_model = config.model
        self.base_model.eval()
Beispiel #3
0
    def propose_H(self, dataset):
        config = self.get_base_config(dataset)

        from models import get_ref_model_path
        h_path = get_ref_model_path(self.args, config.model.__class__.__name__,
                                    dataset.name)
        best_h_path = path.join(h_path, 'model.best.pth')

        trainer = IterativeTrainer(config, self.args)

        if not path.isfile(best_h_path):
            raise NotImplementedError(
                "Please use model_setup to pretrain the networks first!")
        else:
            print(colored('Loading H1 model from %s' % best_h_path, 'red'))
            config.model.load_state_dict(torch.load(best_h_path))

        trainer.run_epoch(0, phase='all')
        test_average_acc = config.logger.get_measure(
            'all_accuracy').mean_epoch(epoch=0)
        print("All average accuracy %s" %
              colored('%.4f%%' % (test_average_acc * 100), 'red'))

        self.base_model = config.model
        self.base_model.eval()
Beispiel #4
0
def train_autoencoder(args, model, dataset, BCE_Loss):
    if BCE_Loss:
        model.netid = "BCE." + model.netid
    else:
        model.netid = "MSE." + model.netid

    home_path = Models.get_ref_model_path(args, model.__class__.__name__, dataset.name, model_setup=True, suffix_str=model.netid)
    hbest_path = os.path.join(home_path, 'model.best.pth')
    hlast_path = os.path.join(home_path, 'model.last.pth')

    if not os.path.isdir(home_path):
        os.makedirs(home_path)

    if not os.path.isfile(hbest_path+".done"):
        config = get_ae_config(args, model, dataset, BCE_Loss=BCE_Loss)
        trainer = IterativeTrainer(config, args)
        print(colored('Training from scratch', 'green'))
        best_loss = 999999999
        for epoch in range(1, config.max_epoch+1):

            # Track the learning rates.
            lrs = [float(param_group['lr']) for param_group in config.optim.param_groups]
            config.logger.log('LRs', lrs, epoch)
            config.logger.get_measure('LRs').legend = ['LR%d'%i for i in range(len(lrs))]
            
            # One epoch of train and test.
            trainer.run_epoch(epoch, phase='train')
            trainer.run_epoch(epoch, phase='test')

            train_loss = config.logger.get_measure('train_loss').mean_epoch()
            test_loss = config.logger.get_measure('test_loss').mean_epoch()

            config.scheduler.step(train_loss)

            if config.visualize:
                # Show the average losses for all the phases in one figure.
                config.logger.visualize_average_keys('.*_loss', 'Average Loss', trainer.visdom)
                config.logger.visualize_average_keys('.*_accuracy', 'Average Accuracy', trainer.visdom)
                config.logger.visualize_average('LRs', trainer.visdom)

            # Save the logger for future reference.
            torch.save(config.logger.measures, os.path.join(home_path, 'logger.pth'))

            # Saving a checkpoint. Enable if needed!
            # if args.save and epoch % 10 == 0:
            #     print('Saving a %s at iter %s'%(colored('snapshot', 'yellow'), colored('%d'%epoch, 'yellow')))
            #     torch.save(config.model.state_dict(), os.path.join(home_path, 'model.%d.pth'%epoch))

            if args.save and test_loss < best_loss:
                print('Updating the on file model with %s'%(colored('%.4f'%test_loss, 'red')))
                best_loss = test_loss
                torch.save(config.model.state_dict(), hbest_path)
        
        torch.save({'finished':True}, hbest_path+".done")
        torch.save(config.model.state_dict(), hlast_path)

        if config.visualize:
            trainer.visdom.save([trainer.visdom.env])
    else:
        print("Skipping %s"%(colored(home_path, 'yellow')))
def train_classifier(args, model, dataset):
    config = None

    for mid in range(5):
        home_path = Models.get_ref_model_path(args,
                                              model.__class__.__name__,
                                              dataset.name,
                                              model_setup=True,
                                              suffix_str='DE.%d' % mid)
        hbest_path = os.path.join(home_path, 'model.best.pth')

        if not os.path.isdir(home_path):
            os.makedirs(home_path)
        else:
            if os.path.isfile(hbest_path + ".done"):
                print("Skipping %s" % (colored(home_path, 'yellow')))
                continue

        config = get_classifier_config(args,
                                       model.__class__(),
                                       dataset,
                                       mid=mid)

        trainer = IterativeTrainer(config, args)

        if not os.path.isfile(hbest_path + ".done"):
            print(colored('Training from scratch', 'green'))
            best_accuracy = -1
            for epoch in range(1, config.max_epoch + 1):

                # Track the learning rates.
                lrs = [
                    float(param_group['lr'])
                    for param_group in config.optim.param_groups
                ]
                config.logger.log('LRs', lrs, epoch)
                config.logger.get_measure('LRs').legend = [
                    'LR%d' % i for i in range(len(lrs))
                ]

                # One epoch of train and test.
                trainer.run_epoch(epoch, phase='train')
                trainer.run_epoch(epoch, phase='test')

                train_loss = config.logger.get_measure(
                    'train_loss').mean_epoch()
                config.scheduler.step(train_loss)

                if config.visualize:
                    # Show the average losses for all the phases in one figure.
                    config.logger.visualize_average_keys(
                        '.*_loss', 'Average Loss', trainer.visdom)
                    config.logger.visualize_average_keys(
                        '.*_accuracy', 'Average Accuracy', trainer.visdom)
                    config.logger.visualize_average('LRs', trainer.visdom)

                test_average_acc = config.logger.get_measure(
                    'test_accuracy').mean_epoch()

                # Save the logger for future reference.
                torch.save(config.logger.measures,
                           os.path.join(home_path, 'logger.pth'))

                # Saving a checkpoint. Enable if needed!
                # if args.save and epoch % 10 == 0:
                #     print('Saving a %s at iter %s'%(colored('snapshot', 'yellow'), colored('%d'%epoch, 'yellow')))
                #     torch.save(config.model.state_dict(), os.path.join(home_path, 'model.%d.pth'%epoch))

                if args.save and best_accuracy < test_average_acc:
                    print('Updating the on file model with %s' %
                          (colored('%.4f' % test_average_acc, 'red')))
                    best_accuracy = test_average_acc
                    torch.save(config.model.state_dict(), hbest_path)

            torch.save({'finished': True}, hbest_path + ".done")
            if config.visualize:
                trainer.visdom.save([trainer.visdom.env])
        else:
            print("Skipping %s" % (colored(home_path, 'yellow')))

        print("Loading the best model.")
        config.model.load_state_dict(torch.load(hbest_path))
        config.model.eval()

        trainer.run_epoch(0, phase='all')
        test_average_acc = config.logger.get_measure(
            'all_accuracy').mean_epoch(epoch=0)
        print("All average accuracy %s" %
              colored('%.4f%%' % (test_average_acc * 100), 'red'))
Beispiel #6
0
    def train_H(self, dataset):
        # Wrap the (mixture)dataset in SubDataset so to easily
        # split it later.
        dataset = SubDataset('%s-%s' % (self.args.D1, self.args.D2), dataset,
                             torch.arange(len(dataset)).int())

        # 80%, 20% for local train+test
        train_ds, valid_ds = dataset.split_dataset(0.8)

        if self.args.D1 in Global.mirror_augment:
            print(colored("Mirror augmenting %s" % self.args.D1, 'green'))
            new_train_ds = train_ds + MirroredDataset(train_ds)
            train_ds = new_train_ds

        # As suggested by the authors.
        all_temperatures = [1, 2, 5, 10, 20, 50, 100, 200, 500, 1000]
        all_epsilons = torch.linspace(0, 0.004, 21)
        total_params = len(all_temperatures) * len(all_epsilons)
        best_accuracy = -1

        h_path = path.join(self.args.experiment_path,
                           '%s' % (self.__class__.__name__),
                           '%d' % (self.default_model),
                           '%s-%s.pth' % (self.args.D1, self.args.D2))
        h_parent = path.dirname(h_path)
        if not path.isdir(h_parent):
            os.makedirs(h_parent)

        done_path = h_path + '.done'
        trainer, h_config = None, None

        if self.args.force_train_h or not path.isfile(done_path):
            # Grid search over the temperature and the epsilons.
            for i_eps, eps in enumerate(all_epsilons):
                for i_temp, temp in enumerate(all_temperatures):
                    so_far = i_eps * len(all_temperatures) + i_temp + 1
                    print(
                        colored(
                            'Checking eps=%.2e temp=%.1f (%d/%d)' %
                            (eps, temp, so_far, total_params), 'green'))
                    start_time = timeit.default_timer()

                    h_config = self.get_H_config(train_ds=train_ds,
                                                 valid_ds=valid_ds,
                                                 epsilon=eps,
                                                 temperature=temp)

                    trainer = IterativeTrainer(h_config, self.args)

                    print(colored('Training from scratch', 'green'))
                    trainer.run_epoch(0, phase='test')
                    for epoch in range(1, h_config.max_epoch + 1):
                        trainer.run_epoch(epoch, phase='train')
                        trainer.run_epoch(epoch, phase='test')

                        train_loss = h_config.logger.get_measure(
                            'train_loss').mean_epoch()
                        h_config.scheduler.step(train_loss)

                        # Track the learning rates and threshold.
                        lrs = [
                            float(param_group['lr'])
                            for param_group in h_config.optim.param_groups
                        ]
                        h_config.logger.log('LRs', lrs, epoch)
                        h_config.logger.get_measure('LRs').legend = [
                            'LR%d' % i for i in range(len(lrs))
                        ]

                        if hasattr(h_config.model, 'H') and hasattr(
                                h_config.model.H, 'threshold'):
                            h_config.logger.log(
                                'threshold',
                                h_config.model.H.threshold.cpu().numpy(),
                                epoch - 1)
                            h_config.logger.get_measure('threshold').legend = [
                                'threshold'
                            ]
                            if h_config.visualize:
                                h_config.logger.get_measure(
                                    'threshold').visualize_all_epochs(
                                        trainer.visdom)

                        if h_config.visualize:
                            # Show the average losses for all the phases in one figure.
                            h_config.logger.visualize_average_keys(
                                '.*_loss', 'Average Loss', trainer.visdom)
                            h_config.logger.visualize_average_keys(
                                '.*_accuracy', 'Average Accuracy',
                                trainer.visdom)
                            h_config.logger.visualize_average(
                                'LRs', trainer.visdom)

                        test_average_acc = h_config.logger.get_measure(
                            'test_accuracy').mean_epoch()

                        if best_accuracy < test_average_acc:
                            print('Updating the on file model with %s' %
                                  (colored('%.4f' % test_average_acc, 'red')))
                            best_accuracy = test_average_acc
                            torch.save(h_config.model.H.state_dict(), h_path)

                    elapsed = timeit.default_timer() - start_time
                    print('Hyper-param check (%.2e, %.1f) in %.2fs' %
                          (eps, temp, elapsed))

            torch.save({'finished': True}, done_path)

        # If we load the pretrained model directly, we will have to initialize these.
        if trainer is None or h_config is None:
            h_config = self.get_H_config(train_ds=train_ds,
                                         valid_ds=valid_ds,
                                         epsilon=0,
                                         temperature=1,
                                         will_train=False)
            # don't worry about the values of epsilon or temperature. it will be overwritten.
            trainer = IterativeTrainer(h_config, self.args)

        # Load the best model.
        print(colored('Loading H model from %s' % h_path, 'red'))
        state_dict = torch.load(h_path)
        for key, val in state_dict.items():
            if val.shape == torch.Size([]):
                state_dict[key] = val.view((1, ))
        h_config.model.H.load_state_dict(state_dict)
        h_config.model.set_eval_direct(False)
        print('Temperature %s Epsilon %s' %
              (colored(h_config.model.H.temperature.item(), 'red'),
               colored(h_config.model.H.epsilon.item(), 'red')))

        trainer.run_epoch(0, phase='testU')
        test_average_acc = h_config.logger.get_measure(
            'testU_accuracy').mean_epoch(epoch=0)
        print("Valid/Test average accuracy %s" %
              colored('%.4f%%' % (test_average_acc * 100), 'red'))
        self.H_class = h_config.model
        self.H_class.eval()
        self.H_class.set_eval_direct(False)
        return test_average_acc
Beispiel #7
0
    def train_H(self, dataset):
        # Wrap the (mixture)dataset in SubDataset so to easily
        # split it later.
        from datasets import SubDataset
        dataset = SubDataset('%s-%s' % (self.args.D1, self.args.D2), dataset,
                             torch.arange(len(dataset)).int())

        h_path = path.join(self.args.experiment_path,
                           '%s' % (self.__class__.__name__),
                           '%d' % (self.default_model),
                           '%s->%s.pth' % (self.args.D1, self.args.D2))
        h_parent = path.dirname(h_path)
        if not path.isdir(h_parent):
            os.makedirs(h_parent)

        done_path = h_path + '.done'
        will_train = self.args.force_train_h or not path.isfile(done_path)

        h_config = self.get_H_config(dataset, will_train)

        trainer = IterativeTrainer(h_config, self.args)

        if will_train:
            print(colored('Training from scratch', 'green'))
            best_accuracy = -1
            trainer.run_epoch(0, phase='test')
            for epoch in range(1, h_config.max_epoch + 1):
                trainer.run_epoch(epoch, phase='train')
                trainer.run_epoch(epoch, phase='test')

                train_loss = h_config.logger.get_measure(
                    'train_loss').mean_epoch()
                h_config.scheduler.step(train_loss)

                # Track the learning rates and threshold.
                lrs = [
                    float(param_group['lr'])
                    for param_group in h_config.optim.param_groups
                ]
                h_config.logger.log('LRs', lrs, epoch)
                h_config.logger.get_measure('LRs').legend = [
                    'LR%d' % i for i in range(len(lrs))
                ]

                viz_params = ['threshold', 'transfer']
                for viz_param in viz_params:
                    if hasattr(h_config.model, 'H') and hasattr(
                            h_config.model.H, viz_param):
                        h_config.logger.log(
                            viz_param,
                            getattr(h_config.model.H, viz_param).cpu().numpy(),
                            epoch - 1)
                        h_config.logger.get_measure(viz_param).legend = [
                            viz_param
                        ]
                        if h_config.visualize:
                            h_config.logger.get_measure(
                                viz_param).visualize_all_epochs(trainer.visdom)

                if h_config.visualize:
                    # Show the average losses for all the phases in one figure.
                    h_config.logger.visualize_average_keys(
                        '.*_loss', 'Average Loss', trainer.visdom)
                    h_config.logger.visualize_average_keys(
                        '.*_accuracy', 'Average Accuracy', trainer.visdom)
                    h_config.logger.visualize_average('LRs', trainer.visdom)

                test_average_acc = h_config.logger.get_measure(
                    'test_accuracy').mean_epoch()

                # Save the logger for future reference.
                torch.save(
                    h_config.logger.measures,
                    path.join(
                        h_parent,
                        'logger.%s->%s.pth' % (self.args.D1, self.args.D2)))

                if best_accuracy < test_average_acc:
                    print('Updating the on file model with %s' %
                          (colored('%.4f' % test_average_acc, 'red')))
                    best_accuracy = test_average_acc
                    torch.save(h_config.model.H.state_dict(), h_path)

            torch.save({'finished': True}, done_path)

            if h_config.visualize:
                trainer.visdom.save([trainer.visdom.env])

        # Load the best model.
        print(colored('Loading H model from %s' % h_path, 'red'))
        h_config.model.H.load_state_dict(torch.load(h_path))
        h_config.model.set_eval_direct(False)

        trainer.run_epoch(0, phase='testU')
        test_average_acc = h_config.logger.get_measure(
            'testU_accuracy').mean_epoch(epoch=0)
        print("Valid/Test average accuracy %s" %
              colored('%.4f%%' % (test_average_acc * 100), 'red'))
        self.H_class = h_config.model
        self.H_class.eval()
        self.H_class.set_eval_direct(False)
        return test_average_acc
Beispiel #8
0
def train_variational_autoencoder(args, model, dataset, BCE_Loss=True):
    if BCE_Loss:
        model.netid = "BCE." + model.netid
    else:
        model.netid = "MSE." + model.netid
    home_path = Models.get_ref_model_path(args,
                                          model.__class__.__name__,
                                          dataset.name,
                                          model_setup=True,
                                          suffix_str=model.netid)
    hbest_path = os.path.join(home_path, 'model.best.pth')
    hlast_path = os.path.join(home_path, 'model.last.pth')

    if not os.path.isdir(home_path):
        os.makedirs(home_path)

    if not os.path.isfile(hbest_path + ".done"):
        config = get_vae_config(args, model, dataset, home_path, BCE_Loss)
        trainer = IterativeTrainer(config, args)
        print(colored('Training from scratch', 'green'))
        best_loss = 999999999
        for epoch in range(1, config.max_epoch + 1):

            # Track the learning rates.
            lrs = [
                float(param_group['lr'])
                for param_group in config.optim.param_groups
            ]
            config.logger.log('LRs', lrs, epoch)
            config.logger.get_measure('LRs').legend = [
                'LR%d' % i for i in range(len(lrs))
            ]

            # One epoch of train and test.
            trainer.run_epoch(epoch, phase='train')
            trainer.run_epoch(epoch, phase='test')

            train_loss = config.logger.get_measure('train_loss').mean_epoch()
            test_loss = config.logger.get_measure('test_loss').mean_epoch()

            config.logger.writer.add_scalar('train_loss', train_loss, epoch)
            config.logger.writer.add_scalar('test_loss', test_loss, epoch)
            config.scheduler.step(train_loss)

            # vis in tensorboard
            for (image, label) in config.valid_loader:
                prediction = model(image.cuda()).data.cpu().squeeze().numpy()
                prediction = (prediction - prediction.min()) / (
                    prediction.max() - prediction.min())
                if len(prediction.shape) > 3 and prediction.shape[1] == 3:
                    prediction = prediction.transpose(
                        (0, 2, 3, 1))  # change to N W H C
                N = min(prediction.shape[0], 5)
                fig, ax = plt.subplots(N, 2)
                image = image.data.squeeze().numpy()
                image = (image - image.min()) / (image.max() - image.min())
                if len(image.shape) > 3 and image.shape[1] == 3:
                    image = image.transpose((0, 2, 3, 1))
                for i in range(N):
                    ax[i, 0].imshow(prediction[i])
                    ax[i, 1].imshow(image[i])
                config.logger.writer.add_figure('Vis', fig, epoch)
                plt.close(fig)
                break

            if config.visualize:
                # Show the average losses for all the phases in one figure.
                config.logger.visualize_average_keys('.*_loss', 'Average Loss',
                                                     trainer.visdom)
                config.logger.visualize_average_keys('.*_accuracy',
                                                     'Average Accuracy',
                                                     trainer.visdom)
                config.logger.visualize_average('LRs', trainer.visdom)

            # Save the logger for future reference.
            torch.save(config.logger.measures,
                       os.path.join(home_path, 'logger.pth'))

            # Saving a checkpoint. Enable if needed!
            # if args.save and epoch % 10 == 0:
            #     print('Saving a %s at iter %s'%(colored('snapshot', 'yellow'), colored('%d'%epoch, 'yellow')))
            #     torch.save(config.model.state_dict(), os.path.join(home_path, 'model.%d.pth'%epoch))

            if args.save and test_loss < best_loss:
                print('Updating the on file model with %s' %
                      (colored('%.4f' % test_loss, 'red')))
                best_loss = test_loss
                torch.save(config.model.state_dict(), hbest_path)

        torch.save({'finished': True}, hbest_path + ".done")
        torch.save(config.model.state_dict(), hlast_path)

        if config.visualize:
            trainer.visdom.save([trainer.visdom.env])
    else:
        print("Skipping %s" % (colored(home_path, 'yellow')))