Ejemplo n.º 1
0
    def propose_H(self, dataset):
        config = self.get_base_config(dataset)

        from models import get_ref_model_path
        h_path = get_ref_model_path(self.args, config.model.__class__.__name__,
                                    dataset.name)
        best_h_path = path.join(h_path, 'model.best.pth')

        if not path.isfile(best_h_path):
            raise NotImplementedError(
                "Please use model_setup to pretrain the networks first!")
        else:
            print(colored('Loading H1 model from %s' % best_h_path, 'red'))
            config.model.load_state_dict(torch.load(best_h_path))

        # trainer.run_epoch(0, phase='all')
        # test_average_acc = config.logger.get_measure('all_accuracy').mean_epoch(epoch=0)
        # print("All average accuracy %s"%colored('%.4f%%'%(test_average_acc*100), 'red'))

        self.base_model = MahaModelWrapper(config.model,
                                           2,
                                           intermediate_nodes=(11, ))
        loader = DataLoader(dataset,
                            batch_size=self.args.batch_size,
                            shuffle=True,
                            num_workers=self.args.workers,
                            pin_memory=True)
        self.base_model.collect_states(loader, self.args.device)
        self.base_model.eval()
Ejemplo n.º 2
0
    def propose_H(self, dataset):
        config = self.get_base_config(dataset)

        import models as Models
        if self.default_model == 0:
            config.model.netid = "BCE." + config.model.netid
        else:
            config.model.netid = "MSE." + config.model.netid

        home_path = Models.get_ref_model_path(self.args, config.model.__class__.__name__, dataset.name, suffix_str=config.model.netid)
        hbest_path = path.join(home_path, 'model.best.pth')
        best_h_path = hbest_path

        trainer = IterativeTrainer(config, self.args)

        if not path.isfile(best_h_path):
            raise NotImplementedError("%s not found!, Please use setup_model to pretrain the networks first!"%best_h_path)
        else:
            print(colored('Loading H1 model from %s'%best_h_path, 'red'))
            config.model.load_state_dict(torch.load(best_h_path))
        
        trainer.run_epoch(0, phase='all')
        test_loss = config.logger.get_measure('all_loss').mean_epoch(epoch=0)
        print("All average loss %s"%colored('%.4f'%(test_loss), 'red'))

        self.base_model = config.model
        self.base_model.eval()
Ejemplo n.º 3
0
    def propose_H(self, dataset):
        config = self.get_base_config(dataset)

        # Wrap the class in KWLWrapper
        original_class_name = config.model.__class__.__name__
        config.model = KWayLogisticWrapper(config.model)
        config.model = config.model.to(self.args.device)

        h_path = Models.get_ref_model_path(self.args, original_class_name, dataset.name, suffix_str='KLogistic')
        best_h_path = path.join(h_path, 'model.best.pth')
        
        trainer = IterativeTrainer(config, self.args)

        if not path.isfile(best_h_path):      
            raise NotImplementedError("Please use setup_model to pretrain the networks first!")
        else:
            print(colored('Loading H1 model from %s'%best_h_path, 'red'))
            config.model.load_state_dict(torch.load(best_h_path))
        
        trainer.run_epoch(0, phase='all')
        test_average_acc = config.logger.get_measure('all_accuracy').mean_epoch(epoch=0)
        print("All average accuracy %s"%colored('%.4f%%'%(test_average_acc*100), 'red'))

        self.base_model = config.model
        self.base_model.eval()
