Example #1
0
    def __init__(self):
        self.args = self.parse_command_line()

        self.checkpoint_dir, self.logfile, self.checkpoint_path_validation, self.checkpoint_path_final \
            = get_log_files(self.args.checkpoint_dir)

        print_and_log(self.logfile, "Options: %s\n" % self.args)
        print_and_log(self.logfile, "Checkpoint Directory: %s\n" % self.checkpoint_dir)

        gpu_device = 'cuda:0'
        self.device = torch.device(gpu_device if torch.cuda.is_available() else 'cpu')
        self.model = self.init_model()
        self.train_set, self.validation_set, self.test_set = self.init_data()
        self.metadataset = MetaDatasetReader(self.args.data_path, self.args.mode, self.train_set, self.validation_set,
                                             self.test_set)
        self.loss = loss
        self.accuracy_fn = aggregate_accuracy
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        self.optimizer.zero_grad()
        self.validation_accuracies = ValidationAccuracies(self.validation_set)
Example #2
0
    def __init__(self):
        self.args = self.parse_command_line()

        self.checkpoint_dir, self.logfile, self.checkpoint_path_validation, self.checkpoint_path_final \
            = get_log_files(self.args.checkpoint_dir, self.args.resume_from_checkpoint, self.args.mode == "test" or
                            self.args.mode == "attack")

        print_and_log(self.logfile, "Options: %s\n" % self.args)
        print_and_log(self.logfile,
                      "Checkpoint Directory: %s\n" % self.checkpoint_dir)

        gpu_device = 'cuda:0'
        self.device = torch.device(
            gpu_device if torch.cuda.is_available() else 'cpu')
        self.model = self.init_model()
        self.train_set, self.validation_set, self.test_set = self.init_data()
        if self.args.dataset == "meta-dataset":
            self.dataset = MetaDatasetReader(
                self.args.data_path, self.args.mode, self.train_set,
                self.validation_set, self.test_set, self.args.max_way_train,
                self.args.max_way_test, self.args.max_support_train,
                self.args.max_support_test)
        else:
            self.dataset = SingleDatasetReader(self.args.data_path,
                                               self.args.mode,
                                               self.args.dataset,
                                               self.args.way, self.args.shot,
                                               self.args.query_train,
                                               self.args.query_test)
        self.loss = loss
        self.accuracy_fn = aggregate_accuracy
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.args.learning_rate)
        self.validation_accuracies = ValidationAccuracies(self.validation_set)
        self.start_iteration = 0
        if self.args.resume_from_checkpoint:
            self.load_checkpoint()
        self.optimizer.zero_grad()
