def get_D2_valid(self, D1):
     assert self.is_compatible(D1)
     return SubDataset(self.name,
                       self.ds_all,
                       self.val_ind,
                       label=1,
                       transform=D1.conformity_transform())
 def get_D2_test(self, D1):
     assert self.is_compatible(D1)
     return SubDataset(self.name,
                       self.ds_test,
                       self.D2_test_ind,
                       label=1,
                       transform=D1.conformity_transform())
示例#3
0
 def get_D2_test(self, D1):
     assert self.is_compatible(D1)
     target_indices = self.D2_test_ind
     return SubDataset(self.name,
                       self.ds_valid,
                       target_indices,
                       label=1,
                       transform=D1.conformity_transform())
示例#4
0
 def get_D2_valid(self, D1):
     assert self.is_compatible(D1)
     target_indices = self.source_data.D2_valid_ind
     return SubDataset(self.name,
                       self.source_data.ds_train,
                       target_indices,
                       label=1,
                       transform=D1.conformity_transform())
示例#5
0
文件: CIFAR.py 项目: yw981/OD-test
 def get_D2_test(self, D1):
     assert self.is_compatible(D1)
     target_indices = self.D2_test_ind
     if self.filter_rules.has_key(D1.name):
         target_indices = filter_indices(self.ds_test, target_indices,
                                         self.filter_rules[D1.name])
     return SubDataset(self.name,
                       self.ds_test,
                       target_indices,
                       label=1,
                       transform=D1.conformity_transform())
 def get_D1_test(self):
     return SubDataset(self.name, self.ds_test, self.D1_test_ind, label=0)
 def get_D1_valid(self):
     return SubDataset(self.name, self.ds_train, self.D1_valid_ind, label=0)
 def get_D1_train(self):
     return SubDataset(self.name, self.ds_train, self.D1_train_ind)
