def create_model_and_load_chkpt(args, dataset_name, checkpoint_path):

    print("\n", "--" * 20, "MODEL", "--" * 20)

    if args.model_type == 'resnet_12':
        if 'miniImagenet' in dataset_name or 'CUB' in dataset_name:
            model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool),
                                       drop_rate=0.1,
                                       dropblock_size=5,
                                       num_classes=args.num_classes_train,
                                       classifier_type=args.classifier_type,
                                       projection=str2bool(args.projection))
        else:
            model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool),
                                       drop_rate=0.1,
                                       dropblock_size=2,
                                       num_classes=args.num_classes_train,
                                       classifier_type=args.classifier_type,
                                       projection=str2bool(args.projection))
    elif args.model_type in ['conv64', 'conv48', 'conv32']:
        dim = int(args.model_type[-2:])
        model = shallow_conv.ShallowConv(z_dim=dim,
                                         h_dim=dim,
                                         num_classes=args.num_classes_train,
                                         x_width=image_size,
                                         classifier_type=args.classifier_type,
                                         projection=str2bool(args.projection))
    elif args.model_type == 'wide_resnet28_10':
        model = wide_resnet.wrn28_10(projection=str2bool(args.projection),
                                     classifier_type=args.classifier_type)
    elif args.model_type == 'wide_resnet16_10':
        model = wide_resnet.wrn16_10(projection=str2bool(args.projection),
                                     classifier_type=args.classifier_type)
    else:
        raise ValueError('Unrecognized model type {}'.format(args.model_type))
    print("Model\n" + "==" * 27)
    print(model)

    print(f"loading model from {checkpoint_path}")
    model_dict = model.state_dict()
    chkpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    chkpt_state_dict = chkpt['model']
    chkpt_state_dict_cpy = chkpt_state_dict.copy()
    # remove "module." from key, possibly present as it was dumped by data-parallel
    for key in chkpt_state_dict_cpy.keys():
        if 'module.' in key:
            new_key = re.sub('module\.', '', key)
            chkpt_state_dict[new_key] = chkpt_state_dict.pop(key)
    chkpt_state_dict = {
        k: v
        for k, v in chkpt_state_dict.items() if k in model_dict
    }
    model_dict.update(chkpt_state_dict)
    updated_keys = set(model_dict).intersection(set(chkpt_state_dict))
    print(f"Updated {len(updated_keys)} keys using chkpt")
    print("Following keys updated :", "\n".join(sorted(updated_keys)))
    missed_keys = set(model_dict).difference(set(chkpt_state_dict))
    print(f"Missed {len(missed_keys)} keys")
    print("Following keys missed :", "\n".join(sorted(missed_keys)))
    model.load_state_dict(model_dict)

    # Multi-gpu support and device setup
    os.environ["CUDA_VISIBLE_DEVICES"] = args.device_number
    print('Using GPUs: ', os.environ["CUDA_VISIBLE_DEVICES"])
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model.cuda()

    return model