Example #3
0
class Learner:
    def __init__(self):
        self.args = self.parse_command_line()

        self.checkpoint_dir, self.logfile, self.checkpoint_path_validation, self.checkpoint_path_final \
            = get_log_files(self.args.checkpoint_dir, self.args.resume_from_checkpoint, self.args.mode == "test" or
                            self.args.mode == "attack")

        print_and_log(self.logfile, "Options: %s\n" % self.args)
        print_and_log(self.logfile,
                      "Checkpoint Directory: %s\n" % self.checkpoint_dir)

        gpu_device = 'cuda:0'
        self.device = torch.device(
            gpu_device if torch.cuda.is_available() else 'cpu')
        self.model = self.init_model()
        self.train_set, self.validation_set, self.test_set = self.init_data()
        if self.args.dataset == "meta-dataset":
            self.dataset = MetaDatasetReader(
                self.args.data_path, self.args.mode, self.train_set,
                self.validation_set, self.test_set, self.args.max_way_train,
                self.args.max_way_test, self.args.max_support_train,
                self.args.max_support_test)
        else:
            self.dataset = SingleDatasetReader(self.args.data_path,
                                               self.args.mode,
                                               self.args.dataset,
                                               self.args.way, self.args.shot,
                                               self.args.query_train,
                                               self.args.query_test)
        self.loss = loss
        self.accuracy_fn = aggregate_accuracy
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.args.learning_rate)
        self.validation_accuracies = ValidationAccuracies(self.validation_set)
        self.start_iteration = 0
        if self.args.resume_from_checkpoint:
            self.load_checkpoint()
        self.optimizer.zero_grad()

    def init_model(self):
        use_two_gpus = self.use_two_gpus()
        model = Cnaps(device=self.device,
                      use_two_gpus=use_two_gpus,
                      args=self.args).to(self.device)
        self.register_extra_parameters(model)

        # set encoder is always in train mode (it only sees context data).
        model.train()
        # Feature extractor is in eval mode by default, but gets switched in model depending on args.batch_normalization
        model.feature_extractor.eval()
        if use_two_gpus:
            model.distribute_model()
        return model

    def init_data(self):
        if self.args.dataset == "meta-dataset":
            train_set = [
                'ilsvrc_2012', 'omniglot', 'aircraft', 'cu_birds', 'dtd',
                'quickdraw', 'fungi', 'vgg_flower'
            ]
            validation_set = [
                'ilsvrc_2012', 'omniglot', 'aircraft', 'cu_birds', 'dtd',
                'quickdraw', 'fungi', 'vgg_flower', 'mscoco'
            ]
            test_set = self.args.test_datasets
        else:
            train_set = [self.args.dataset]
            validation_set = [self.args.dataset]
            test_set = [self.args.dataset]

        return train_set, validation_set, test_set

    """
    Command line parser
    """

    def parse_command_line(self):
        parser = argparse.ArgumentParser()

        parser.add_argument("--dataset",
                            choices=[
                                "meta-dataset", "ilsvrc_2012", "omniglot",
                                "aircraft", "cu_birds", "dtd", "quickdraw",
                                "fungi", "vgg_flower", "traffic_sign",
                                "mscoco", "mnist", "cifar10", "cifar100"
                            ],
                            default="meta-dataset",
                            help="Dataset to use.")
        parser.add_argument('--test_datasets',
                            nargs='+',
                            help='Datasets to use for testing',
                            default=[
                                "ilsvrc_2012", "omniglot", "aircraft",
                                "cu_birds", "dtd", "quickdraw", "fungi",
                                "vgg_flower", "traffic_sign", "mscoco",
                                "mnist", "cifar10", "cifar100"
                            ])
        parser.add_argument("--data_path",
                            default="../datasets",
                            help="Path to dataset records.")
        parser.add_argument("--pretrained_resnet_path",
                            default="../models/pretrained_resnet.pt.tar",
                            help="Path to pretrained feature extractor model.")
        parser.add_argument(
            "--mode",
            choices=["train", "test", "train_test", "attack"],
            default="train_test",
            help=
            "Whether to run training only, testing only, or both training and testing."
        )
        parser.add_argument("--learning_rate",
                            "-lr",
                            type=float,
                            default=5e-4,
                            help="Learning rate.")
        parser.add_argument(
            "--tasks_per_batch",
            type=int,
            default=16,
            help="Number of tasks between parameter optimizations.")
        parser.add_argument("--checkpoint_dir",
                            "-c",
                            default='../checkpoints',
                            help="Directory to save checkpoint to.")
        parser.add_argument("--test_model_path",
                            "-m",
                            default=None,
                            help="Path to model to load and test.")
        parser.add_argument(
            "--feature_adaptation",
            choices=["no_adaptation", "film", "film+ar"],
            default="film",
            help="Method to adapt feature extractor parameters.")
        parser.add_argument("--batch_normalization",
                            choices=["basic", "task_norm-i"],
                            default="basic",
                            help="Normalization layer to use.")
        parser.add_argument("--training_iterations",
                            "-i",
                            type=int,
                            default=110000,
                            help="Number of meta-training iterations.")
        parser.add_argument("--attack_tasks",
                            "-a",
                            type=int,
                            default=10,
                            help="Number of tasks when performing attack.")
        parser.add_argument("--val_freq",
                            type=int,
                            default=10000,
                            help="Number of iterations between validations.")
        parser.add_argument(
            "--max_way_train",
            type=int,
            default=40,
            help="Maximum way of meta-dataset meta-train task.")
        parser.add_argument("--max_way_test",
                            type=int,
                            default=50,
                            help="Maximum way of meta-dataset meta-test task.")
        parser.add_argument(
            "--max_support_train",
            type=int,
            default=400,
            help="Maximum support set size of meta-dataset meta-train task.")
        parser.add_argument(
            "--max_support_test",
            type=int,
            default=500,
            help="Maximum support set size of meta-dataset meta-test task.")
        parser.add_argument("--resume_from_checkpoint",
                            "-r",
                            dest="resume_from_checkpoint",
                            default=False,
                            action="store_true",
                            help="Restart from latest checkpoint.")
        parser.add_argument("--way",
                            type=int,
                            default=5,
                            help="Way of single dataset task.")
        parser.add_argument(
            "--shot",
            type=int,
            default=1,
            help="Shots per class for context of single dataset task.")
        parser.add_argument(
            "--query_train",
            type=int,
            default=10,
            help="Shots per class for target  of single dataset task.")
        parser.add_argument(
            "--query_test",
            type=int,
            default=10,
            help="Shots per class for target  of single dataset task.")

        args = parser.parse_args()

        return args

    def run(self):
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        with tf.compat.v1.Session(config=config) as session:
            if self.args.mode == 'train' or self.args.mode == 'train_test':
                train_accuracies = []
                losses = []
                total_iterations = self.args.training_iterations
                for iteration in range(self.start_iteration, total_iterations):
                    torch.set_grad_enabled(True)
                    task_dict = self.dataset.get_train_task(session)
                    task_loss, task_accuracy = self.train_task(task_dict)
                    train_accuracies.append(task_accuracy)
                    losses.append(task_loss)

                    # optimize
                    if ((iteration + 1) % self.args.tasks_per_batch
                            == 0) or (iteration == (total_iterations - 1)):
                        self.optimizer.step()
                        self.optimizer.zero_grad()

                    if (iteration + 1) % PRINT_FREQUENCY == 0:
                        # print training stats
                        print_and_log(
                            self.logfile,
                            'Task [{}/{}], Train Loss: {:.7f}, Train Accuracy: {:.7f}'
                            .format(
                                iteration + 1, total_iterations,
                                torch.Tensor(losses).mean().item(),
                                torch.Tensor(train_accuracies).mean().item()))
                        train_accuracies = []
                        losses = []

                    if ((iteration + 1) % self.args.val_freq
                            == 0) and (iteration + 1) != total_iterations:
                        # validate
                        accuracy_dict = self.validate(session)
                        self.validation_accuracies.print(
                            self.logfile, accuracy_dict)
                        # save the model if validation is the best so far
                        if self.validation_accuracies.is_better(accuracy_dict):
                            self.validation_accuracies.replace(accuracy_dict)
                            torch.save(self.model.state_dict(),
                                       self.checkpoint_path_validation)
                            print_and_log(
                                self.logfile,
                                'Best validation model was updated.')
                            print_and_log(self.logfile, '')
                        self.save_checkpoint(iteration + 1)

                # save the final model
                torch.save(self.model.state_dict(), self.checkpoint_path_final)

            if self.args.mode == 'train_test':
                self.test(self.checkpoint_path_final, session)
                self.test(self.checkpoint_path_validation, session)

            if self.args.mode == 'test':
                self.test(self.args.test_model_path, session)

            if self.args.mode == 'attack':
                self.attack(self.args.test_model_path, session)

            self.logfile.close()

    def train_task(self, task_dict):
        context_images, target_images, context_labels, target_labels = self.prepare_task(
            task_dict)

        target_logits = self.model(context_images, context_labels,
                                   target_images)
        task_loss = self.loss(target_logits, target_labels,
                              self.device) / self.args.tasks_per_batch
        if self.args.feature_adaptation == 'film' or self.args.feature_adaptation == 'film+ar':
            if self.use_two_gpus():
                regularization_term = (self.model.feature_adaptation_network.
                                       regularization_term()).cuda(0)
            else:
                regularization_term = (self.model.feature_adaptation_network.
                                       regularization_term())
            regularizer_scaling = 0.001
            task_loss += regularizer_scaling * regularization_term
        task_accuracy = self.accuracy_fn(target_logits, target_labels)

        task_loss.backward(retain_graph=False)

        return task_loss, task_accuracy

    def validate(self, session):
        with torch.no_grad():
            accuracy_dict = {}
            for item in self.validation_set:
                accuracies = []
                for _ in range(NUM_VALIDATION_TASKS):
                    task_dict = self.dataset.get_validation_task(item, session)
                    context_images, target_images, context_labels, target_labels = self.prepare_task(
                        task_dict)
                    target_logits = self.model(context_images, context_labels,
                                               target_images)
                    accuracy = self.accuracy_fn(target_logits, target_labels)
                    accuracies.append(accuracy.item())
                    del target_logits

                accuracy = np.array(accuracies).mean() * 100.0
                confidence = (196.0 * np.array(accuracies).std()) / np.sqrt(
                    len(accuracies))

                accuracy_dict[item] = {
                    "accuracy": accuracy,
                    "confidence": confidence
                }

        return accuracy_dict

    def test(self, path, session):
        print_and_log(self.logfile, "")  # add a blank line
        print_and_log(self.logfile, 'Testing model {0:}: '.format(path))
        self.model = self.init_model()
        self.model.load_state_dict(torch.load(path))

        with torch.no_grad():
            for item in self.test_set:
                accuracies = []
                for _ in range(NUM_TEST_TASKS):
                    task_dict = self.dataset.get_test_task(item, session)
                    context_images, target_images, context_labels, target_labels = self.prepare_task(
                        task_dict)
                    target_logits = self.model(context_images, context_labels,
                                               target_images)
                    accuracy = self.accuracy_fn(target_logits, target_labels)
                    accuracies.append(accuracy.item())
                    del target_logits

                accuracy = np.array(accuracies).mean() * 100.0
                accuracy_confidence = (196.0 *
                                       np.array(accuracies).std()) / np.sqrt(
                                           len(accuracies))

                print_and_log(
                    self.logfile,
                    '{0:}: {1:3.1f}+/-{2:2.1f}'.format(item, accuracy,
                                                       accuracy_confidence))

    def attack(self, path, session):
        print_and_log(self.logfile, "")  # add a blank line
        print_and_log(self.logfile, 'Attacking model {0:}: '.format(path))
        self.model = self.init_model()
        self.model.load_state_dict(torch.load(path))
        pgd_parameters = self.pgd_params()

        class_index = 0
        context_images, target_images, context_labels, target_labels, context_images_np = None, None, None, None, None

        def model_wrapper(context_point_x):
            # Insert context_point at correct spot
            context_images_attack = torch.cat([
                context_images[0:class_index], context_point_x,
                context_images[class_index + 1:]
            ],
                                              dim=0)

            target_logits = self.model(context_images_attack, context_labels,
                                       target_images)
            return target_logits[0]

        tf_model_conv = convert_pytorch_model_to_tf(model_wrapper,
                                                    out_dims=self.args.way)
        tf_model = cleverhans.model.CallableModelWrapper(
            tf_model_conv, 'logits')
        pgd = ProjectedGradientDescent(tf_model,
                                       sess=session,
                                       dtypestr='float32')

        for item in self.test_set:

            for t in range(self.args.attack_tasks):

                task_dict = self.dataset.get_test_task(item, session)
                context_images, target_images, context_labels, target_labels, context_images_np = self.prepare_task(
                    task_dict, shuffle=False)
                # Detach shares storage with the original tensor, which isn't what we want.
                context_images_attack_all = context_images.clone()
                # Is require_grad true here, for context_images?

                for c in torch.unique(context_labels):
                    # Adversarial input context image
                    class_index = extract_class_indices(context_labels,
                                                        c)[0].item()
                    context_x = np.expand_dims(context_images_np[class_index],
                                               0)

                    # Input to the model wrapper is automatically converted to Torch tensor for us

                    x = tf.placeholder(tf.float32, shape=context_x.shape)

                    adv_x_op = pgd.generate(x, **pgd_parameters)
                    preds_adv_op = tf_model.get_logits(adv_x_op)

                    feed_dict = {x: context_x}
                    adv_x, preds_adv = session.run((adv_x_op, preds_adv_op),
                                                   feed_dict=feed_dict)

                    context_images_attack_all[class_index] = torch.from_numpy(
                        adv_x)

                    save_image(adv_x,
                               os.path.join(self.checkpoint_dir, 'adv.png'))
                    save_image(context_x,
                               os.path.join(self.checkpoint_dir, 'in.png'))

                    acc_after = torch.mean(
                        torch.eq(
                            target_labels,
                            torch.argmax(torch.from_numpy(preds_adv).to(
                                self.device),
                                         dim=-1)).float()).item()

                    with torch.no_grad():
                        logits = self.model(context_images, context_labels,
                                            target_images)
                        acc_before = torch.mean(
                            torch.eq(target_labels,
                                     torch.argmax(logits,
                                                  dim=-1)).float()).item()
                        del logits

                    diff = acc_before - acc_after
                    print_and_log(
                        self.logfile,
                        "Task = {}, Class = {} \t Diff = {}".format(
                            t, c, diff))

                print_and_log(self.logfile,
                              "Accuracy before {}".format(acc_after))
                logits = self.model(context_images_attack_all, context_labels,
                                    target_images)
                acc_all_attack = torch.mean(
                    torch.eq(target_labels,
                             torch.argmax(logits, dim=-1)).float()).item()
                print_and_log(self.logfile,
                              "Accuracy after {}".format(acc_all_attack))

    def pgd_params(self,
                   eps=0.3,
                   eps_iter=0.01,
                   ord=np.Inf,
                   nb_iter=10,
                   rand_init=True,
                   clip_grad=True):
        return dict(
            eps=eps,
            eps_iter=eps_iter,
            ord=ord,
            nb_iter=nb_iter,
            rand_init=rand_init,
            clip_grad=clip_grad,
            clip_min=-1.0,
            clip_max=1.0,
        )

    def prepare_task(self, task_dict, shuffle=True):
        context_images_np, context_labels_np = task_dict[
            'context_images'], task_dict['context_labels']
        target_images_np, target_labels_np = task_dict[
            'target_images'], task_dict['target_labels']

        context_images_np = context_images_np.transpose([0, 3, 1, 2])

        if shuffle:
            context_images_np, context_labels_np = self.shuffle(
                context_images_np, context_labels_np)
        context_images = torch.from_numpy(context_images_np)
        context_labels = torch.from_numpy(context_labels_np)

        target_images_np = target_images_np.transpose([0, 3, 1, 2])
        if shuffle:
            target_images_np, target_labels_np = self.shuffle(
                target_images_np, target_labels_np)
        target_images = torch.from_numpy(target_images_np)
        target_labels = torch.from_numpy(target_labels_np)

        context_images = context_images.to(self.device)
        target_images = target_images.to(self.device)
        context_labels = context_labels.to(self.device)
        target_labels = target_labels.type(torch.LongTensor).to(self.device)

        return context_images, target_images, context_labels, target_labels, context_images_np

    def shuffle(self, images, labels):
        """
        Return shuffled data.
        """
        permutation = np.random.permutation(images.shape[0])
        return images[permutation], labels[permutation]

    def use_two_gpus(self):
        use_two_gpus = False
        if self.args.dataset == "meta-dataset":
            if self.args.feature_adaptation == "film+ar" or \
               self.args.batch_normalization == "task_norm-i":
                use_two_gpus = True  # These models do not fit on one GPU, so use model parallelism.

        return use_two_gpus

    def save_checkpoint(self, iteration):
        torch.save(
            {
                'iteration':
                iteration,
                'model_state_dict':
                self.model.state_dict(),
                'optimizer_state_dict':
                self.optimizer.state_dict(),
                'best_accuracy':
                self.validation_accuracies.get_current_best_accuracy_dict(),
            }, os.path.join(self.checkpoint_dir, 'checkpoint.pt'))

    def load_checkpoint(self):
        checkpoint = torch.load(
            os.path.join(self.checkpoint_dir, 'checkpoint.pt'))
        self.start_iteration = checkpoint['iteration']
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.validation_accuracies.replace(checkpoint['best_accuracy'])

    def register_extra_parameters(self, model):
        for module in model.modules():
            if isinstance(module, TaskNormI):
                module.register_extra_weights()