示例#9
0
    def train_H(self, dataset):
        # Wrap the (mixture)dataset in SubDataset so to easily
        # split it later.
        dataset = SubDataset('%s-%s' % (self.args.D1, self.args.D2), dataset,
                             torch.arange(len(dataset)).int())

        # 80%, 20% for local train+test
        train_ds, valid_ds = dataset.split_dataset(0.8)

        if self.args.D1 in Global.mirror_augment:
            print(colored("Mirror augmenting %s" % self.args.D1, 'green'))
            new_train_ds = train_ds + MirroredDataset(train_ds)
            train_ds = new_train_ds

        # As suggested by the authors.
        all_temperatures = [1, 2, 5, 10, 20, 50, 100, 200, 500, 1000]
        all_epsilons = torch.linspace(0, 0.004, 21)
        total_params = len(all_temperatures) * len(all_epsilons)
        best_accuracy = -1

        h_path = path.join(self.args.experiment_path,
                           '%s' % (self.__class__.__name__),
                           '%d' % (self.default_model),
                           '%s-%s.pth' % (self.args.D1, self.args.D2))
        h_parent = path.dirname(h_path)
        if not path.isdir(h_parent):
            os.makedirs(h_parent)

        done_path = h_path + '.done'
        trainer, h_config = None, None

        if self.args.force_train_h or not path.isfile(done_path):
            # Grid search over the temperature and the epsilons.
            for i_eps, eps in enumerate(all_epsilons):
                for i_temp, temp in enumerate(all_temperatures):
                    so_far = i_eps * len(all_temperatures) + i_temp + 1
                    print(
                        colored(
                            'Checking eps=%.2e temp=%.1f (%d/%d)' %
                            (eps, temp, so_far, total_params), 'green'))
                    start_time = timeit.default_timer()

                    h_config = self.get_H_config(train_ds=train_ds,
                                                 valid_ds=valid_ds,
                                                 epsilon=eps,
                                                 temperature=temp)

                    trainer = IterativeTrainer(h_config, self.args)

                    print(colored('Training from scratch', 'green'))
                    trainer.run_epoch(0, phase='test')
                    for epoch in range(1, h_config.max_epoch + 1):
                        trainer.run_epoch(epoch, phase='train')
                        trainer.run_epoch(epoch, phase='test')

                        train_loss = h_config.logger.get_measure(
                            'train_loss').mean_epoch()
                        h_config.scheduler.step(train_loss)

                        # Track the learning rates and threshold.
                        lrs = [
                            float(param_group['lr'])
                            for param_group in h_config.optim.param_groups
                        ]
                        h_config.logger.log('LRs', lrs, epoch)
                        h_config.logger.get_measure('LRs').legend = [
                            'LR%d' % i for i in range(len(lrs))
                        ]

                        if hasattr(h_config.model, 'H') and hasattr(
                                h_config.model.H, 'threshold'):
                            h_config.logger.log(
                                'threshold',
                                h_config.model.H.threshold.cpu().numpy(),
                                epoch - 1)
                            h_config.logger.get_measure('threshold').legend = [
                                'threshold'
                            ]
                            if h_config.visualize:
                                h_config.logger.get_measure(
                                    'threshold').visualize_all_epochs(
                                        trainer.visdom)

                        if h_config.visualize:
                            # Show the average losses for all the phases in one figure.
                            h_config.logger.visualize_average_keys(
                                '.*_loss', 'Average Loss', trainer.visdom)
                            h_config.logger.visualize_average_keys(
                                '.*_accuracy', 'Average Accuracy',
                                trainer.visdom)
                            h_config.logger.visualize_average(
                                'LRs', trainer.visdom)

                        test_average_acc = h_config.logger.get_measure(
                            'test_accuracy').mean_epoch()

                        if best_accuracy < test_average_acc:
                            print('Updating the on file model with %s' %
                                  (colored('%.4f' % test_average_acc, 'red')))
                            best_accuracy = test_average_acc
                            torch.save(h_config.model.H.state_dict(), h_path)

                    elapsed = timeit.default_timer() - start_time
                    print('Hyper-param check (%.2e, %.1f) in %.2fs' %
                          (eps, temp, elapsed))

            torch.save({'finished': True}, done_path)

        # If we load the pretrained model directly, we will have to initialize these.
        if trainer is None or h_config is None:
            h_config = self.get_H_config(train_ds=train_ds,
                                         valid_ds=valid_ds,
                                         epsilon=0,
                                         temperature=1,
                                         will_train=False)
            # don't worry about the values of epsilon or temperature. it will be overwritten.
            trainer = IterativeTrainer(h_config, self.args)

        # Load the best model.
        print(colored('Loading H model from %s' % h_path, 'red'))
        state_dict = torch.load(h_path)
        for key, val in state_dict.items():
            if val.shape == torch.Size([]):
                state_dict[key] = val.view((1, ))
        h_config.model.H.load_state_dict(state_dict)
        h_config.model.set_eval_direct(False)
        print('Temperature %s Epsilon %s' %
              (colored(h_config.model.H.temperature.item(), 'red'),
               colored(h_config.model.H.epsilon.item(), 'red')))

        trainer.run_epoch(0, phase='testU')
        test_average_acc = h_config.logger.get_measure(
            'testU_accuracy').mean_epoch(epoch=0)
        print("Valid/Test average accuracy %s" %
              colored('%.4f%%' % (test_average_acc * 100), 'red'))
        self.H_class = h_config.model
        self.H_class.eval()
        self.H_class.set_eval_direct(False)
        return test_average_acc
示例#10
0
 def get_D2_valid(self, D1):
     assert self.is_compatible(D1)
     target_indices = self.D2_valid_ind
     if D1.name in self.filter_rules:
         target_indices = filter_indices(self.ds_train, target_indices, self.filter_rules[D1.name])
     return SubDataset(self.name, self.ds_train, target_indices, label=1, transform=D1.conformity_transform())