コード例 #2
0
def main(args):

    ####################################################
    #                LOGGING AND SAVING                #
    ####################################################
    args.output_folder = ensure_path('./runs/{0}'.format(args.output_folder))
    writer = SummaryWriter(args.output_folder)
    with open(f'{args.output_folder}/config.txt', 'w') as config_txt:
        for k, v in sorted(vars(args).items()):
            config_txt.write(f'{k}: {v}\n')
    save_folder = args.output_folder

    ####################################################
    #         DATASET AND DATALOADER CREATION          #
    ####################################################

    # json paths
    dataset_name = args.dataset_path.split('/')[-1]
    image_size = args.img_side_len
    # Following is needed when same train config is used for both 5w5s and 5w1s evaluations.
    # This is the case in the case of SVM when 5w15s5q is used for both 5w5s and 5w1s evaluations.
    all_n_shot_vals = [args.n_shot_val, 1] if str2bool(
        args.do_one_shot_eval_too) else [args.n_shot_val]
    # base_class_generalization = dataset_name.lower() in ['miniimagenet', 'fc100-base', 'cifar-fs-base']
    train_file = os.path.join(args.dataset_path, 'base.json')
    val_file = os.path.join(args.dataset_path, 'val.json')
    test_file = os.path.join(args.dataset_path, 'novel.json')
    '''
    if base_class_generalization:
        base_test_file = os.path.join(args.dataset_path, 'base_test.json')
    '''
    print("Dataset name", dataset_name, "image_size", image_size)
    if args.algorithm != 'SupervisedBaseline':
        print("all_n_shot_vals", all_n_shot_vals)
    # print("base_class_generalization:", base_class_generalization)
    """
    1. Create FedDataset object, which handles preloading of images for every single client
    2. Create FedDataLoader object from FedDataset, which samples a batch of clients.
    """

    print("\n", "--" * 20, "TRAIN", "--" * 20)
    if args.algorithm in ["SupervisedBaseline", "TransferLearning"]:
        """
        For Transfer Learning we create a SimpleFedDataset.
        The augmentation is decided by query_aug flag.
        """

        train_dataset = SimpleFedDataset(
            json_path=train_file,
            image_size=(image_size, image_size),  # has to be a (h, w) tuple
            preload=str2bool(args.preload_train))

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size_train,
            shuffle=True,
            num_workers=6)
    else:
        train_meta_dataset = FedDataset(
            json_path=train_file,
            n_shot_per_class=args.n_shot_train,
            n_query_per_class=args.n_query_train,
            image_size=(image_size, image_size),  # has to be a (h, w) tuple
            randomize_query=str2bool(args.randomize_query),
            preload=str2bool(args.preload_train),
            fixed_sq=str2bool(args.fixed_sq))

        train_loader = FedDataLoader(dataset=train_meta_dataset,
                                     n_batches=args.n_iters_per_epoch,
                                     batch_size=args.batch_size_train)

    print("\n", "--" * 20, "VAL", "--" * 20)
    if args.algorithm in ["SupervisedBaseline", "TransferLearning"]:
        val_dataset = SimpleFedDataset(
            json_path=val_file,
            image_size=(image_size, image_size),  # has to be a (h, w) tuple
            preload=True)

        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=args.batch_size_val,
            shuffle=False,
            drop_last=False,
            num_workers=6)
    else:
        val_meta_datasets = {}
        val_loaders = {}
        for ns_val in all_n_shot_vals:
            print("====", f"n_shots_val {ns_val}", "====")
            val_meta_datasets[ns_val] = FedDataset(
                json_path=val_file,
                n_shot_per_class=ns_val,
                n_query_per_class=args.n_query_val,
                image_size=(image_size, image_size),
                randomize_query=False,
                preload=True,
                fixed_sq=str2bool(args.fixed_sq))
            val_loaders[ns_val] = FedDataLoader(
                dataset=val_meta_datasets[ns_val],
                n_batches=args.n_iterations_val,
                batch_size=args.batch_size_val)

    print("\n", "--" * 20, "TEST", "--" * 20)
    if args.algorithm in ["SupervisedBaseline", "TransferLearning"]:
        test_dataset = SimpleFedDataset(
            json_path=test_file,
            image_size=(image_size, image_size),  # has to be a (h, w) tuple
            preload=True)

        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=args.batch_size_val,
            shuffle=False,
            drop_last=False,
            num_workers=6)
    else:
        test_meta_datasets = {}
        test_loaders = {}
        for ns_val in all_n_shot_vals:
            print("====", f"n_shots_val {ns_val}", "====")
            test_meta_datasets[ns_val] = FedDataset(
                json_path=test_file,
                n_shot_per_class=ns_val,
                n_query_per_class=args.n_query_val,
                image_size=(image_size, image_size),
                randomize_query=False,
                preload=True,
                fixed_sq=str2bool(args.fixed_sq))
            test_loaders[ns_val] = FedDataLoader(
                dataset=test_meta_datasets[ns_val],
                n_batches=args.n_iterations_val,
                batch_size=args.batch_size_val)
    '''
    # currently for fedlearn_datasets there is no notion of base_class because the classes
    # used for base, val, test are already the same.
    if base_class_generalization:
        # can only do this if there is only one type of evaluation
        print("\n", "--"*20, "BASE TEST", "--"*20)
        base_test_classes = ClassImagesSet(base_test_file)
        base_test_meta_dataset = MetaDataset(
                                    dataset_name=dataset_name,
                                    support_class_images_set=base_test_classes,
                                    query_class_images_set=base_test_classes,
                                    image_size=image_size,
                                    support_aug=False,
                                    query_aug=False,
                                    fix_support=0,
                                    save_folder=save_folder)
        base_test_loader = MetaDataLoader(
                                dataset=base_test_meta_dataset,
                                n_batches=args.n_iterations_val,
                                batch_size=args.batch_size_val,
                                n_way=args.n_way_val,
                                n_shot=args.n_shot_val,
                                n_query=args.n_query_val, 
                                randomize_query=False)


        if args.fix_support > 0:
            base_test_meta_dataset_using_fixS = MetaDataset(
                                                    dataset_name=dataset_name,
                                                    support_class_images_set=train_classes, query_class_images_set=base_test_classes,
                                                    image_size=image_size,
                                                    support_aug=False,
                                                    query_aug=False,
                                                    fix_support=0,
                                                    save_folder=save_folder, 
                                                    fix_support_path=os.path.join(save_folder, "fixed_support_pool.pkl"))

            base_test_loader_using_fixS = MetaDataLoader(
                                            dataset=base_test_meta_dataset_using_fixS,
                                            n_batches=args.n_iterations_val,
                                            batch_size=args.batch_size_val,
                                            n_way=args.n_way_val,
                                            n_shot=args.n_shot_val,
                                            n_query=args.n_query_val, 
                                            randomize_query=False,)
    '''

    ####################################################
    #             MODEL/BACKBONE CREATION              #
    ####################################################

    print("\n", "--" * 20, "MODEL", "--" * 20)

    if args.model_type == 'resnet_12':
        if 'miniImagenet' in dataset_name or 'CUB' in dataset_name or 'celeba' in dataset_name:
            model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool),
                                       drop_rate=0.1,
                                       dropblock_size=5,
                                       num_classes=args.num_classes_train,
                                       classifier_type=args.classifier_type,
                                       projection=str2bool(args.projection),
                                       learnable_scale=str2bool(
                                           args.learnable_scale))
        else:
            model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool),
                                       drop_rate=0.1,
                                       dropblock_size=2,
                                       num_classes=args.num_classes_train,
                                       classifier_type=args.classifier_type,
                                       projection=str2bool(args.projection),
                                       learnable_scale=str2bool(
                                           args.learnable_scale))

    elif args.model_type in ['conv64', 'conv48', 'conv32']:
        dim = int(args.model_type[-2:])
        model = shallow_conv.ShallowConv(z_dim=dim,
                                         h_dim=dim,
                                         num_classes=args.num_classes_train,
                                         classifier_type=args.classifier_type,
                                         projection=str2bool(args.projection),
                                         learnable_scale=str2bool(
                                             args.learnable_scale))
    elif args.model_type == 'wide_resnet28_10':
        model = wide_resnet.wrn28_10(projection=str2bool(args.projection),
                                     classifier_type=args.classifier_type,
                                     learnable_scale=str2bool(
                                         args.learnable_scale))
    elif args.model_type == 'wide_resnet16_10':
        model = wide_resnet.wrn16_10(projection=str2bool(args.projection),
                                     classifier_type=args.classifier_type,
                                     learnable_scale=str2bool(
                                         args.learnable_scale))
    else:
        raise ValueError('Unrecognized model type {}'.format(args.model_type))
    print("Model\n" + "==" * 27)
    print(model)

    ####################################################
    #                OPTIMIZER CREATION                #
    ####################################################

    # optimizer construction
    print("\n", "--" * 20, "OPTIMIZER", "--" * 20)
    print("Optimzer", args.optimizer_type)
    if args.optimizer_type == 'adam':
        optimizer = torch.optim.Adam([{
            'params': model.parameters(),
            'lr': args.lr,
            'weight_decay': args.weight_decay
        }])
    else:
        optimizer = modified_sgd.SGD([
            {
                'params': model.parameters(),
                'lr': args.lr,
                'weight_decay': args.weight_decay,
                'momentum': 0.9,
                'nesterov': True
            },
        ])
    print("Total n_epochs: ", args.n_epochs)

    # learning rate scheduler creation
    if args.lr_scheduler_type == 'deterministic':
        drop_eps = [int(x) for x in args.drop_lr_epoch.split(',')]
        if args.drop_factors != '':
            drop_factors = [float(x) for x in args.drop_factors.split(',')]
        else:
            drop_factors = [0.06, 0.012, 0.0024]
        assert len(drop_factors) >= len(drop_eps), "No ennough drop factors"
        print("Drop lr at epochs", drop_eps)
        print("Drop factors", drop_factors[:len(drop_eps)])
        assert len(
            drop_eps
        ) <= 3, "Must give less than or equal to three epochs to drop lr"
        if len(drop_eps) == 3:
            lambda_epoch = lambda e: 1.0 if e < drop_eps[0] else (
                drop_factors[0] if e < drop_eps[1] else drop_factors[1]
                if e < drop_eps[2] else (drop_factors[2]))
        elif len(drop_eps) == 2:
            lambda_epoch = lambda e: 1.0 if e < drop_eps[0] else (drop_factors[
                0] if e < drop_eps[1] else drop_factors[1])
        else:
            lambda_epoch = lambda e: 1.0 if e < drop_eps[0] else drop_factors[0
                                                                              ]

        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=lambda_epoch, last_epoch=-1)
        for _ in range(args.restart_iter):
            lr_scheduler.step()

    elif args.lr_scheduler_type == 'val_based':
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='max',
            patience=5,
            factor=0.1,
            min_lr=5e-6,
            threshold=0.5)

    else:
        raise ValueError("Unimplemented lr scheduler")

    print("LR scheduler ", args.lr_scheduler_type)

    ####################################################
    #                LOAD FROM CHECKPOINT              #
    ####################################################

    if args.checkpoint != '':
        print(f"loading model from {args.checkpoint}")
        model_dict = model.state_dict()
        chkpt = torch.load(args.checkpoint, map_location=torch.device('cpu'))
        try:
            print(f"loading optimizer from {args.checkpoint}")
            optimizer.state = chkpt['optimizer'].state
            print("Successfully loaded optimizer")
        except:
            print("Failed to load optimizer")
        chkpt_state_dict = chkpt['model']
        chkpt_state_dict_cpy = chkpt_state_dict.copy()
        # remove "module." from key, possibly present as it was dumped by data-parallel
        for key in chkpt_state_dict_cpy.keys():
            if 'module.' in key:
                new_key = re.sub('module\.', '', key)
                chkpt_state_dict[new_key] = chkpt_state_dict.pop(key)
        chkpt_state_dict = {
            k: v
            for k, v in chkpt_state_dict.items() if k in model_dict
        }
        model_dict.update(chkpt_state_dict)
        updated_keys = set(model_dict).intersection(set(chkpt_state_dict))
        print(f"Updated {len(updated_keys)} keys using chkpt")
        print("Following keys updated :", "\n".join(sorted(updated_keys)))
        print()

        missed_keys = set(model_dict).difference(set(chkpt_state_dict))
        print(f"Missed {len(missed_keys)} keys")
        print("Following keys missed :", "\n".join(sorted(missed_keys)))
        model.load_state_dict(model_dict)

    # Multi-gpu support and device setup
    os.environ["CUDA_VISIBLE_DEVICES"] = args.device_number
    print('Using GPUs: ', os.environ["CUDA_VISIBLE_DEVICES"])
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model.cuda()

    ####################################################
    #        ALGORITHM AND ALGORITHM TRAINER           #
    ####################################################

    # start tboard from restart iter
    init_global_iteration = 0
    if args.restart_iter:
        init_global_iteration = args.restart_iter * args.n_iters_per_epoch

    # algorithm
    if args.algorithm == 'InitBasedAlgorithm':
        algorithm = InitBasedAlgorithm(
            model=model,
            loss_func=torch.nn.CrossEntropyLoss(),
            method=args.init_meta_algorithm,
            alpha=args.alpha,
            inner_loop_grad_clip=args.grad_clip_inner,
            inner_update_method=args.inner_update_method,
            device='cuda')
    elif args.algorithm == 'ProtoNet':
        algorithm = ProtoNet(model=model,
                             inner_loss_func=torch.nn.CrossEntropyLoss(),
                             device='cuda',
                             scale=args.scale_factor,
                             metric=args.classifier_metric)
    elif args.algorithm == 'SVM':
        algorithm = SVM(model=model,
                        inner_loss_func=torch.nn.CrossEntropyLoss(),
                        scale=args.scale_factor,
                        device='cuda')
    elif args.algorithm == 'Ridge':
        algorithm = Ridge(model=model,
                          inner_loss_func=torch.nn.CrossEntropyLoss(),
                          scale=args.scale_factor,
                          device='cuda')
    elif args.algorithm in ['TransferLearning', 'SupervisedBaseline']:
        """
        We use the ProtoNet algorithm at test time.
        """
        algorithm = ProtoNet(model=model,
                             inner_loss_func=torch.nn.CrossEntropyLoss(),
                             device='cuda',
                             scale=args.scale_factor,
                             metric=args.classifier_metric)
    else:
        raise ValueError('Unrecognized algorithm {}'.format(args.algorithm))

    if args.algorithm == 'InitBasedAlgorithm':
        trainer = Init_algorithm_trainer(
            algorithm=algorithm,
            optimizer=optimizer,
            writer=writer,
            log_interval=args.log_interval,
            save_folder=save_folder,
            grad_clip=args.grad_clip,
            num_updates_inner_train=args.num_updates_inner_train,
            num_updates_inner_val=args.num_updates_inner_val,
            init_global_iteration=init_global_iteration)
    elif args.algorithm in ['TransferLearning', 'SupervisedBaseline']:
        trainer = TL_algorithm_trainer(
            algorithm=algorithm,
            optimizer=optimizer,
            writer=writer,
            log_interval=args.log_interval,
            save_folder=save_folder,
            grad_clip=args.grad_clip,
            init_global_iteration=init_global_iteration)
    else:
        trainer = Meta_algorithm_trainer(
            algorithm=algorithm,
            optimizer=optimizer,
            writer=writer,
            log_interval=args.log_interval,
            save_folder=save_folder,
            grad_clip=args.grad_clip,
            init_global_iteration=init_global_iteration)

    ####################################################
    #                  TRAINER LOOP                    #
    ####################################################

    print("\n", "--" * 20, "BEGIN TRAINING", "--" * 20)

    # iterate over training epochs
    for iter_start in range(args.restart_iter, args.n_epochs):

        # training
        for param_group in optimizer.param_groups:
            print('\n\nlearning rate:', param_group['lr'])

        if args.algorithm in ['SupervisedBaseline']:
            trainer.run(mt_loader=train_loader,
                        evaluate_supervised_classification=True,
                        is_training=True,
                        epoch=iter_start + 1)

        else:
            trainer.run(mt_loader=train_loader,
                        is_training=True,
                        epoch=iter_start + 1)

        if iter_start % args.val_frequency == 0:
            '''
            # On ML train objective
            print("Train Loss on ML objective")
            results = trainer.run(
                mt_loader=no_fixS_train_loader, is_training=False)
            pp = pprint.PrettyPrinter(indent=4)
            pp.pprint(results)
            writer.add_scalar(
                "train_acc_on_ml", results['test_loss_after']['accu'], iter_start + 1)
            writer.add_scalar(
                "train_loss_on_ml", results['test_loss_after']['loss'], iter_start + 1)
            base_train_loss = results['test_loss_after']['loss']
            '''

            # validation/test
            if args.algorithm in ["SupervisedBaseline"]:
                print("Validation")
                results = trainer.run(mt_loader=val_loader,
                                      is_training=False,
                                      evaluate_supervised_classification=True)
                writer.add_scalar(f"val_acc",
                                  results['test_loss_after']['accu'],
                                  iter_start + 1)
                writer.add_scalar(f"val_loss",
                                  results['test_loss_after']['loss'],
                                  iter_start + 1)
                val_accu = results['test_loss_after']['accu']
            else:
                val_accus = {}
                for ns_val in all_n_shot_vals:
                    print("Validation ", f"n_shots_val {ns_val}")
                    results = trainer.run(mt_loader=val_loaders[ns_val],
                                          is_training=False)
                    writer.add_scalar(f"val_acc_{args.n_way_val}w{ns_val}s",
                                      results['test_loss_after']['accu'],
                                      iter_start + 1)
                    writer.add_scalar(f"val_loss_{args.n_way_val}w{ns_val}s",
                                      results['test_loss_after']['loss'],
                                      iter_start + 1)
                    val_accus[ns_val] = results['test_loss_after']['accu']
            pp = pprint.PrettyPrinter(indent=4)
            pp.pprint(results)

            if args.algorithm in ["SupervisedBaseline"]:
                print("Test")
                results = trainer.run(mt_loader=test_loader,
                                      is_training=False,
                                      evaluate_supervised_classification=True)
                writer.add_scalar(f"test_acc",
                                  results['test_loss_after']['accu'],
                                  iter_start + 1)
                writer.add_scalar(f"test_loss",
                                  results['test_loss_after']['loss'],
                                  iter_start + 1)
                novel_test_loss = results['test_loss_after']['loss']
            else:
                print("Test ", f"n_shots_val {ns_val}")
                results = trainer.run(mt_loader=test_loaders[ns_val],
                                      is_training=False)
                writer.add_scalar(f"test_acc_{args.n_way_val}w{ns_val}s",
                                  results['test_loss_after']['accu'],
                                  iter_start + 1)
                writer.add_scalar(f"test_loss_{args.n_way_val}w{ns_val}s",
                                  results['test_loss_after']['loss'],
                                  iter_start + 1)
                novel_test_losses[ns_val] = results['test_loss_after']['loss']
            pp = pprint.PrettyPrinter(indent=4)
            pp.pprint(results)

            if args.algorithm != "SupervisedBaseline":
                val_accu = val_accus[
                    args.n_shot_val]  # stick with 5w5s for model selection
                novel_test_loss = novel_test_losses[
                    args.n_shot_val]  # stick with 5w5s for model selection

            # base class generalization
            '''
            if base_class_generalization:
                # can only do this if there is only one type of evaluation
                
                print("Base Test")
                results = trainer.run(
                    mt_loader=base_test_loader, is_training=False)
                pp = pprint.PrettyPrinter(indent=4)
                pp.pprint(results)
                writer.add_scalar(
                    "base_test_acc", results['test_loss_after']['accu'], iter_start + 1)
                writer.add_scalar(
                    "base_test_loss", results['test_loss_after']['loss'], iter_start + 1)
                base_test_loss = results['test_loss_after']['loss']
                writer.add_scalar(
                    "base_gen_gap", base_test_loss - base_train_loss, iter_start + 1)
                writer.add_scalar(
                    "novel_gen_gap", novel_test_loss - base_train_loss, iter_start + 1)
            
                if args.fix_support > 0:
                    print("Base Test using FixSupport, matching train and test for fixml")
                    results = trainer.run(
                        mt_loader=base_test_loader_using_fixS, is_training=False)
                    pp = pprint.PrettyPrinter(indent=4)
                    pp.pprint(results)
                    writer.add_scalar(
                        "base_test_acc_usingFixS", results['test_loss_after']['accu'], iter_start + 1)
                    writer.add_scalar(
                        "base_test_loss_usingFixS", results['test_loss_after']['loss'], iter_start + 1)
            '''

        # scheduler step
        if args.lr_scheduler_type == 'val_based':
            assert args.val_frequency == 1, "eval after every epoch is mandatory for val based lr scheduler"
            lr_scheduler.step(val_accu)
        else:
            lr_scheduler.step()