Example #4
0
class Learner:
    def __init__(self):
        self.args = self.parse_command_line()

        self.checkpoint_dir, self.logfile, self.checkpoint_path_validation, self.checkpoint_path_final \
            = get_log_files(self.args.checkpoint_dir, self.args.resume_from_checkpoint, self.args.mode == "test")

        print_and_log(self.logfile, "Options: %s\n" % self.args)
        print_and_log(self.logfile,
                      "Checkpoint Directory: %s\n" % self.checkpoint_dir)

        gpu_device = 'cuda:0'
        self.device = torch.device(
            gpu_device if torch.cuda.is_available() else 'cpu')
        self.model = self.init_model()
        self.train_set, self.validation_set, self.test_set = self.init_data()
        if self.args.dataset == "meta-dataset":
            self.dataset = MetaDatasetReader(
                self.args.data_path, self.args.mode, self.train_set,
                self.validation_set, self.test_set, self.args.max_way_train,
                self.args.max_way_test, self.args.max_support_train,
                self.args.max_support_test)
        else:
            self.dataset = SingleDatasetReader(self.args.data_path,
                                               self.args.mode,
                                               self.args.dataset,
                                               self.args.way, self.args.shot,
                                               self.args.query_train,
                                               self.args.query_test)
        self.loss = loss
        self.accuracy_fn = aggregate_accuracy
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.args.learning_rate)
        self.validation_accuracies = ValidationAccuracies(self.validation_set)
        self.start_iteration = 0
        if self.args.resume_from_checkpoint:
            self.load_checkpoint()
        self.optimizer.zero_grad()

    def init_model(self):
        use_two_gpus = self.use_two_gpus()
        model = Cnaps(device=self.device,
                      use_two_gpus=use_two_gpus,
                      args=self.args).to(self.device)
        self.register_extra_parameters(model)

        # set encoder is always in train mode (it only sees context data).
        model.train()
        # Feature extractor is in eval mode by default, but gets switched in model depending on args.batch_normalization
        model.feature_extractor.eval()
        if use_two_gpus:
            model.distribute_model()
        return model

    def init_data(self):
        if self.args.dataset == "meta-dataset":
            train_set = [
                'ilsvrc_2012', 'omniglot', 'aircraft', 'cu_birds', 'dtd',
                'quickdraw', 'fungi', 'vgg_flower'
            ]
            validation_set = [
                'ilsvrc_2012', 'omniglot', 'aircraft', 'cu_birds', 'dtd',
                'quickdraw', 'fungi', 'vgg_flower', 'mscoco'
            ]
            test_set = self.args.test_datasets
        else:
            train_set = [self.args.dataset]
            validation_set = [self.args.dataset]
            test_set = [self.args.dataset]

        return train_set, validation_set, test_set

    """
    Command line parser
    """

    def parse_command_line(self):
        parser = argparse.ArgumentParser()

        parser.add_argument("--dataset",
                            choices=[
                                "meta-dataset", "ilsvrc_2012", "omniglot",
                                "aircraft", "cu_birds", "dtd", "quickdraw",
                                "fungi", "vgg_flower", "traffic_sign",
                                "mscoco", "mnist", "cifar10", "cifar100"
                            ],
                            default="meta-dataset",
                            help="Dataset to use.")
        parser.add_argument('--test_datasets',
                            nargs='+',
                            help='Datasets to use for testing',
                            default=[
                                "ilsvrc_2012", "omniglot", "aircraft",
                                "cu_birds", "dtd", "quickdraw", "fungi",
                                "vgg_flower", "traffic_sign", "mscoco",
                                "mnist", "cifar10", "cifar100"
                            ])
        parser.add_argument("--data_path",
                            default="../datasets",
                            help="Path to dataset records.")
        parser.add_argument("--pretrained_resnet_path",
                            default="../models/pretrained_resnet.pt.tar",
                            help="Path to pretrained feature extractor model.")
        parser.add_argument(
            "--mode",
            choices=["train", "test", "train_test"],
            default="train_test",
            help=
            "Whether to run training only, testing only, or both training and testing."
        )
        parser.add_argument("--learning_rate",
                            "-lr",
                            type=float,
                            default=5e-4,
                            help="Learning rate.")
        parser.add_argument(
            "--tasks_per_batch",
            type=int,
            default=16,
            help="Number of tasks between parameter optimizations.")
        parser.add_argument("--checkpoint_dir",
                            "-c",
                            default='../checkpoints',
                            help="Directory to save checkpoint to.")
        parser.add_argument("--test_model_path",
                            "-m",
                            default=None,
                            help="Path to model to load and test.")
        parser.add_argument(
            "--feature_adaptation",
            choices=["no_adaptation", "film", "film+ar"],
            default="film",
            help="Method to adapt feature extractor parameters.")
        parser.add_argument("--batch_normalization",
                            choices=["basic", "task_norm-i"],
                            default="basic",
                            help="Normalization layer to use.")
        parser.add_argument("--training_iterations",
                            "-i",
                            type=int,
                            default=110000,
                            help="Number of meta-training iterations.")
        parser.add_argument("--val_freq",
                            type=int,
                            default=10000,
                            help="Number of iterations between validations.")
        parser.add_argument(
            "--max_way_train",
            type=int,
            default=40,
            help="Maximum way of meta-dataset meta-train task.")
        parser.add_argument("--max_way_test",
                            type=int,
                            default=50,
                            help="Maximum way of meta-dataset meta-test task.")
        parser.add_argument(
            "--max_support_train",
            type=int,
            default=400,
            help="Maximum support set size of meta-dataset meta-train task.")
        parser.add_argument(
            "--max_support_test",
            type=int,
            default=500,
            help="Maximum support set size of meta-dataset meta-test task.")
        parser.add_argument("--resume_from_checkpoint",
                            "-r",
                            dest="resume_from_checkpoint",
                            default=False,
                            action="store_true",
                            help="Restart from latest checkpoint.")
        parser.add_argument("--way",
                            type=int,
                            default=5,
                            help="Way of single dataset task.")
        parser.add_argument(
            "--shot",
            type=int,
            default=1,
            help="Shots per class for context of single dataset task.")
        parser.add_argument(
            "--query_train",
            type=int,
            default=10,
            help="Shots per class for target  of single dataset task.")
        parser.add_argument(
            "--query_test",
            type=int,
            default=10,
            help="Shots per class for target  of single dataset task.")

        args = parser.parse_args()

        return args

    def run(self):
        if self.args.mode == 'train' or self.args.mode == 'train_test':
            train_accuracies = []
            losses = []
            total_iterations = self.args.training_iterations
            for iteration in range(self.start_iteration, total_iterations):
                torch.set_grad_enabled(True)
                task_dict = self.dataset.get_train_task()
                task_loss, task_accuracy = self.train_task(task_dict)
                train_accuracies.append(task_accuracy)
                losses.append(task_loss)

                # optimize
                if ((iteration + 1) % self.args.tasks_per_batch
                        == 0) or (iteration == (total_iterations - 1)):
                    self.optimizer.step()
                    self.optimizer.zero_grad()

                if (iteration + 1) % PRINT_FREQUENCY == 0:
                    # print training stats
                    print_and_log(
                        self.logfile,
                        'Task [{}/{}], Train Loss: {:.7f}, Train Accuracy: {:.7f}'
                        .format(iteration + 1, total_iterations,
                                torch.Tensor(losses).mean().item(),
                                torch.Tensor(train_accuracies).mean().item()))
                    train_accuracies = []
                    losses = []

                if ((iteration + 1) % self.args.val_freq
                        == 0) and (iteration + 1) != total_iterations:
                    # validate
                    accuracy_dict = self.validate()
                    self.validation_accuracies.print(self.logfile,
                                                     accuracy_dict)
                    # save the model if validation is the best so far
                    if self.validation_accuracies.is_better(accuracy_dict):
                        self.validation_accuracies.replace(accuracy_dict)
                        torch.save(self.model.state_dict(),
                                   self.checkpoint_path_validation)
                        print_and_log(self.logfile,
                                      'Best validation model was updated.')
                        print_and_log(self.logfile, '')
                    self.save_checkpoint(iteration + 1)

            # save the final model
            torch.save(self.model.state_dict(), self.checkpoint_path_final)

        if self.args.mode == 'train_test':
            self.test(self.checkpoint_path_final)
            self.test(self.checkpoint_path_validation)

        if self.args.mode == 'test':
            self.test(self.args.test_model_path)

        self.logfile.close()

    def train_task(self, task_dict):
        context_images, target_images, context_labels, target_labels = self.prepare_task(
            task_dict)

        target_logits = self.model(context_images, context_labels,
                                   target_images)
        task_loss = self.loss(target_logits, target_labels,
                              self.device) / self.args.tasks_per_batch
        if self.args.feature_adaptation == 'film' or self.args.feature_adaptation == 'film+ar':
            if self.use_two_gpus():
                regularization_term = (self.model.feature_adaptation_network.
                                       regularization_term()).cuda(0)
            else:
                regularization_term = (self.model.feature_adaptation_network.
                                       regularization_term())
            regularizer_scaling = 0.001
            task_loss += regularizer_scaling * regularization_term
        task_accuracy = self.accuracy_fn(target_logits, target_labels)

        task_loss.backward(retain_graph=False)

        return task_loss, task_accuracy

    def validate(self):
        with torch.no_grad():
            accuracy_dict = {}
            for item in self.validation_set:
                accuracies = []
                for _ in range(NUM_VALIDATION_TASKS):
                    task_dict = self.dataset.get_validation_task(item)
                    context_images, target_images, context_labels, target_labels = self.prepare_task(
                        task_dict)
                    target_logits = self.model(context_images, context_labels,
                                               target_images)
                    accuracy = self.accuracy_fn(target_logits, target_labels)
                    accuracies.append(accuracy.item())
                    del target_logits

                accuracy = np.array(accuracies).mean() * 100.0
                confidence = (196.0 * np.array(accuracies).std()) / np.sqrt(
                    len(accuracies))

                accuracy_dict[item] = {
                    "accuracy": accuracy,
                    "confidence": confidence
                }

        return accuracy_dict

    def test(self, path):
        print_and_log(self.logfile, "")  # add a blank line
        print_and_log(self.logfile, 'Testing model {0:}: '.format(path))
        self.model = self.init_model()
        self.model.load_state_dict(torch.load(path))

        with torch.no_grad():
            for item in self.test_set:
                accuracies = []
                for _ in range(NUM_TEST_TASKS):
                    task_dict = self.dataset.get_test_task(item)
                    context_images, target_images, context_labels, target_labels = self.prepare_task(
                        task_dict)
                    target_logits = self.model(context_images, context_labels,
                                               target_images)
                    accuracy = self.accuracy_fn(target_logits, target_labels)
                    accuracies.append(accuracy.item())
                    del target_logits

                accuracy = np.array(accuracies).mean() * 100.0
                accuracy_confidence = (196.0 *
                                       np.array(accuracies).std()) / np.sqrt(
                                           len(accuracies))

                print_and_log(
                    self.logfile,
                    '{0:}: {1:3.1f}+/-{2:2.1f}'.format(item, accuracy,
                                                       accuracy_confidence))

    def prepare_task(self, task_dict):
        context_images_np, context_labels_np = task_dict[
            'context_images'], task_dict['context_labels']
        target_images_np, target_labels_np = task_dict[
            'target_images'], task_dict['target_labels']

        context_images_np = context_images_np.transpose([0, 3, 1, 2])
        context_images_np, context_labels_np = self.shuffle(
            context_images_np, context_labels_np)
        context_images = torch.from_numpy(context_images_np)
        context_labels = torch.from_numpy(context_labels_np)

        target_images_np = target_images_np.transpose([0, 3, 1, 2])
        target_images_np, target_labels_np = self.shuffle(
            target_images_np, target_labels_np)
        target_images = torch.from_numpy(target_images_np)
        target_labels = torch.from_numpy(target_labels_np)

        context_images = context_images.to(self.device)
        target_images = target_images.to(self.device)
        context_labels = context_labels.to(self.device)
        target_labels = target_labels.type(torch.LongTensor).to(self.device)

        return context_images, target_images, context_labels, target_labels

    def shuffle(self, images, labels):
        """
        Return shuffled data.
        """
        permutation = np.random.permutation(images.shape[0])
        return images[permutation], labels[permutation]

    def use_two_gpus(self):
        use_two_gpus = False
        if self.args.dataset == "meta-dataset":
            if self.args.feature_adaptation == "film+ar" or \
               self.args.batch_normalization == "task_norm-i":
                use_two_gpus = True  # These models do not fit on one GPU, so use model parallelism.

        return use_two_gpus

    def save_checkpoint(self, iteration):
        torch.save(
            {
                'iteration':
                iteration,
                'model_state_dict':
                self.model.state_dict(),
                'optimizer_state_dict':
                self.optimizer.state_dict(),
                'best_accuracy':
                self.validation_accuracies.get_current_best_accuracy_dict(),
            }, os.path.join(self.checkpoint_dir, 'checkpoint.pt'))

    def load_checkpoint(self):
        checkpoint = torch.load(
            os.path.join(self.checkpoint_dir, 'checkpoint.pt'))
        self.start_iteration = checkpoint['iteration']
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.validation_accuracies.replace(checkpoint['best_accuracy'])

    def register_extra_parameters(self, model):
        for module in model.modules():
            if isinstance(module, TaskNormI):
                module.register_extra_weights()