Ejemplo n.º 4
0
def train_autoencoder(args, model, dataset, BCE_Loss):
    if BCE_Loss:
        model.netid = "BCE." + model.netid
    else:
        model.netid = "MSE." + model.netid

    home_path = Models.get_ref_model_path(args, model.__class__.__name__, dataset.name, model_setup=True, suffix_str=model.netid)
    hbest_path = os.path.join(home_path, 'model.best.pth')
    hlast_path = os.path.join(home_path, 'model.last.pth')

    if not os.path.isdir(home_path):
        os.makedirs(home_path)

    if not os.path.isfile(hbest_path+".done"):
        config = get_ae_config(args, model, dataset, BCE_Loss=BCE_Loss)
        trainer = IterativeTrainer(config, args)
        print(colored('Training from scratch', 'green'))
        best_loss = 999999999
        for epoch in range(1, config.max_epoch+1):

            # Track the learning rates.
            lrs = [float(param_group['lr']) for param_group in config.optim.param_groups]
            config.logger.log('LRs', lrs, epoch)
            config.logger.get_measure('LRs').legend = ['LR%d'%i for i in range(len(lrs))]
            
            # One epoch of train and test.
            trainer.run_epoch(epoch, phase='train')
            trainer.run_epoch(epoch, phase='test')

            train_loss = config.logger.get_measure('train_loss').mean_epoch()
            test_loss = config.logger.get_measure('test_loss').mean_epoch()

            config.scheduler.step(train_loss)

            if config.visualize:
                # Show the average losses for all the phases in one figure.
                config.logger.visualize_average_keys('.*_loss', 'Average Loss', trainer.visdom)
                config.logger.visualize_average_keys('.*_accuracy', 'Average Accuracy', trainer.visdom)
                config.logger.visualize_average('LRs', trainer.visdom)

            # Save the logger for future reference.
            torch.save(config.logger.measures, os.path.join(home_path, 'logger.pth'))

            # Saving a checkpoint. Enable if needed!
            # if args.save and epoch % 10 == 0:
            #     print('Saving a %s at iter %s'%(colored('snapshot', 'yellow'), colored('%d'%epoch, 'yellow')))
            #     torch.save(config.model.state_dict(), os.path.join(home_path, 'model.%d.pth'%epoch))

            if args.save and test_loss < best_loss:
                print('Updating the on file model with %s'%(colored('%.4f'%test_loss, 'red')))
                best_loss = test_loss
                torch.save(config.model.state_dict(), hbest_path)
        
        torch.save({'finished':True}, hbest_path+".done")
        torch.save(config.model.state_dict(), hlast_path)

        if config.visualize:
            trainer.visdom.save([trainer.visdom.env])
    else:
        print("Skipping %s"%(colored(home_path, 'yellow')))
Ejemplo n.º 5
0
    def propose_H(self, dataset):
        config = self.get_base_config(dataset)

        from models import get_ref_model_path
        h_path = get_ref_model_path(self.args, config.model.__class__.__name__,
                                    dataset.name)
        best_h_path = path.join(h_path, 'model.best.pth')

        trainer = IterativeTrainer(config, self.args)

        if not path.isfile(best_h_path):
            raise NotImplementedError(
                "Please use model_setup to pretrain the networks first!")
        else:
            print(colored('Loading H1 model from %s' % best_h_path, 'red'))
            config.model.load_state_dict(torch.load(best_h_path))

        trainer.run_epoch(0, phase='all')
        test_average_acc = config.logger.get_measure(
            'all_accuracy').mean_epoch(epoch=0)
        print("All average accuracy %s" %
              colored('%.4f%%' % (test_average_acc * 100), 'red'))

        self.base_model = config.model
        self.base_model.eval()
Ejemplo n.º 6
0
    def get_base_config(self, dataset):
        print("Preparing training D1 for %s" %
              (dataset.parent_dataset.__class__.__name__))

        all_loader = DataLoader(dataset,
                                batch_size=self.args.batch_size,
                                num_workers=self.args.workers,
                                pin_memory=True)

        # Set up the criterion
        criterion = nn.NLLLoss().cuda()

        # Set up the model
        model_class = Global.get_ref_classifier(
            dataset.name)[self.default_model]
        self.add_identifier = model_class.__name__

        # We must create 5 instances of this class.
        from models import get_ref_model_path
        all_models = []
        for mid in range(5):
            model = model_class()
            model = DeepEnsembleWrapper(model)
            model = model.to(self.args.device)
            h_path = get_ref_model_path(self.args,
                                        model_class.__name__,
                                        dataset.name,
                                        suffix_str='DE.%d' % mid)
            best_h_path = path.join(h_path, 'model.best.pth')
            if not path.isfile(best_h_path):
                raise NotImplementedError(
                    "Please use setup_model to pretrain the networks first! Can't find %s"
                    % best_h_path)
            else:
                print(colored('Loading H1 model from %s' % best_h_path, 'red'))
                model.load_state_dict(torch.load(best_h_path))
                model.eval()
            all_models.append(model)
        master_model = DeepEnsembleMasterWrapper(all_models)

        # Set up the config
        config = IterativeTrainerConfig()

        config.name = '%s-CLS' % (self.args.D1)
        config.phases = {
            'all': {
                'dataset': all_loader,
                'backward': False
            },
        }
        config.criterion = criterion
        config.classification = True
        config.cast_float_label = False
        config.stochastic_gradient = True
        config.model = master_model
        config.optim = None
        config.autoencoder_target = False
        config.visualize = False
        config.logger = Logger()
        return config
    def propose_H(self, dataset):
        assert self.default_model > 0, 'KNN needs K>0'
        if self.base_model is not None:
            self.base_model.base_data = None
            self.base_model = None

        # Set up the base0-model
        base_model = Global.get_ref_classifier(dataset.name)[0]().to(
            self.args.device)
        from models import get_ref_model_path
        home_path = get_ref_model_path(self.args,
                                       base_model.__class__.__name__,
                                       dataset.name)

        hbest_path = path.join(home_path, 'model.best.pth')
        best_h_path = hbest_path
        print(colored('Loading H1 model from %s' % best_h_path, 'red'))
        base_model.load_state_dict(torch.load(best_h_path))
        base_model.eval()

        if dataset.name in Global.mirror_augment:
            print(colored("Mirror augmenting %s" % dataset.name, 'green'))
            new_train_ds = dataset + MirroredDataset(dataset)
            dataset = new_train_ds

        # Initialize the multi-threaded loaders.
        all_loader = DataLoader(dataset,
                                batch_size=self.args.batch_size,
                                num_workers=1,
                                pin_memory=True)

        n_data = len(dataset)
        n_dim = base_model.partial_forward(dataset[0][0].to(
            self.args.device).unsqueeze(0)).numel()
        print('nHidden %d' % (n_dim))
        self.base_data = torch.zeros(n_data, n_dim, dtype=torch.float32)
        base_ind = 0
        with torch.set_grad_enabled(False):
            with tqdm(total=len(all_loader),
                      disable=bool(os.environ.get("DISABLE_TQDM",
                                                  False))) as pbar:
                pbar.set_description('Caching X_train for %d-nn' %
                                     self.default_model)
                for i, (x, _) in enumerate(all_loader):
                    n_data = x.size(0)
                    output = base_model.partial_forward(x.to(
                        self.args.device)).data
                    self.base_data[base_ind:base_ind + n_data].copy_(output)
                    base_ind = base_ind + n_data
                    pbar.update()
        # self.base_data = torch.cat([x.view(1, -1) for x,_ in dataset])
        self.base_model = AEKNNModel(base_model,
                                     self.base_data,
                                     k=self.default_model,
                                     SV=True).to(self.args.device)
        self.base_model.eval()