コード例 #3
0
def main(args):

    ####################################################
    #                LOGGING AND SAVING                #
    ####################################################
    args.output_folder = ensure_path('./runs/{0}'.format(args.output_folder))
    if str2bool(args.eot_model):
        eval_results = f'{args.output_folder}/evaleot_results.txt'
    else:
        eval_results = f'{args.output_folder}/eval_results.txt'
    with open(eval_results, 'w') as f:
        f.write("--" * 20 + "EVALUATION RESULTS" + "--" * 20 + '\n')

    ####################################################
    #         DATASET AND DATALOADER CREATION          #
    ####################################################

    # json paths
    dataset_name = args.dataset_path.split('/')[-1]
    image_size = args.img_side_len
    dataset_name = args.dataset_path.split('/')[-1]
    # Following is needed when same train config is used for both 5w5s and 5w1s evaluations.
    # This is the case in the case of SVM when 5w15s5q is used for both 5w5s and 5w1s evaluations.
    all_n_shot_vals = [args.n_shot_val, 1] if str2bool(
        args.do_one_shot_eval_too) else [args.n_shot_val]
    base_class_generalization = dataset_name.lower() in [
        'miniimagenet', 'fc100-base', 'cifar-fs-base', 'tieredimagenet-base'
    ]
    train_file = os.path.join(args.dataset_path, 'base.json')
    val_file = os.path.join(args.dataset_path, 'val.json')
    test_file = os.path.join(args.dataset_path, 'novel.json')
    if base_class_generalization:
        base_test_file = os.path.join(args.dataset_path, 'base_test.json')
    print("Dataset name", dataset_name, "image_size", image_size,
          "all_n_shot_vals", all_n_shot_vals)
    print("base_class_generalization:", base_class_generalization)
    """
    1. Create ClassImagesSet object, which handles preloading of images
    2. Pass ClassImagesSet to MetaDataset which handles nshot, nquery and fixSupport
    3. Create Dataloader object from MetaDataset
    """

    print("\n", "--" * 20, "BASE", "--" * 20)
    train_classes = ClassImagesSet(train_file,
                                   preload=str2bool(args.preload_train))

    # create a dataloader that has no fixed support
    no_fixS_train_meta_dataset = MetaDataset(
        dataset_name=dataset_name,
        support_class_images_set=train_classes,
        query_class_images_set=train_classes,
        image_size=image_size,
        support_aug=False,
        query_aug=False,
        fix_support=0,  # no fixed support
        save_folder='',
        verbose=False)

    no_fixS_train_loader = MetaDataLoader(dataset=no_fixS_train_meta_dataset,
                                          n_batches=args.n_iterations_val,
                                          batch_size=args.batch_size_val,
                                          n_way=args.n_way_val,
                                          n_shot=args.n_shot_val,
                                          n_query=args.n_query_val,
                                          randomize_query=False)

    print("\n", "--" * 20, "VAL", "--" * 20)
    val_classes = ClassImagesSet(val_file,
                                 preload=str2bool(args.preload_train))
    val_meta_datasets = {}
    val_loaders = {}
    for ns_val in all_n_shot_vals:
        print("====", f"n_shots_val {ns_val}", "====")
        val_meta_datasets[ns_val] = MetaDataset(
            dataset_name=dataset_name,
            support_class_images_set=val_classes,
            query_class_images_set=val_classes,
            image_size=image_size,
            support_aug=False,
            query_aug=False,
            fix_support=0,
            save_folder='')

        val_loaders[ns_val] = MetaDataLoader(dataset=val_meta_datasets[ns_val],
                                             n_batches=args.n_iterations_val,
                                             batch_size=args.batch_size_val,
                                             n_way=args.n_way_val,
                                             n_shot=ns_val,
                                             n_query=args.n_query_val,
                                             randomize_query=False)

    print("\n", "--" * 20, "NOVEL", "--" * 20)
    test_classes = ClassImagesSet(test_file,
                                  preload=str2bool(args.preload_train))
    test_meta_datasets = {}
    test_loaders = {}
    for ns_val in all_n_shot_vals:
        print("====", f"n_shots_val {ns_val}", "====")
        test_meta_datasets[ns_val] = MetaDataset(
            dataset_name=dataset_name,
            support_class_images_set=test_classes,
            query_class_images_set=test_classes,
            image_size=image_size,
            support_aug=False,
            query_aug=False,
            fix_support=0,
            save_folder='')

        test_loaders[ns_val] = MetaDataLoader(
            dataset=test_meta_datasets[ns_val],
            n_batches=args.n_iterations_val,
            batch_size=args.batch_size_val,
            n_way=args.n_way_val,
            n_shot=ns_val,
            n_query=args.n_query_val,
            randomize_query=False,
        )

    if base_class_generalization:
        # can only do this if there is only one type of evaluation
        print("\n", "--" * 20, "BASE TEST", "--" * 20)
        base_test_classes = ClassImagesSet(base_test_file,
                                           preload=str2bool(
                                               args.preload_train))

        # if args.fix_support > 0:
        #     base_test_meta_dataset_using_fixS = MetaDataset(
        #                                             dataset_name=dataset_name,
        #                                             support_class_images_set=train_classes,
        #                                             query_class_images_set=base_test_classes,
        #                                             image_size=image_size,
        #                                             support_aug=False,
        #                                             query_aug=False,
        #                                             fix_support=0,
        #                                             save_folder='',
        #                                             fix_support_path=os.path.join(args.output_folder,
        #                                                                           "fixed_support_pool.pkl"))

        #     base_test_loader_using_fixS = MetaDataLoader(
        #                                     dataset=base_test_meta_dataset_using_fixS,
        #                                     n_batches=args.n_iterations_val,
        #                                     batch_size=args.batch_size_val,
        #                                     n_way=args.n_way_val,
        #                                     n_shot=args.n_shot_val,
        #                                     n_query=args.n_query_val,
        #                                     randomize_query=False,)

        print("\n", "--" * 20, "BASE + NOVEL TEST", "--" * 20)
        assert len(set(base_test_classes.keys()).intersection(set(test_classes.keys()))) == 0,\
            f"the base and novel classes must have different ids, base:{set(base_test_classes.keys())}, novel: f{set(test_classes.keys())}"
        # combine both base and novel classes
        base_novel_test_classes = ClassImagesSet(base_test_file, test_file)
        base_novel_test_meta_dataset = MetaDataset(
            dataset_name=dataset_name,
            support_class_images_set=base_novel_test_classes,
            query_class_images_set=base_novel_test_classes,
            image_size=image_size,
            support_aug=False,
            query_aug=False,
            fix_support=0,
            save_folder='')

        # sample classes from base and novel with mix prob. given by lambd
        base_novel_test_loaders_dict = {}
        for lambd in np.arange(0., 1.1, 0.1):
            base_novel_test_loaders_dict[lambd] = MetaDataLoader(
                dataset=base_novel_test_meta_dataset,
                n_batches=args.n_iterations_val,
                batch_size=args.batch_size_val,
                n_way=args.n_way_val,
                n_shot=args.n_shot_val,
                n_query=args.n_query_val,
                randomize_query=False,
                p_dict={
                    k: ((1 - lambd) / len(base_test_classes) if k
                        in base_test_classes else lambd / len(test_classes))
                    for k in list(base_test_classes) + list(test_classes)
                })

    ####################################################
    #             MODEL/BACKBONE CREATION              #
    ####################################################

    print("\n", "--" * 20, "MODEL", "--" * 20)

    if args.model_type == 'resnet_12':
        if 'miniImagenet' in dataset_name or 'CUB' in dataset_name:
            model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool),
                                       drop_rate=0.1,
                                       dropblock_size=5,
                                       num_classes=args.num_classes_train,
                                       classifier_type=args.classifier_type,
                                       projection=str2bool(args.projection),
                                       learnable_scale=str2bool(
                                           args.learnable_scale))
        else:
            model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool),
                                       drop_rate=0.1,
                                       dropblock_size=2,
                                       num_classes=args.num_classes_train,
                                       classifier_type=args.classifier_type,
                                       projection=str2bool(args.projection),
                                       learnable_scale=str2bool(
                                           args.learnable_scale))
    elif args.model_type in ['conv64', 'conv48', 'conv32']:
        dim = int(args.model_type[-2:])
        model = shallow_conv.ShallowConv(z_dim=dim,
                                         h_dim=dim,
                                         num_classes=args.num_classes_train,
                                         x_width=image_size,
                                         classifier_type=args.classifier_type,
                                         projection=str2bool(args.projection),
                                         learnable_scale=str2bool(
                                             args.learnable_scale))
    elif args.model_type == 'wide_resnet28_10':
        model = wide_resnet.wrn28_10(projection=str2bool(args.projection),
                                     classifier_type=args.classifier_type,
                                     learnable_scale=str2bool(
                                         args.learnable_scale))
    elif args.model_type == 'wide_resnet16_10':
        model = wide_resnet.wrn16_10(projection=str2bool(args.projection),
                                     classifier_type=args.classifier_type,
                                     learnable_scale=str2bool(
                                         args.learnable_scale))
    else:
        raise ValueError('Unrecognized model type {}'.format(args.model_type))
    print("Model\n" + "==" * 27)
    print(model)

    ####################################################
    #                LOAD FROM CHECKPOINT              #
    ####################################################

    assert args.checkpoint != '', "Must provide checkpoint"
    print(f"loading model from {args.checkpoint}")
    model_dict = model.state_dict()
    chkpt = torch.load(args.checkpoint, map_location=torch.device('cpu'))

    ### load model
    chkpt_state_dict = chkpt['model']
    chkpt_state_dict_old_keys = list(chkpt_state_dict.keys())
    # remove "module." from key, possibly present as it was dumped by data-parallel
    for key in chkpt_state_dict_old_keys:
        if 'module.' in key:
            new_key = re.sub('module\.', '', key)
            chkpt_state_dict[new_key] = chkpt_state_dict.pop(key)
    load_model_state_dict = {
        k: v
        for k, v in chkpt_state_dict.items() if k in model_dict
    }
    model_dict.update(load_model_state_dict)
    # updated_keys = set(model_dict).intersection(set(chkpt_state_dict))
    print(f"Updated {len(load_model_state_dict.keys())} keys using chkpt")
    print("Following keys updated :",
          "\n".join(sorted(load_model_state_dict.keys())))
    missed_keys = set(model_dict).difference(set(load_model_state_dict))
    print(f"Missed {len(missed_keys)} keys")
    print("Following keys missed :", "\n".join(sorted(missed_keys)))
    model.load_state_dict(model_dict)

    # Multi-gpu support and device setup
    os.environ["CUDA_VISIBLE_DEVICES"] = args.device_number
    print('Using GPUs: ', os.environ["CUDA_VISIBLE_DEVICES"])
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model.cuda()

    ####################################################
    #        ALGORITHM AND ALGORITHM TRAINER           #
    ####################################################

    # algorithm
    if args.algorithm == 'InitBasedAlgorithm':
        algorithm = InitBasedAlgorithm(
            model=model,
            loss_func=torch.nn.CrossEntropyLoss(),
            method=args.init_meta_algorithm,
            alpha=args.alpha,
            inner_loop_grad_clip=args.grad_clip_inner,
            inner_update_method=args.inner_update_method,
            device='cuda')
    elif args.algorithm == 'ProtoNet':
        algorithm = ProtoNet(model=model,
                             inner_loss_func=torch.nn.CrossEntropyLoss(),
                             device='cuda',
                             scale=args.scale_factor,
                             metric=args.classifier_metric)
    elif args.algorithm == 'SVM':
        algorithm = SVM(model=model,
                        inner_loss_func=torch.nn.CrossEntropyLoss(),
                        scale=args.scale_factor,
                        device='cuda')
    elif args.algorithm == 'Ridge':
        algorithm = Ridge(model=model,
                          inner_loss_func=torch.nn.CrossEntropyLoss(),
                          scale=args.scale_factor,
                          device='cuda')
    else:
        raise ValueError('Unrecognized algorithm {}'.format(args.algorithm))

    if args.algorithm != 'InitBasedAlgorithm':
        trainer = Meta_algorithm_trainer(algorithm=algorithm,
                                         optimizer=None,
                                         writer=None,
                                         log_interval=args.log_interval,
                                         save_folder='',
                                         grad_clip=None,
                                         init_global_iteration=None)
    else:
        trainer = Init_algorithm_trainer(
            algorithm=algorithm,
            optimizer=None,
            writer=None,
            log_interval=args.log_interval,
            save_folder='',
            grad_clip=None,
            num_updates_inner_train=args.num_updates_inner_train,
            num_updates_inner_val=args.num_updates_inner_val,
            init_global_iteration=None)

    ####################################################
    #                    EVALUATION                    #
    ####################################################

    print("\n", "--" * 20, "BEGIN EVALUATION", "--" * 20)

    # On ML train objective
    print("Train Loss on ML objective")
    results = trainer.run(mt_loader=no_fixS_train_loader, is_training=False)
    pp = pprint.PrettyPrinter(indent=4)
    pp.pprint(results)
    with open(eval_results, 'a') as f:
        f.write(
            f"TrainLossOnML{args.n_way_val}w{ns_val}s: Loss {round(results['test_loss_after']['loss'], 3)} Acc {round(results['test_loss_after']['accu'], 3)}"
            + '\n')

    # validation/test
    # val_accus = {}
    # novel_test_losses = {}
    for ns_val in all_n_shot_vals:
        print("Validation ", f"n_shots_val {ns_val}")
        results = trainer.run(mt_loader=val_loaders[ns_val], is_training=False)
        pp = pprint.PrettyPrinter(indent=4)
        pp.pprint(results)
        with open(eval_results, 'a') as f:
            f.write(
                f"Val{args.n_way_val}w{ns_val}s: Loss {round(results['test_loss_after']['loss'], 3)} Acc {round(results['test_loss_after']['accu'], 3)}"
                + '\n')

        print("Test ", f"n_shots_val {ns_val}")
        results = trainer.run(mt_loader=test_loaders[ns_val],
                              is_training=False)
        pp = pprint.PrettyPrinter(indent=4)
        pp.pprint(results)
        with open(eval_results, 'a') as f:
            f.write(
                f"Test{args.n_way_val}w{ns_val}s: Loss {round(results['test_loss_after']['loss'], 3)} Acc {round(results['test_loss_after']['accu'], 3)}"
                + '\n')

    # base class generalization
    if base_class_generalization:
        # can only do this if there is only one type of evaluation

        print("Base Test")

        # if args.fix_support > 0:
        #     print("Base Test using FixSupport, matching train and test for fixml")
        #     results = trainer.run(
        #         mt_loader=base_test_loader_using_fixS, is_training=False)
        #     pp = pprint.PrettyPrinter(indent=4)
        #     pp.pprint(results)
        #     with open(eval_results, 'a') as f:
        #         f.write(f"BaseTestUsingFixSupport: Loss {round(results['test_loss_after']['loss'], 3)} Acc {round(results['test_loss_after']['accu'], 3)}"+'\n')

        for lambd, base_novel_test_loader in base_novel_test_loaders_dict.items(
        ):
            print(
                f"Base + Novel Test lambda={round(lambd, 2)} Novel {round(1-lambd, 2)} Base"
            )
            results = trainer.run(mt_loader=base_novel_test_loader,
                                  is_training=False)
            pp = pprint.PrettyPrinter(indent=4)
            pp.pprint(results)
            with open(eval_results, 'a') as f:
                f.write(
                    f"Base+NovelTestLambda={round(lambd, 2)}Novel{round(1-lambd, 2)}Base: Loss {round(results['test_loss_after']['loss'], 3)} Acc {round(results['test_loss_after']['accu'], 3)}"
                    + '\n')