示例#11
0
    def train_H(self, dataset):
        # Wrap the (mixture)dataset in SubDataset so to easily
        # split it later.
        from datasets import SubDataset
        dataset = SubDataset('%s-%s' % (self.args.D1, self.args.D2), dataset,
                             torch.arange(len(dataset)).int())

        h_path = path.join(self.args.experiment_path,
                           '%s' % (self.__class__.__name__),
                           '%d' % (self.default_model),
                           '%s->%s.pth' % (self.args.D1, self.args.D2))
        h_parent = path.dirname(h_path)
        if not path.isdir(h_parent):
            os.makedirs(h_parent)

        done_path = h_path + '.done'
        will_train = self.args.force_train_h or not path.isfile(done_path)

        h_config = self.get_H_config(dataset, will_train)

        trainer = IterativeTrainer(h_config, self.args)

        if will_train:
            print(colored('Training from scratch', 'green'))
            best_accuracy = -1
            trainer.run_epoch(0, phase='test')
            for epoch in range(1, h_config.max_epoch + 1):
                trainer.run_epoch(epoch, phase='train')
                trainer.run_epoch(epoch, phase='test')

                train_loss = h_config.logger.get_measure(
                    'train_loss').mean_epoch()
                h_config.scheduler.step(train_loss)

                # Track the learning rates and threshold.
                lrs = [
                    float(param_group['lr'])
                    for param_group in h_config.optim.param_groups
                ]
                h_config.logger.log('LRs', lrs, epoch)
                h_config.logger.get_measure('LRs').legend = [
                    'LR%d' % i for i in range(len(lrs))
                ]

                viz_params = ['threshold', 'transfer']
                for viz_param in viz_params:
                    if hasattr(h_config.model, 'H') and hasattr(
                            h_config.model.H, viz_param):
                        h_config.logger.log(
                            viz_param,
                            getattr(h_config.model.H, viz_param).cpu().numpy(),
                            epoch - 1)
                        h_config.logger.get_measure(viz_param).legend = [
                            viz_param
                        ]
                        if h_config.visualize:
                            h_config.logger.get_measure(
                                viz_param).visualize_all_epochs(trainer.visdom)

                if h_config.visualize:
                    # Show the average losses for all the phases in one figure.
                    h_config.logger.visualize_average_keys(
                        '.*_loss', 'Average Loss', trainer.visdom)
                    h_config.logger.visualize_average_keys(
                        '.*_accuracy', 'Average Accuracy', trainer.visdom)
                    h_config.logger.visualize_average('LRs', trainer.visdom)

                test_average_acc = h_config.logger.get_measure(
                    'test_accuracy').mean_epoch()

                # Save the logger for future reference.
                torch.save(
                    h_config.logger.measures,
                    path.join(
                        h_parent,
                        'logger.%s->%s.pth' % (self.args.D1, self.args.D2)))

                if best_accuracy < test_average_acc:
                    print('Updating the on file model with %s' %
                          (colored('%.4f' % test_average_acc, 'red')))
                    best_accuracy = test_average_acc
                    torch.save(h_config.model.H.state_dict(), h_path)

            torch.save({'finished': True}, done_path)

            if h_config.visualize:
                trainer.visdom.save([trainer.visdom.env])

        # Load the best model.
        print(colored('Loading H model from %s' % h_path, 'red'))
        h_config.model.H.load_state_dict(torch.load(h_path))
        h_config.model.set_eval_direct(False)

        trainer.run_epoch(0, phase='testU')
        test_average_acc = h_config.logger.get_measure(
            'testU_accuracy').mean_epoch(epoch=0)
        print("Valid/Test average accuracy %s" %
              colored('%.4f%%' % (test_average_acc * 100), 'red'))
        self.H_class = h_config.model
        self.H_class.eval()
        self.H_class.set_eval_direct(False)
        return test_average_acc
示例#12
0
 def get_D1_test(self):
     return SubDataset(self.name,
                       self.source_data.ds_test,
                       self.source_data.D1_test_ind,
                       label=0,
                       transform=self.transform)
示例#13
0
 def get_D1_valid(self):
     return SubDataset(self.name,
                       self.source_data.ds_valid,
                       self.source_data.D1_valid_ind,
                       label=0,
                       transform=self.transform)
示例#14
0
 def get_D1_train(self):
     return SubDataset(self.name,
                       self.source_data.ds_train,
                       self.source_data.D1_train_ind,
                       transform=self.transform)
示例#15
0
 def get_D1_valid(self):
     return SubDataset(self.name, self.ds_all, self.val_d1_ind, label=0)