Ejemplo n.º 8
0
def needs_processing(args, dataset_class, models, suffix):
    """
        This function checks whether this model is already trained and can be skipped.
    """
    for model in models:
        for suf in suffix:
            home_path = Models.get_ref_model_path(args, model.__name__, dataset_class.__name__, model_setup=True, suffix_str=suf)
            hbest_path = os.path.join(home_path, 'model.best.pth.done')
            if not os.path.isfile(hbest_path):
                return True
    return False
Ejemplo n.º 9
0
    def get_base_config(self, dataset):
        print("Preparing training D1 for %s" %
              (dataset.parent_dataset.__class__.__name__))

        all_loader = DataLoader(dataset,
                                batch_size=self.args.batch_size,
                                num_workers=self.args.workers,
                                pin_memory=True)

        # Set up the model
        model = Global.get_ref_pixelcnn(dataset.name)[self.default_model]().to(
            self.args.device)
        self.add_identifier = model.__class__.__name__

        # Load the snapshot
        from models import get_ref_model_path
        h_path = get_ref_model_path(self.args,
                                    model.__class__.__name__,
                                    dataset.name,
                                    suffix_str=model.netid)
        best_h_path = path.join(h_path, 'model.best.pth')
        if not path.isfile(best_h_path):
            raise NotImplementedError(
                "Please use setup_model to pretrain the networks first! Can't find %s"
                % best_h_path)
        else:
            print(colored('Loading H1 model from %s' % best_h_path, 'red'))
            model.load_state_dict(torch.load(best_h_path))
            model.eval()

        # Set up the criterion
        criterion = PCNN_Loss(one_d=(model.input_channels == 1)).to(
            self.args.device)

        # Set up the config
        config = IterativeTrainerConfig()

        config.name = '%s-pcnn' % (self.args.D1)
        config.phases = {
            'all': {
                'dataset': all_loader,
                'backward': False
            },
        }
        config.criterion = criterion
        config.classification = False
        config.cast_float_label = False
        config.autoencoder_target = True
        config.stochastic_gradient = True
        config.model = model
        config.optim = None
        config.visualize = False
        config.logger = Logger()
        return config