Example #5
0
class Learner:
    def __init__(self):
        self.args = self.parse_command_line()

        self.checkpoint_dir, self.logfile, self.checkpoint_path_validation, self.checkpoint_path_final \
            = get_log_files(self.args.checkpoint_dir)

        print_and_log(self.logfile, "Options: %s\n" % self.args)
        print_and_log(self.logfile, "Checkpoint Directory: %s\n" % self.checkpoint_dir)

        gpu_device = 'cuda:0'
        self.device = torch.device(gpu_device if torch.cuda.is_available() else 'cpu')
        self.model = self.init_model()
        self.train_set, self.validation_set, self.test_set = self.init_data()
        self.metadataset = MetaDatasetReader(self.args.data_path, self.args.mode, self.train_set, self.validation_set,
                                             self.test_set)
        self.loss = loss
        self.accuracy_fn = aggregate_accuracy
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        self.optimizer.zero_grad()
        self.validation_accuracies = ValidationAccuracies(self.validation_set)

    def init_model(self):
        use_two_gpus = self.use_two_gpus()
        model = Cnaps(device=self.device, use_two_gpus=use_two_gpus, args=self.args).to(self.device)
        model.train()  # set encoder is always in train mode to process context data
        model.feature_extractor.eval()  # feature extractor is always in eval mode
        if use_two_gpus:
            model.distribute_model()
        return model

    def init_data(self):

        train_set = ['ilsvrc_2012', 'omniglot', 'aircraft', 'cu_birds', 'dtd', 'quickdraw', 'fungi', 'vgg_flower']
        validation_set = ['ilsvrc_2012', 'omniglot', 'aircraft', 'cu_birds', 'dtd', 'quickdraw', 'fungi', 'vgg_flower',
                          'mscoco']
        test_set = ['ilsvrc_2012', 'omniglot', 'aircraft', 'cu_birds', 'dtd', 'quickdraw', 'fungi', 'vgg_flower',
                    'traffic_sign', 'mscoco', 'mnist', 'cifar10', 'cifar100']

        return train_set, validation_set, test_set

    """
    Command line parser
    """
    def parse_command_line(self):
        parser = argparse.ArgumentParser()

        parser.add_argument("--data_path", default="../datasets", help="Path to dataset records.")
        parser.add_argument("--pretrained_resnet_path", default="../models/pretrained_resnet.pt.tar",
                            help="Path to pretrained feature extractor model.")
        parser.add_argument("--mode", choices=["train", "test", "train_test"], default="train_test",
                            help="Whether to run training only, testing only, or both training and testing.")
        parser.add_argument("--learning_rate", "-lr", type=float, default=5e-4, help="Learning rate.")
        parser.add_argument("--tasks_per_batch", type=int, default=16,
                            help="Number of tasks between parameter optimizations.")
        parser.add_argument("--checkpoint_dir", "-c", default='../checkpoints', help="Directory to save checkpoint to.")
        parser.add_argument("--test_model_path", "-m", default=None, help="Path to model to load and test.")
        parser.add_argument("--feature_adaptation", choices=["no_adaptation", "film", "film+ar"], default="film+ar",
                            help="Method to adapt feature extractor parameters.")

        args = parser.parse_args()

        return args

    def run(self):
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        with tf.compat.v1.Session(config=config) as session:
            if self.args.mode == 'train' or self.args.mode == 'train_test':
                train_accuracies = []
                losses = []
                total_iterations = NUM_TRAIN_TASKS
                for iteration in range(total_iterations):
                    torch.set_grad_enabled(True)
                    task_dict = self.metadataset.get_train_task(session)
                    task_loss, task_accuracy = self.train_task(task_dict)
                    train_accuracies.append(task_accuracy)
                    losses.append(task_loss)

                    # optimize
                    if ((iteration + 1) % self.args.tasks_per_batch == 0) or (iteration == (total_iterations - 1)):
                        self.optimizer.step()
                        self.optimizer.zero_grad()

                    if (iteration + 1) % 1000 == 0:
                        # print training stats
                        print_and_log(self.logfile,'Task [{}/{}], Train Loss: {:.7f}, Train Accuracy: {:.7f}'
                                      .format(iteration + 1, total_iterations, torch.Tensor(losses).mean().item(),
                                              torch.Tensor(train_accuracies).mean().item()))
                        train_accuracies = []
                        losses = []

                    if ((iteration + 1) % VALIDATION_FREQUENCY == 0) and (iteration + 1) != total_iterations:
                        # validate
                        accuracy_dict = self.validate(session)
                        self.validation_accuracies.print(self.logfile, accuracy_dict)
                        # save the model if validation is the best so far
                        if self.validation_accuracies.is_better(accuracy_dict):
                            self.validation_accuracies.replace(accuracy_dict)
                            torch.save(self.model.state_dict(), self.checkpoint_path_validation)
                            print_and_log(self.logfile, 'Best validation model was updated.')
                            print_and_log(self.logfile, '')

                # save the final model
                torch.save(self.model.state_dict(), self.checkpoint_path_final)

            if self.args.mode == 'train_test':
                self.test(self.checkpoint_path_final, session)
                self.test(self.checkpoint_path_validation, session)

            if self.args.mode == 'test':
                self.test(self.args.test_model_path, session)

            self.logfile.close()

    def train_task(self, task_dict):
        context_images, target_images, context_labels, target_labels = self.prepare_task(task_dict)

        target_logits = self.model(context_images, context_labels, target_images)
        task_loss = self.loss(target_logits, target_labels, self.device) / self.args.tasks_per_batch
        if self.args.feature_adaptation == 'film' or self.args.feature_adaptation == 'film+ar':
            if self.use_two_gpus():
                regularization_term = (self.model.feature_adaptation_network.regularization_term()).cuda(0)
            else:
                regularization_term = (self.model.feature_adaptation_network.regularization_term())
            regularizer_scaling = 0.001
            task_loss += regularizer_scaling * regularization_term
        task_accuracy = self.accuracy_fn(target_logits, target_labels)

        task_loss.backward(retain_graph=False)

        return task_loss, task_accuracy

    def validate(self, session):
        with torch.no_grad():
            accuracy_dict ={}
            for item in self.validation_set:
                accuracies = []
                for _ in range(NUM_VALIDATION_TASKS):
                    task_dict = self.metadataset.get_validation_task(item, session)
                    context_images, target_images, context_labels, target_labels = self.prepare_task(task_dict)
                    target_logits = self.model(context_images, context_labels, target_images)
                    accuracy = self.accuracy_fn(target_logits, target_labels)
                    accuracies.append(accuracy.item())
                    del target_logits

                accuracy = np.array(accuracies).mean() * 100.0
                confidence = (196.0 * np.array(accuracies).std()) / np.sqrt(len(accuracies))

                accuracy_dict[item] = {"accuracy": accuracy, "confidence": confidence}

        return accuracy_dict

    def test(self, path, session):
        self.model = self.init_model()
        self.model.load_state_dict(torch.load(path))

        print_and_log(self.logfile, "")  # add a blank line
        print_and_log(self.logfile, 'Testing model {0:}: '.format(path))

        with torch.no_grad():
            for item in self.test_set:
                accuracies = []
                for _ in range(NUM_TEST_TASKS):
                    task_dict = self.metadataset.get_test_task(item, session)
                    context_images, target_images, context_labels, target_labels = self.prepare_task(task_dict)
                    target_logits = self.model(context_images, context_labels, target_images)
                    accuracy = self.accuracy_fn(target_logits, target_labels)
                    accuracies.append(accuracy.item())
                    del target_logits

                accuracy = np.array(accuracies).mean() * 100.0
                accuracy_confidence = (196.0 * np.array(accuracies).std()) / np.sqrt(len(accuracies))

                print_and_log(self.logfile, '{0:}: {1:3.1f}+/-{2:2.1f}'.format(item, accuracy, accuracy_confidence))

    def prepare_task(self, task_dict):
        context_images_np, context_labels_np = task_dict['context_images'], task_dict['context_labels']
        target_images_np, target_labels_np = task_dict['target_images'], task_dict['target_labels']

        context_images_np = context_images_np.transpose([0, 3, 1, 2])
        context_images_np, context_labels_np = self.shuffle(context_images_np, context_labels_np)
        context_images = torch.from_numpy(context_images_np)
        context_labels = torch.from_numpy(context_labels_np)

        target_images_np = target_images_np.transpose([0, 3, 1, 2])
        target_images_np, target_labels_np = self.shuffle(target_images_np, target_labels_np)
        target_images = torch.from_numpy(target_images_np)
        target_labels = torch.from_numpy(target_labels_np)

        context_images = context_images.to(self.device)
        target_images = target_images.to(self.device)
        context_labels = context_labels.to(self.device)
        target_labels = target_labels.type(torch.LongTensor).to(self.device)

        return context_images, target_images, context_labels, target_labels

    def shuffle(self, images, labels):
        """
        Return shuffled data.
        """
        permutation = np.random.permutation(images.shape[0])
        return images[permutation], labels[permutation]

    def use_two_gpus(self):
        use_two_gpus = False
        if self.args.feature_adaptation == "film+ar":
            use_two_gpus = True  # film+ar model does not fit on one GPU, so use model parallelism
        return use_two_gpus