コード例 #4
0
def compute_norm(args, p=2):

    ####################################################
    #             MODEL/BACKBONE CREATION              #
    ####################################################

    print("\n", "--" * 20, "MODEL", "--" * 20)
    dataset_name = args.dataset_name
    if args.model_type == 'resnet_12':
        if 'miniImagenet' in dataset_name or 'CUB' in dataset_name:
            model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool),
                                       drop_rate=0.1,
                                       dropblock_size=5,
                                       num_classes=args.num_classes_train,
                                       classifier_type=args.classifier_type,
                                       projection=str2bool(args.projection))
        else:
            model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool),
                                       drop_rate=0.1,
                                       dropblock_size=2,
                                       num_classes=args.num_classes_train,
                                       classifier_type=args.classifier_type,
                                       projection=str2bool(args.projection))
    elif args.model_type in ['conv64', 'conv48', 'conv32']:
        dim = int(args.model_type[-2:])
        model = shallow_conv.ShallowConv(z_dim=dim,
                                         h_dim=dim,
                                         num_classes=args.num_classes_train,
                                         classifier_type=args.classifier_type,
                                         projection=str2bool(args.projection))
    elif args.model_type == 'wide_resnet28_10':
        model = wide_resnet.wrn28_10(projection=str2bool(args.projection),
                                     classifier_type=args.classifier_type)
    elif args.model_type == 'wide_resnet16_10':
        model = wide_resnet.wrn16_10(projection=str2bool(args.projection),
                                     classifier_type=args.classifier_type)
    else:
        raise ValueError('Unrecognized model type {}'.format(args.model_type))
    print("Model\n" + "==" * 27)
    print(model)

    ####################################################
    #                LOAD FROM CHECKPOINT              #
    ####################################################

    if args.checkpoint != '':
        print(f"loading model from {args.checkpoint}")
        model_dict = model.state_dict()
        chkpt = torch.load(args.checkpoint, map_location=torch.device('cpu'))
        try:
            print(f"loading optimizer from {args.checkpoint}")
            optimizer.state = chkpt['optimizer'].state
            print("Successfully loaded optimizer")
        except:
            print("Failed to load optimizer")
        chkpt_state_dict = chkpt['model']
        chkpt_state_dict_cpy = chkpt_state_dict.copy()
        # remove "module." from key, possibly present as it was dumped by data-parallel
        for key in chkpt_state_dict_cpy.keys():
            if 'module.' in key:
                new_key = re.sub('module\.', '', key)
                chkpt_state_dict[new_key] = chkpt_state_dict.pop(key)
        chkpt_state_dict = {
            k: v
            for k, v in chkpt_state_dict.items() if k in model_dict
        }
        model_dict.update(chkpt_state_dict)
        updated_keys = set(model_dict).intersection(set(chkpt_state_dict))
        print(f"Updated {len(updated_keys)} keys using chkpt")
        print("Following keys updated :", "\n".join(sorted(updated_keys)))
        missed_keys = set(model_dict).difference(set(chkpt_state_dict))
        print(f"Missed {len(missed_keys)} keys")
        print("Following keys missed :", "\n".join(sorted(missed_keys)))
        model.load_state_dict(model_dict)

    ####################################################
    #                COMPUTE NORM HERE                 #
    ####################################################

    if p == 2:
        return l2_norm(model)
    else:
        raise ValueError('Unspecified norm type')