Ejemplo n.º 10
0
    def run_epoch(self, epoch, phase='train'):
        # Retrieve the appropriate config.
        config = self.config.phases[phase]
        dataset = config['dataset']
        backward = config['backward']
        phase_name = phase
        print("Doing %s" % colored(phase, 'green'))

        model = self.config.model
        visualize = self.config.visualize
        criterion = self.config.criterion
        optimizer = self.config.optim
        logger = self.config.logger
        stochastic = self.config.stochastic_gradient
        classification = self.config.classification

        #print("self.config.name:" + self.config.name)
        home_path = Models.get_ref_model_path(self.args,
                                              model.__class__.__name__,
                                              self.config.name,
                                              model_setup=True,
                                              suffix_str="CCC")
        dump_path = os.path.join(home_path, 'dump')
        if not os.path.isdir(dump_path):
            os.makedirs(dump_path)

        # See the network to the target mode.
        if backward:
            model.train()
            torch.set_grad_enabled(True)
        else:
            model.eval()
            torch.set_grad_enabled(False)

        start_time = timeit.default_timer()
        last_viz_update = start_time

        # For full gradient optimization we need to rescale the loss
        # to calculate the gradient correctly.
        loss_scaler = 1
        if not stochastic:
            loss_scaler = 1. / len(dataset.dataset)

        try:
            # TQDM sometimes throws IOError exceptions when you
            # try to close it. We ignore those exceptions.
            with tqdm(total=len(dataset)) as pbar:
                if backward and not stochastic:
                    optimizer.zero_grad()

                for i, (image, label) in enumerate(dataset):
                    pbar.update()
                    if backward and stochastic:
                        optimizer.zero_grad()

                    # Get and prepare data.
                    input, target, data_indices = image, None, None
                    if torch.typename(label) == 'list':
                        assert len(
                            label
                        ) == 2, 'There should be two entries in the label'
                        # Need to unpack the label. This is for when the data provider
                        # has the cached flag enabled, therefore the y is now (y, idx).
                        target, data_indices = label
                    else:
                        target = label

                    if self.config.autoencoder_target:
                        target = input.clone()

            if self.config.cast_float_label:
                target = target.float().unsqueeze(1)

            input, target = input.to(self.device), target.to(
                model.get_output_device())

            # Do a forward propagation and get the loss.
            prediction = None
            if data_indices is None:
                prediction = model(input)
            else:
                # Run in the cached mode. This is necessary to speed up
                # some of the underlying optimization procedures. It is not
                # always used though.
                prediction = model(input,
                                   indices=data_indices,
                                   group=phase_name)

            loss = criterion(prediction, target)

            if (self.args.dump_images):
                # pick one from the batch and output it

                #filename = phase_name + str(i) +"_epoch" + str(epoch) + ".png"
                #dump_file = os.path.join(dump_path,filename)
                #self.dump_image(input[0].cpu(),dump_file,True)

                if self.config.autoencoder_target:

                    home_path = Models.get_ref_model_path(
                        self.args,
                        model.__class__.__name__,
                        self.config.name,
                        model_setup=True,
                        suffix_str="CCC")
                    dump_path = os.path.join(home_path, 'dump')
                    if not os.path.isdir(dump_path):
                        os.makedirs(dump_path)

                    filename = phase_name + str(i) + "_epoch" + str(
                        epoch) + ".png"
                    dump_file = os.path.join(dump_path, filename)
                    self.dump_image(input[0].cpu(), dump_file, True)

                    # if this is an autoencoder run, also output the recreation for comparison
                    filename = phase_name + str(i) + "_epoch" + str(
                        epoch) + "_target.png"
                    dump_file = os.path.join(dump_path, filename)
                    self.dump_image(prediction[0].cpu(), dump_file, True)

            if backward:
                if stochastic:
                    loss.backward()
                    optimizer.step()
        except IOError, e:
            if e.errno != errno.EINTR:
                raise
            else:
                print(colored("Problem averted :D", 'green'))
Ejemplo n.º 11
0
matplotlib.use('Agg')
import matplotlib.pyplot as plt

if __name__ == "__main__":
    dataset = PCAM(root_path=os.path.join(args.root_path, "pcam"),
                   extract=True,
                   downsample=64).get_D1_train()
    dataloader = torch.utils.data.DataLoader(dataset,
                                             args.batch_size,
                                             True,
                                             num_workers=args.workers,
                                             pin_memory=True)
    model = ALIModel(dims=(3, 64, 64)).cuda()
    home_path = Models.get_ref_model_path(args,
                                          model.__class__.__name__,
                                          dataset.name,
                                          model_setup=True,
                                          suffix_str='base0')
    logger = Logger(home_path)

    hbest_path = os.path.join(home_path, 'model.best.pth')

    if not os.path.isdir(home_path):
        os.makedirs(home_path)
    best_gen_loss = 9999
    if not os.path.isfile(hbest_path + ".done"):
        print(colored('Training from scratch', 'green'))
        best_loss = -1

        optimizerG = optim.Adam([{
            'params': model.GenX.parameters()
Ejemplo n.º 12
0
def get_classifier_config(args, model, dataset, balanced=False):
    print("Preparing training D1 for %s" % (dataset.name))

    # 80%, 20% for local train+test
    train_ds, valid_ds = dataset.split_dataset(0.8)

    if dataset.name in Global.mirror_augment:
        print(colored("Mirror augmenting %s" % dataset.name, 'green'))
        new_train_ds = train_ds + MirroredDataset(train_ds)
        train_ds = new_train_ds

    # Initialize the multi-threaded loaders.
    if balanced:
        y_train = []
        for x, y in train_ds:
            y_train.append(y.numpy())
        y_train = np.array(y_train)
        class_sample_count = np.array(
            [len(np.where(y_train == t)[0]) for t in np.unique(y_train)])
        print(class_sample_count)
        weight = 1. / class_sample_count
        samples_weight = np.array([weight[t] for t in y_train])

        samples_weight = torch.from_numpy(samples_weight)
        sampler = WeightedRandomSampler(
            samples_weight.type('torch.DoubleTensor'), len(samples_weight))
        train_loader = DataLoader(train_ds,
                                  batch_size=args.batch_size,
                                  num_workers=args.workers,
                                  pin_memory=True,
                                  sampler=sampler)

        y_val = []
        for x, y in valid_ds:
            y_val.append(y.numpy())
        y_val = np.array(y_val)
        class_sample_count = np.array(
            [len(np.where(y_val == t)[0]) for t in np.unique(y_val)])
        print(class_sample_count)
        weight = 1. / class_sample_count
        samples_weight = np.array([weight[t] for t in y_val])

        samples_weight = torch.from_numpy(samples_weight)
        sampler = WeightedRandomSampler(
            samples_weight.type('torch.DoubleTensor'), len(samples_weight))
        valid_loader = DataLoader(valid_ds,
                                  batch_size=args.batch_size,
                                  num_workers=args.workers,
                                  pin_memory=True,
                                  sampler=sampler)

    else:
        train_loader = DataLoader(train_ds,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.workers,
                                  pin_memory=True)

        valid_loader = DataLoader(valid_ds,
                                  batch_size=args.batch_size,
                                  num_workers=args.workers,
                                  pin_memory=True)
    all_loader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            num_workers=args.workers,
                            pin_memory=True)

    # Set up the criterion
    criterion = nn.NLLLoss().to(args.device)

    # Set up the model
    model = model.to(args.device)

    # Set up the config
    config = IterativeTrainerConfig()

    config.name = 'classifier_%s_%s' % (dataset.name, model.__class__.__name__)

    config.train_loader = train_loader
    config.valid_loader = valid_loader
    config.phases = {
        'train': {
            'dataset': train_loader,
            'backward': True
        },
        'test': {
            'dataset': valid_loader,
            'backward': False
        },
        'all': {
            'dataset': all_loader,
            'backward': False
        },
    }
    config.criterion = criterion
    config.classification = True
    config.stochastic_gradient = True
    config.visualize = not args.no_visualize
    config.model = model
    home_path = Models.get_ref_model_path(args,
                                          config.model.__class__.__name__,
                                          dataset.name,
                                          model_setup=True,
                                          suffix_str='base0')
    config.logger = Logger(home_path)

    config.optim = optim.Adam(model.parameters(), lr=1e-3)
    config.scheduler = optim.lr_scheduler.ReduceLROnPlateau(config.optim,
                                                            patience=10,
                                                            threshold=1e-2,
                                                            min_lr=1e-6,
                                                            factor=0.1,
                                                            verbose=True)
    config.max_epoch = 120

    if hasattr(model, 'train_config'):
        model_train_config = model.train_config()
        for key, value in model_train_config.items():
            print('Overriding config.%s' % key)
            config.__setattr__(key, value)

    return config
Ejemplo n.º 13
0
        D1 = D164.get_D1_train()

        emb = args.embedding_function.lower()
        assert emb in ["vae", "ae", "ali"]
        dummy_args = EasyDict()
        dummy_args.exp = "foo"
        dummy_args.experiment_path = args.experiment_path
        if args.encoder_loss.lower() == "bce":
            tag = "BCE"
        else:
            tag = "MSE"
        if emb == "vae":
            model = Global.dataset_reference_vaes[args.dataset][0]()
            home_path = Models.get_ref_model_path(dummy_args,
                                                  model.__class__.__name__,
                                                  D164.name,
                                                  suffix_str=tag + "." +
                                                  model.netid)
            model_path = os.path.join(home_path, 'model.best.pth')
        elif emb == "ae":
            model = Global.dataset_reference_autoencoders[args.dataset][0]()

            home_path = Models.get_ref_model_path(dummy_args,
                                                  model.__class__.__name__,
                                                  D164.name,
                                                  suffix_str=tag + "." +
                                                  model.netid)
            model_path = os.path.join(home_path, 'model.best.pth')
        else:
            model = Global.dataset_reference_ALI[args.dataset][0]()
            home_path = Models.get_ref_model_path(dummy_args,
Ejemplo n.º 14
0
def Train_ALI(args, model, dataset, BCE_Loss=True):
    dataloader = torch.utils.data.DataLoader(dataset,
                                             args.batch_size,
                                             True,
                                             num_workers=args.workers,
                                             pin_memory=True)
    home_path = Models.get_ref_model_path(args,
                                          model.__class__.__name__,
                                          dataset.name,
                                          model_setup=True,
                                          suffix_str='base0')
    logger = Logger(home_path)

    hbest_path = os.path.join(home_path, 'model.best.pth')

    if not os.path.isdir(home_path):
        os.makedirs(home_path)
    best_gen_loss = 9999
    if not os.path.isfile(hbest_path + ".done"):
        print(colored('Training from scratch', 'green'))

        optimizerG = optim.Adam([{
            'params': model.GenX.parameters()
        }, {
            'params': model.GenZ.parameters()
        }],
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))

        optimizerD = optim.Adam([{
            'params': model.DisZ.parameters()
        }, {
            'params': model.DisX.parameters()
        }, {
            'params': model.DisXZ.parameters()
        }],
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
        if BCE_Loss:
            criterion = nn.BCELoss()
        else:
            criterion = nn.MSELoss()

        for epoch in range(1, 100 + 1):
            model.train()
            with tqdm(total=len(dataloader),
                      disable=bool(os.environ.get("DISABLE_TQDM",
                                                  False))) as pbar:
                for i, (x, y) in enumerate(dataloader):
                    pbar.update()
                    batchsize = x.shape[0]
                    fakeZ = torch.randn(batchsize, 512, 1, 1).cuda()
                    pred_real, pred_fake = model.forward(x.cuda(), fakeZ)
                    truelabel = torch.ones(batchsize) - 0.1
                    fakelabel = torch.zeros(batchsize)

                    if args.random_label == True:
                        truelabel = torch.randint(
                            low=70, high=110, size=(1, batchsize))[0] / 100
                        fakelabel = torch.randint(
                            low=-10, high=30, size=(1, batchsize))[0] / 100
                    truelabel = truelabel.cuda()
                    fakelabel = fakelabel.cuda()
                    loss_d = criterion(pred_real.view(-1),
                                       truelabel) + criterion(
                                           pred_fake.view(-1), fakelabel)
                    loss_g = criterion(pred_fake.view(-1),
                                       truelabel) + criterion(
                                           pred_real.view(-1), fakelabel)
                    logger.log('Disc_loss', loss_d.item(), epoch, i)
                    logger.log('Gen_loss', loss_g.item(), epoch, i)

                    if loss_g > args.max_loss_g:
                        optimizerG.zero_grad()
                        loss_g.backward()
                        optimizerG.step()
                        pbar.set_description(
                            "Skipped D, Disc_loss %.4f, Gen_loss %.4f" %
                            (loss_d.item(), loss_g.item()))
                    elif loss_g < args.min_loss_g:
                        optimizerD.zero_grad()
                        loss_d.backward()
                        optimizerD.step()
                        pbar.set_description(
                            "Skipped G, Disc_loss %.4f, Gen_loss %.4f" %
                            (loss_d.item(), loss_g.item()))
                    else:
                        optimizerD.zero_grad()
                        loss_d.backward(retain_graph=True)
                        optimizerD.step()
                        optimizerG.zero_grad()
                        loss_g.backward()
                        optimizerG.step()
                        pbar.set_description("Disc_loss %.4f, Gen_loss %.4f" %
                                             (loss_d.item(), loss_g.item()))

            disc_loss = logger.get_measure('Disc_loss').mean_epoch()
            gen_loss = logger.get_measure('Gen_loss').mean_epoch()
            print("Discriminator loss %.4f, Generator loss %.4f" %
                  (disc_loss, gen_loss))

            logger.writer.add_scalar('disc_loss', disc_loss, epoch)
            logger.writer.add_scalar('gen_loss', gen_loss, epoch)

            # vis in tensorboard
            for (image, label) in dataloader:
                prediction = model(x=image.cuda()).data.cpu().squeeze().numpy()
                N = min(prediction.shape[0], 5)
                fig, ax = plt.subplots(N, 2)
                image = image.data.squeeze().numpy()
                for i in range(N):
                    ax[i, 0].imshow(prediction[i])
                    ax[i, 1].imshow(image[i])
                logger.writer.add_figure('Vis', fig, epoch)
                plt.close(fig)
                break

            torch.save(logger.measures, os.path.join(home_path, 'logger.pth'))

            if args.save and gen_loss < best_gen_loss:
                print('Updating the on file model with %s' %
                      (colored('%.4f' % gen_loss, 'red')))
                best_gen_loss = gen_loss
                torch.save(model.state_dict(), hbest_path)
Ejemplo n.º 15
0
    def propose_H(self, dataset):
        assert self.default_model > 0, 'KNN needs K>0'
        if self.base_model is not None:
            self.base_model.base_data = None
            self.base_model = None

        # Set up the base-model
        if isinstance(self, BCEKNNSVM) or isinstance(self, MSEKNNSVM):
            base_model = Global.get_ref_autoencoder(dataset.name)[0]().to(
                self.args.device)
            if isinstance(self, BCEKNNSVM):
                base_model.netid = "BCE." + base_model.netid
            else:
                base_model.netid = "MSE." + base_model.netid
            home_path = Models.get_ref_model_path(
                self.args,
                base_model.__class__.__name__,
                dataset.name,
                suffix_str=base_model.netid)
        elif isinstance(self, VAEKNNSVM):
            base_model = Global.get_ref_vae(dataset.name)[0]().to(
                self.args.device)
            home_path = Models.get_ref_model_path(
                self.args,
                base_model.__class__.__name__,
                dataset.name,
                suffix_str=base_model.netid)
        else:
            raise NotImplementedError()

        hbest_path = path.join(home_path, 'model.best.pth')
        best_h_path = hbest_path
        print(colored('Loading H1 model from %s' % best_h_path, 'red'))
        base_model.load_state_dict(torch.load(best_h_path))
        base_model.eval()

        if dataset.name in Global.mirror_augment:
            print(colored("Mirror augmenting %s" % dataset.name, 'green'))
            new_train_ds = dataset + MirroredDataset(dataset)
            dataset = new_train_ds

        # Initialize the multi-threaded loaders.
        all_loader = DataLoader(dataset,
                                batch_size=self.args.batch_size,
                                num_workers=1,
                                pin_memory=True)

        n_data = len(dataset)
        n_dim = base_model.encode(dataset[0][0].to(
            self.args.device).unsqueeze(0)).numel()
        print('nHidden %d' % (n_dim))
        self.base_data = torch.zeros(n_data, n_dim, dtype=torch.float32)
        base_ind = 0
        with torch.set_grad_enabled(False):
            with tqdm(total=len(all_loader)) as pbar:
                pbar.set_description('Caching X_train for %d-nn' %
                                     self.default_model)
                for i, (x, _) in enumerate(all_loader):
                    n_data = x.size(0)
                    output = base_model.encode(x.to(self.args.device)).data
                    self.base_data[base_ind:base_ind + n_data].copy_(output)
                    base_ind = base_ind + n_data
                    pbar.update()
        # self.base_data = torch.cat([x.view(1, -1) for x,_ in dataset])
        self.base_model = AEKNNModel(base_model,
                                     self.base_data,
                                     k=self.default_model).to(self.args.device)
        self.base_model.eval()
Ejemplo n.º 16
0
def train_classifier(args, model, dataset):
    config = None

    for mid in range(5):
        home_path = Models.get_ref_model_path(args,
                                              model.__class__.__name__,
                                              dataset.name,
                                              model_setup=True,
                                              suffix_str='DE.%d' % mid)
        hbest_path = os.path.join(home_path, 'model.best.pth')

        if not os.path.isdir(home_path):
            os.makedirs(home_path)
        else:
            if os.path.isfile(hbest_path + ".done"):
                print("Skipping %s" % (colored(home_path, 'yellow')))
                continue

        config = get_classifier_config(args,
                                       model.__class__(),
                                       dataset,
                                       mid=mid)

        trainer = IterativeTrainer(config, args)

        if not os.path.isfile(hbest_path + ".done"):
            print(colored('Training from scratch', 'green'))
            best_accuracy = -1
            for epoch in range(1, config.max_epoch + 1):

                # Track the learning rates.
                lrs = [
                    float(param_group['lr'])
                    for param_group in config.optim.param_groups
                ]
                config.logger.log('LRs', lrs, epoch)
                config.logger.get_measure('LRs').legend = [
                    'LR%d' % i for i in range(len(lrs))
                ]

                # One epoch of train and test.
                trainer.run_epoch(epoch, phase='train')
                trainer.run_epoch(epoch, phase='test')

                train_loss = config.logger.get_measure(
                    'train_loss').mean_epoch()
                config.scheduler.step(train_loss)

                if config.visualize:
                    # Show the average losses for all the phases in one figure.
                    config.logger.visualize_average_keys(
                        '.*_loss', 'Average Loss', trainer.visdom)
                    config.logger.visualize_average_keys(
                        '.*_accuracy', 'Average Accuracy', trainer.visdom)
                    config.logger.visualize_average('LRs', trainer.visdom)

                test_average_acc = config.logger.get_measure(
                    'test_accuracy').mean_epoch()

                # Save the logger for future reference.
                torch.save(config.logger.measures,
                           os.path.join(home_path, 'logger.pth'))

                # Saving a checkpoint. Enable if needed!
                # if args.save and epoch % 10 == 0:
                #     print('Saving a %s at iter %s'%(colored('snapshot', 'yellow'), colored('%d'%epoch, 'yellow')))
                #     torch.save(config.model.state_dict(), os.path.join(home_path, 'model.%d.pth'%epoch))

                if args.save and best_accuracy < test_average_acc:
                    print('Updating the on file model with %s' %
                          (colored('%.4f' % test_average_acc, 'red')))
                    best_accuracy = test_average_acc
                    torch.save(config.model.state_dict(), hbest_path)

            torch.save({'finished': True}, hbest_path + ".done")
            if config.visualize:
                trainer.visdom.save([trainer.visdom.env])
        else:
            print("Skipping %s" % (colored(home_path, 'yellow')))

        print("Loading the best model.")
        config.model.load_state_dict(torch.load(hbest_path))
        config.model.eval()

        trainer.run_epoch(0, phase='all')
        test_average_acc = config.logger.get_measure(
            'all_accuracy').mean_epoch(epoch=0)
        print("All average accuracy %s" %
              colored('%.4f%%' % (test_average_acc * 100), 'red'))
Ejemplo n.º 17
0
def train_variational_autoencoder(args, model, dataset, BCE_Loss=True):
    if BCE_Loss:
        model.netid = "BCE." + model.netid
    else:
        model.netid = "MSE." + model.netid
    home_path = Models.get_ref_model_path(args,
                                          model.__class__.__name__,
                                          dataset.name,
                                          model_setup=True,
                                          suffix_str=model.netid)
    hbest_path = os.path.join(home_path, 'model.best.pth')
    hlast_path = os.path.join(home_path, 'model.last.pth')

    if not os.path.isdir(home_path):
        os.makedirs(home_path)

    if not os.path.isfile(hbest_path + ".done"):
        config = get_vae_config(args, model, dataset, home_path, BCE_Loss)
        trainer = IterativeTrainer(config, args)
        print(colored('Training from scratch', 'green'))
        best_loss = 999999999
        for epoch in range(1, config.max_epoch + 1):

            # Track the learning rates.
            lrs = [
                float(param_group['lr'])
                for param_group in config.optim.param_groups
            ]
            config.logger.log('LRs', lrs, epoch)
            config.logger.get_measure('LRs').legend = [
                'LR%d' % i for i in range(len(lrs))
            ]

            # One epoch of train and test.
            trainer.run_epoch(epoch, phase='train')
            trainer.run_epoch(epoch, phase='test')

            train_loss = config.logger.get_measure('train_loss').mean_epoch()
            test_loss = config.logger.get_measure('test_loss').mean_epoch()

            config.logger.writer.add_scalar('train_loss', train_loss, epoch)
            config.logger.writer.add_scalar('test_loss', test_loss, epoch)
            config.scheduler.step(train_loss)

            # vis in tensorboard
            for (image, label) in config.valid_loader:
                prediction = model(image.cuda()).data.cpu().squeeze().numpy()
                prediction = (prediction - prediction.min()) / (
                    prediction.max() - prediction.min())
                if len(prediction.shape) > 3 and prediction.shape[1] == 3:
                    prediction = prediction.transpose(
                        (0, 2, 3, 1))  # change to N W H C
                N = min(prediction.shape[0], 5)
                fig, ax = plt.subplots(N, 2)
                image = image.data.squeeze().numpy()
                image = (image - image.min()) / (image.max() - image.min())
                if len(image.shape) > 3 and image.shape[1] == 3:
                    image = image.transpose((0, 2, 3, 1))
                for i in range(N):
                    ax[i, 0].imshow(prediction[i])
                    ax[i, 1].imshow(image[i])
                config.logger.writer.add_figure('Vis', fig, epoch)
                plt.close(fig)
                break

            if config.visualize:
                # Show the average losses for all the phases in one figure.
                config.logger.visualize_average_keys('.*_loss', 'Average Loss',
                                                     trainer.visdom)
                config.logger.visualize_average_keys('.*_accuracy',
                                                     'Average Accuracy',
                                                     trainer.visdom)
                config.logger.visualize_average('LRs', trainer.visdom)

            # Save the logger for future reference.
            torch.save(config.logger.measures,
                       os.path.join(home_path, 'logger.pth'))

            # Saving a checkpoint. Enable if needed!
            # if args.save and epoch % 10 == 0:
            #     print('Saving a %s at iter %s'%(colored('snapshot', 'yellow'), colored('%d'%epoch, 'yellow')))
            #     torch.save(config.model.state_dict(), os.path.join(home_path, 'model.%d.pth'%epoch))

            if args.save and test_loss < best_loss:
                print('Updating the on file model with %s' %
                      (colored('%.4f' % test_loss, 'red')))
                best_loss = test_loss
                torch.save(config.model.state_dict(), hbest_path)

        torch.save({'finished': True}, hbest_path + ".done")
        torch.save(config.model.state_dict(), hlast_path)

        if config.visualize:
            trainer.visdom.save([trainer.visdom.env])
    else:
        print("Skipping %s" % (colored(home_path, 'yellow')))