コード例 #5
0
def main(args):

    ####################################################
    #                LOGGING AND SAVING                #
    ####################################################
    if args.checkpoint != '':
        # if we are reloading, we don't need to timestamp and create a new folder
        # instead keep writing to the original output_folder
        assert os.path.exists(f'./runs/{args.output_folder}')
        args.output_folder = f'./runs/{args.output_folder}'
        print(f'resume training and will write to {args.output_folder}')
    else:
        args.output_folder = ensure_path('./runs/{0}'.format(
            args.output_folder))
    writer = SummaryWriter(args.output_folder)

    time_now = datetime.now(
        pytz.timezone("America/New_York")).strftime("%d:%b:%Y:%H:%M:%S")
    with open(f'{args.output_folder}/config_{time_now}.txt',
              'w') as config_txt:
        for k, v in sorted(vars(args).items()):
            config_txt.write(f'{k}: {v}\n')
    save_folder = args.output_folder

    # replace stdout with Logger; the original sys.stdout is saved in src.logger.stdout
    sys.stdout = src.logger.Logger(
        log_filename=f'{args.output_folder}/train_{time_now}.log')
    src.logger.stdout.write('hi!')

    ####################################################
    #         DATASET AND DATALOADER CREATION          #
    ####################################################

    # json paths
    dataset_name = args.dataset_path.split('/')[-1]
    image_size = args.img_side_len
    dataset_name = args.dataset_path.split('/')[-1]
    # Following is needed when same train config is used for both 5w5s and 5w1s evaluations.
    # This is the case in the case of SVM when 5w15s5q is used for both 5w5s and 5w1s evaluations.
    all_n_shot_vals = [args.n_shot_val, 1] if str2bool(
        args.do_one_shot_eval_too) else [args.n_shot_val]
    base_class_generalization = dataset_name.lower() in [
        'miniimagenet', 'fc100-base', 'cifar-fs-base', 'tieredimagenet-base'
    ]
    train_file = os.path.join(args.dataset_path, 'base.json')
    val_file = os.path.join(args.dataset_path, 'val.json')
    test_file = os.path.join(args.dataset_path, 'novel.json')
    if base_class_generalization:
        base_test_file = os.path.join(args.dataset_path, 'base_test.json')
    print("Dataset name", dataset_name, "image_size", image_size,
          "all_n_shot_vals", all_n_shot_vals)
    print("base_class_generalization:", base_class_generalization)
    """
    1. Create ClassImagesSet object, which handles preloading of images
    2. Pass ClassImagesSet to MetaDataset which handles nshot, nquery and fixSupport
    3. Create Dataloader object from MetaDataset
    """

    print("\n", "--" * 20, "TRAIN", "--" * 20)
    train_classes = ClassImagesSet(train_file,
                                   preload=str2bool(args.preload_train))
    if args.algorithm == 'TransferLearning':
        """
        For Transfer Learning we create a SimpleDataset.
        The augmentation is decided by query_aug flag.
        """
        train_dataset = SimpleDataset(dataset_name=dataset_name,
                                      class_images_set=train_classes,
                                      image_size=image_size,
                                      aug=str2bool(args.query_aug))

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size_train,
            shuffle=True,
            num_workers=6)

    else:
        train_meta_dataset = MetaDataset(
            dataset_name=dataset_name,
            support_class_images_set=train_classes,
            query_class_images_set=train_classes,
            image_size=image_size,
            support_aug=str2bool(args.support_aug),
            query_aug=str2bool(args.query_aug),
            fix_support=args.fix_support,
            save_folder=save_folder,
            fix_support_path=args.fix_support_path)

        train_loader = MetaDataLoader(dataset=train_meta_dataset,
                                      batch_size=args.batch_size_train,
                                      n_batches=args.n_iters_per_epoch,
                                      n_way=args.n_way_train,
                                      n_shot=args.n_shot_train,
                                      n_query=args.n_query_train,
                                      randomize_query=str2bool(
                                          args.randomize_query))

    # create a dataloader that has no fixed support
    no_fixS_train_meta_dataset = MetaDataset(
        dataset_name=dataset_name,
        support_class_images_set=train_classes,
        query_class_images_set=train_classes,
        image_size=image_size,
        support_aug=False,
        query_aug=False,
        fix_support=0,  # no fixed support
        save_folder='',
        verbose=False)

    no_fixS_train_loader = MetaDataLoader(dataset=no_fixS_train_meta_dataset,
                                          n_batches=args.n_iterations_val,
                                          batch_size=args.batch_size_val,
                                          n_way=args.n_way_val,
                                          n_shot=args.n_shot_val,
                                          n_query=args.n_query_val,
                                          randomize_query=False)

    print("\n", "--" * 20, "VAL", "--" * 20)
    val_classes = ClassImagesSet(val_file, preload=False)
    val_meta_datasets = {}
    val_loaders = {}
    for ns_val in all_n_shot_vals:
        print("====", f"n_shots_val {ns_val}", "====")
        val_meta_datasets[ns_val] = MetaDataset(
            dataset_name=dataset_name,
            support_class_images_set=val_classes,
            query_class_images_set=val_classes,
            image_size=image_size,
            support_aug=False,
            query_aug=False,
            fix_support=0,
            save_folder='')

        val_loaders[ns_val] = MetaDataLoader(dataset=val_meta_datasets[ns_val],
                                             n_batches=args.n_iterations_val,
                                             batch_size=args.batch_size_val,
                                             n_way=args.n_way_val,
                                             n_shot=ns_val,
                                             n_query=args.n_query_val,
                                             randomize_query=False)

    print("\n", "--" * 20, "TEST", "--" * 20)
    test_classes = ClassImagesSet(test_file)
    test_meta_datasets = {}
    test_loaders = {}
    for ns_val in all_n_shot_vals:
        print("====", f"n_shots_val {ns_val}", "====")
        test_meta_datasets[ns_val] = MetaDataset(
            dataset_name=dataset_name,
            support_class_images_set=test_classes,
            query_class_images_set=test_classes,
            image_size=image_size,
            support_aug=False,
            query_aug=False,
            fix_support=0,
            save_folder='')

        test_loaders[ns_val] = MetaDataLoader(
            dataset=test_meta_datasets[ns_val],
            n_batches=args.n_iterations_val,
            batch_size=args.batch_size_val,
            n_way=args.n_way_val,
            n_shot=ns_val,
            n_query=args.n_query_val,
            randomize_query=False,
        )

    if base_class_generalization:
        # can only do this if there is only one type of evaluation
        print("\n", "--" * 20, "BASE TEST", "--" * 20)
        base_test_classes = ClassImagesSet(base_test_file)
        base_test_meta_dataset = MetaDataset(
            dataset_name=dataset_name,
            support_class_images_set=base_test_classes,
            query_class_images_set=base_test_classes,
            image_size=image_size,
            support_aug=False,
            query_aug=False,
            fix_support=0,
            save_folder=save_folder)
        base_test_loader = MetaDataLoader(dataset=base_test_meta_dataset,
                                          n_batches=args.n_iterations_val,
                                          batch_size=args.batch_size_val,
                                          n_way=args.n_way_val,
                                          n_shot=args.n_shot_val,
                                          n_query=args.n_query_val,
                                          randomize_query=False)

        if args.fix_support > 0:
            base_test_meta_dataset_using_fixS = MetaDataset(
                dataset_name=dataset_name,
                support_class_images_set=train_classes,
                query_class_images_set=base_test_classes,
                image_size=image_size,
                support_aug=False,
                query_aug=False,
                fix_support=0,
                save_folder=save_folder,
                fix_support_path=os.path.join(save_folder,
                                              "fixed_support_pool.pkl"))

            base_test_loader_using_fixS = MetaDataLoader(
                dataset=base_test_meta_dataset_using_fixS,
                n_batches=args.n_iterations_val,
                batch_size=args.batch_size_val,
                n_way=args.n_way_val,
                n_shot=args.n_shot_val,
                n_query=args.n_query_val,
                randomize_query=False,
            )

    ####################################################
    #             MODEL/BACKBONE CREATION              #
    ####################################################

    print("\n", "--" * 20, "MODEL", "--" * 20)

    if args.model_type == 'resnet_12':
        # technically tieredimagenet should also have dropblock size of 5
        if 'miniImagenet' in dataset_name or 'CUB' in dataset_name:
            model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool),
                                       drop_rate=0.1,
                                       dropblock_size=5,
                                       num_classes=args.num_classes_train,
                                       classifier_type=args.classifier_type,
                                       projection=str2bool(args.projection),
                                       learnable_scale=str2bool(
                                           args.learnable_scale))
        else:
            model = resnet_12.resnet12(avg_pool=str2bool(args.avg_pool),
                                       drop_rate=0.1,
                                       dropblock_size=2,
                                       num_classes=args.num_classes_train,
                                       classifier_type=args.classifier_type,
                                       projection=str2bool(args.projection),
                                       learnable_scale=str2bool(
                                           args.learnable_scale))
    elif args.model_type in ['conv64', 'conv48', 'conv32']:
        dim = int(args.model_type[-2:])
        model = shallow_conv.ShallowConv(z_dim=dim,
                                         h_dim=dim,
                                         num_classes=args.num_classes_train,
                                         x_width=image_size,
                                         classifier_type=args.classifier_type,
                                         projection=str2bool(args.projection),
                                         learnable_scale=str2bool(
                                             args.learnable_scale))
    elif args.model_type == 'wide_resnet28_10':
        model = wide_resnet.wrn28_10(projection=str2bool(args.projection),
                                     classifier_type=args.classifier_type,
                                     learnable_scale=str2bool(
                                         args.learnable_scale))
    elif args.model_type == 'wide_resnet16_10':
        model = wide_resnet.wrn16_10(projection=str2bool(args.projection),
                                     classifier_type=args.classifier_type,
                                     learnable_scale=str2bool(
                                         args.learnable_scale))
    else:
        raise ValueError('Unrecognized model type {}'.format(args.model_type))
    print("Model\n" + "==" * 27)
    print(model)

    ####################################################
    #                OPTIMIZER CREATION                #
    ####################################################

    # optimizer construction
    print("\n", "--" * 20, "OPTIMIZER", "--" * 20)
    print("Optimzer", args.optimizer_type)
    if args.optimizer_type == 'adam':
        optimizer = torch.optim.Adam([{
            'params': model.parameters(),
            'lr': args.lr,
            'weight_decay': args.weight_decay
        }])
    else:
        optimizer = modified_sgd.SGD([
            {
                'params': model.parameters(),
                'lr': args.lr,
                'weight_decay': args.weight_decay,
                'momentum': 0.9,
                'nesterov': True
            },
        ])
    print("Total n_epochs: ", args.n_epochs)

    # learning rate scheduler creation
    if args.lr_scheduler_type == 'deterministic':
        drop_eps = [int(x) for x in args.drop_lr_epoch.split(',')]
        if args.drop_factors != '':
            drop_factors = [float(x) for x in args.drop_factors.split(',')]
        else:
            drop_factors = [0.06, 0.012, 0.0024]

        print("Drop lr at epochs", drop_eps)
        print("Drop factors", drop_factors[:len(drop_eps)])

        assert len(drop_factors) >= len(drop_eps), "No enough drop factors"
        # assert len(drop_eps) <= 3, "Must give less than or equal to three epochs to drop lr"
        '''
        if len(drop_eps) == 3:
            lambda_epoch = lambda e: 1.0 if e < drop_eps[0] else (drop_factors[0] if e < drop_eps[1] else drop_factors[1] if e < drop_eps[2] else (drop_factors[2]))
        elif len(drop_eps) == 2:
            lambda_epoch = lambda e: 1.0 if e < drop_eps[0] else (drop_factors[0] if e < drop_eps[1] else drop_factors[1])
        else:
            lambda_epoch = lambda e: 1.0 if e < drop_eps[0] else drop_factors[0]
        '''
        def lr_lambda(x):
            '''
            x is an epoch number
            drop_eps is assumed to an list of strictly increasing epoch numbers
            here we require len(drop_factors) >= len(drop_eps)
            ideally they are of the same length
            but technically the code can just not use the additional factors
            '''
            for i in range(len(drop_eps)):
                if x >= drop_eps[i]:
                    continue
                else:
                    if i == 0:
                        return 1.0
                    else:
                        return drop_factors[i - 1]
            return drop_factors[len(drop_eps) - 1]

        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                         lr_lambda=lr_lambda,
                                                         last_epoch=-1)
        for _ in range(args.restart_iter):
            lr_scheduler.step()

    elif args.lr_scheduler_type == 'val_based':
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='max',
            patience=5,
            factor=0.1,
            min_lr=5e-6,
            threshold=0.5)

    else:
        raise ValueError("Unimplemented lr scheduler")

    print("LR scheduler ", args.lr_scheduler_type)

    ####################################################
    #                LOAD FROM CHECKPOINT              #
    ####################################################

    if args.checkpoint != '':
        print(f"loading model from {args.checkpoint}")
        model_dict = model.state_dict()  # new model's state dict
        chkpt = torch.load(args.checkpoint, map_location=torch.device('cpu'))

        ### load model
        chkpt_state_dict = chkpt['model']
        chkpt_state_dict_old_keys = list(chkpt_state_dict.keys())
        # remove "module." from key, possibly present as it was dumped by data-parallel
        for key in chkpt_state_dict_old_keys:
            if 'module.' in key:
                new_key = re.sub('module\.', '', key)
                chkpt_state_dict[new_key] = chkpt_state_dict.pop(key)
        load_model_state_dict = {
            k: v
            for k, v in chkpt_state_dict.items() if k in model_dict
        }
        model_dict.update(load_model_state_dict)
        # updated_keys = set(model_dict).intersection(set(chkpt_state_dict))
        print(f"Updated {len(load_model_state_dict.keys())} keys using chkpt")
        print("Following keys updated :",
              "\n".join(sorted(load_model_state_dict.keys())))
        missed_keys = set(model_dict).difference(set(load_model_state_dict))
        print(f"Missed {len(missed_keys)} keys")
        print("Following keys missed :", "\n".join(sorted(missed_keys)))
        model.load_state_dict(model_dict)

        ### load optimizer
        try:
            print(f"loading optimizer from {args.checkpoint}")
            optimizer.load_state_dict(chkpt['optimizer'].state_dict())
            print("Successfully loaded optimizer")

        except:
            print("Failed to load optimizer")

    ### Multi-gpu support and device setup
    os.environ["CUDA_VISIBLE_DEVICES"] = args.device_number
    print('Using GPUs: ', os.environ["CUDA_VISIBLE_DEVICES"])
    # move model to cuda
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model.cuda()
    print("Successfully moved the model to cuda")

    # move the optimizer's states to cuda if loaded
    if args.checkpoint != '':
        # https://github.com/pytorch/pytorch/issues/2830
        # when using gpu, need to move all the statistics of the optimizer to cuda
        # in addition to the model parameters
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.cuda()
        print("Successfully moved the optimizer's states to cuda")

    ####################################################
    #        ALGORITHM AND ALGORITHM TRAINER           #
    ####################################################

    # start tboard from restart iter
    init_global_iteration = 0
    if args.restart_iter:
        init_global_iteration = args.restart_iter * args.n_iters_per_epoch

    # algorithm
    if args.algorithm == 'InitBasedAlgorithm':
        algorithm = InitBasedAlgorithm(
            model=model,
            loss_func=torch.nn.CrossEntropyLoss(),
            method=args.init_meta_algorithm,
            alpha=args.alpha,
            inner_loop_grad_clip=args.grad_clip_inner,
            inner_update_method=args.inner_update_method,
            device='cuda')
    elif args.algorithm == 'ProtoNet':
        algorithm = ProtoNet(model=model,
                             inner_loss_func=torch.nn.CrossEntropyLoss(),
                             device='cuda',
                             scale=args.scale_factor,
                             metric=args.classifier_metric)
    elif args.algorithm == 'SVM':
        algorithm = SVM(model=model,
                        inner_loss_func=torch.nn.CrossEntropyLoss(),
                        scale=args.scale_factor,
                        device='cuda')
    elif args.algorithm == 'Ridge':
        algorithm = Ridge(model=model,
                          inner_loss_func=torch.nn.CrossEntropyLoss(),
                          scale=args.scale_factor,
                          device='cuda')
    elif args.algorithm == 'TransferLearning':
        """
        We use the ProtoNet algorithm at test time.
        """
        algorithm = ProtoNet(model=model,
                             inner_loss_func=torch.nn.CrossEntropyLoss(),
                             device='cuda',
                             scale=args.scale_factor,
                             metric=args.classifier_metric)
    else:
        raise ValueError('Unrecognized algorithm {}'.format(args.algorithm))

    if args.algorithm == 'InitBasedAlgorithm':
        trainer = Init_algorithm_trainer(
            algorithm=algorithm,
            optimizer=optimizer,
            writer=writer,
            log_interval=args.log_interval,
            save_folder=save_folder,
            grad_clip=args.grad_clip,
            num_updates_inner_train=args.num_updates_inner_train,
            num_updates_inner_val=args.num_updates_inner_val,
            init_global_iteration=init_global_iteration)
    elif args.algorithm == 'TransferLearning':
        trainer = TL_algorithm_trainer(
            algorithm=algorithm,
            optimizer=optimizer,
            writer=writer,
            log_interval=args.log_interval,
            save_folder=save_folder,
            grad_clip=args.grad_clip,
            init_global_iteration=init_global_iteration)
    else:
        trainer = Meta_algorithm_trainer(
            algorithm=algorithm,
            optimizer=optimizer,
            writer=writer,
            log_interval=args.log_interval,
            save_folder=save_folder,
            grad_clip=args.grad_clip,
            init_global_iteration=init_global_iteration)

    ####################################################
    #                  TRAINER LOOP                    #
    ####################################################

    print("\n", "--" * 20, "BEGIN TRAINING", "--" * 20)

    # iterate over training epochs
    for iter_start in range(args.restart_iter, args.n_epochs):

        # training
        for param_group in optimizer.param_groups:
            print('\n\nlearning rate:', param_group['lr'])

        trainer.run(mt_loader=train_loader,
                    is_training=True,
                    epoch=iter_start + 1)  # 1 based instead of 0 based

        if iter_start % args.val_frequency == 0:
            # On ML train objective
            print("Train Loss on ML objective")
            results = trainer.run(mt_loader=no_fixS_train_loader,
                                  is_training=False)
            print(pprint.pformat(results, indent=4))
            writer.add_scalar("train_acc_on_ml",
                              results['test_loss_after']['accu'],
                              iter_start + 1)
            writer.add_scalar("train_loss_on_ml",
                              results['test_loss_after']['loss'],
                              iter_start + 1)
            base_train_loss = results['test_loss_after']['loss']

            # validation/test
            val_accus = {}
            novel_test_losses = {}
            for ns_val in all_n_shot_vals:
                print("Validation ", f"n_shots_val {ns_val}")
                results = trainer.run(mt_loader=val_loaders[ns_val],
                                      is_training=False)
                print(pprint.pformat(results, indent=4))
                writer.add_scalar(f"val_acc_{args.n_way_val}w{ns_val}s",
                                  results['test_loss_after']['accu'],
                                  iter_start + 1)
                writer.add_scalar(f"val_loss_{args.n_way_val}w{ns_val}s",
                                  results['test_loss_after']['loss'],
                                  iter_start + 1)
                val_accus[ns_val] = results['test_loss_after']['accu']

                print("Test ", f"n_shots_val {ns_val}")
                results = trainer.run(mt_loader=test_loaders[ns_val],
                                      is_training=False)
                print(pprint.pformat(results, indent=4))
                writer.add_scalar(f"test_acc_{args.n_way_val}w{ns_val}s",
                                  results['test_loss_after']['accu'],
                                  iter_start + 1)
                writer.add_scalar(f"test_loss_{args.n_way_val}w{ns_val}s",
                                  results['test_loss_after']['loss'],
                                  iter_start + 1)
                novel_test_losses[ns_val] = results['test_loss_after']['loss']

            val_accu = val_accus[
                args.n_shot_val]  # stick with 5w5s for model selection
            novel_test_loss = novel_test_losses[
                args.n_shot_val]  # stick with 5w5s for model selection

            # base class generalization
            if base_class_generalization:
                # can only do this if there is only one type of evaluation

                print("Base Test")
                results = trainer.run(mt_loader=base_test_loader,
                                      is_training=False)
                print(pprint.pformat(results, indent=4))
                writer.add_scalar("base_test_acc",
                                  results['test_loss_after']['accu'],
                                  iter_start + 1)
                writer.add_scalar("base_test_loss",
                                  results['test_loss_after']['loss'],
                                  iter_start + 1)
                base_test_loss = results['test_loss_after']['loss']
                writer.add_scalar("base_gen_gap",
                                  base_test_loss - base_train_loss,
                                  iter_start + 1)
                writer.add_scalar("novel_gen_gap",
                                  novel_test_loss - base_train_loss,
                                  iter_start + 1)

                if args.fix_support > 0:
                    print(
                        "Base Test using FixSupport, matching train and test for fixml"
                    )
                    results = trainer.run(
                        mt_loader=base_test_loader_using_fixS,
                        is_training=False)
                    print(pprint.pformat(results, indent=4))
                    writer.add_scalar("base_test_acc_usingFixS",
                                      results['test_loss_after']['accu'],
                                      iter_start + 1)
                    writer.add_scalar("base_test_loss_usingFixS",
                                      results['test_loss_after']['loss'],
                                      iter_start + 1)

        # scheduler step
        if args.lr_scheduler_type == 'val_based':
            assert args.val_frequency == 1, "eval after every epoch is mandatory for val based lr scheduler"
            lr_scheduler.step(val_accu)
        else:
            lr_scheduler.step()