示例#1
0
    def train(self, model_path, resume_epoch=0, need_val=False):
        # mtr_loss, mtr_acc, mval_loss, mval_acc = [], [], [], []

        for epoch in range(resume_epoch, self.epoch):
            # Evaluate on test tasks
            # Collect a meta batch update
            # Save a model snapshot every now and then

            for i, (support_images, _, support_labels, query_images, _, query_labels, *_) in enumerate(self.train_loader):
                itr = epoch * len(self.train_loader) + i
                self.adjust_learning_rate(itr, self.meta_step_size, self.lr_decay_itr)
                grads = []
                support_images, support_labels, query_images, query_labels = support_images.cuda(), support_labels.cuda(), query_images.cuda(), query_labels.cuda()
                for task_idx in range(support_images.size(0)):
                    self.fast_net.copy_weights(self.network)
                    # fast_net only forward one task's data
                    g = self.fast_net.forward(support_images[task_idx],query_images[task_idx], support_labels[task_idx], query_labels[task_idx])
                    # (trl, tra, vall, vala) = metrics
                    grads.append(g)

                # Perform the meta update
                # print('Meta update', itr)
                self.meta_update(grads, query_images, query_labels)
                grads.clear()
                if itr % 1000 == 0 and need_val:
                    result_json = finetune_eval_task_accuracy(self.network, self.val_loader, self.inner_step_size,
                                                self.test_finetune_updates, update_BN=True)
                    query_F1_tensor = torch.Tensor(1)
                    query_F1_tensor.fill_(result_json["query_F1"])
                    self.tensorboard.record_val_query_F1(query_F1_tensor, itr)
            torch.save({
                'epoch': epoch + 1,
                'state_dict': self.network.state_dict(),
                'optimizer': self.opt.state_dict(),
            }, model_path)
def evaluate_cross_arch(args):
    # 0-shot的时候请传递args.num_updates = 0
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0])
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_ALL_TEST_ALL@model_conv3@data_conv3@epoch_10@lr_0.0001@batch_100@no_fix_cnn_params.pth.tar
    extract_pattern_detail = re.compile(
        ".*?IMG_ROTATE_DET@(.*?)_(.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)\.pth.tar")
    result = defaultdict(dict)
    update_BN = False
    for model_path in glob.glob("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*".format(PY_ROOT)):
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)]
        if split_protocol != args.protocol:
            continue
        arch = ma.group(3)
        adv_arch = ma.group(4)
        if adv_arch != args.cross_arch_source:
            continue

        epoch = int(ma.group(5))
        lr = float(ma.group(6))
        batch_size = int(ma.group(7))
        print("evaluate_accuracy model :{}".format(os.path.basename(model_path)))
        tot_num_tasks = 20000
        num_classes = 2
        num_query = 15
        old_num_update = args.num_updates
        for shot in [0,1,5]:
            if shot == 0:
                shot = 1
                num_update = 0
            else:
                num_update = old_num_update
            meta_task_dataset = MetaTaskDataset(tot_num_tasks, num_classes, shot, num_query,
                                                dataset, is_train=False, load_mode=args.load_mode,
                                                protocol=split_protocol, no_random_way=True, adv_arch=args.cross_arch_target, fetch_attack_name=False)
            data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True)
            img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset],
                                                   CLASS_NUM[dataset])
            image_transform = ImageTransformCV2(dataset, [1, 2])
            layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100","SVHN"] else 2
            model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number, num_classes=2)
            checkpoint = torch.load(model_path, map_location=lambda storage, location: storage)
            model.load_state_dict(checkpoint['state_dict'])
            model.cuda()
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(model_path, checkpoint['epoch']))
            evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=update_BN)  # FIXME update_BN=False会很高
            if num_update == 0:
                shot = 0
            result["{}@{}-->{}".format(dataset, args.cross_arch_source, args.cross_arch_target)][shot] = evaluate_result
    with open("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/cross_arch_{}--{}_using_{}_result_updateBN_{}.json".format(PY_ROOT, args.cross_arch_source,
                                                                            args.cross_arch_target, args.protocol, update_BN), "w") as file_obj:
        file_obj.write(json.dumps(result))
        file_obj.flush()
示例#3
0
def evaluate_finetune(model_path_list, lr, protocol):
    # deep learning训练是在all_in或者sampled all in下训练的,但是测试需要在task版本的dataset上做
    extract_pattern_detail = re.compile(".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar")
    tot_num_tasks = 20000
    way =  2
    query = 15
    result = defaultdict(dict)
    assert protocol == SPLIT_DATA_PROTOCOL.TRAIN_I_TEST_II, "protocol {} is not TRAIN_I_TEST_II!".format(protocol)
    for model_path in model_path_list:
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        if dataset != "CIFAR-10":
            continue
        file_protocol = ma.group(2)
        if str(protocol) != file_protocol:
            continue
        adv_arch = ma.group(4)
        balance = ma.group(8)
        if balance == "True":
            balance = "balance"
        else:
            balance = "no_balance"
        print("evaluate_accuracy model :{}".format(os.path.basename(model_path)))
        arch = ma.group(3)
        model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2)
        model = model.cuda()
        import torch
        checkpoint = torch.load(model_path, map_location=lambda storage, location: storage)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(model_path, checkpoint['epoch']))
        shot = 1
        meta_task_dataset = MetaTaskDataset(tot_num_tasks, way, shot, query,
                                            dataset, is_train=False,
                                            load_mode=LOAD_TASK_MODE.LOAD,
                                            protocol=protocol, no_random_way=True, adv_arch=adv_arch)
        data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True)
        for num_update in range(0,51):
            evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update, update_BN=False)
            if num_update == 0:
                shot = 0
            else:
                shot = 1
            evaluate_result["shot"] = shot
            result["{}_{}".format(dataset,balance)][num_update] = evaluate_result
    return result
def evaluate_whitebox(dataset, arch, adv_arch, detector, attack_name, num_update, lr, protocol, load_mode,result):
    # deep learning训练是在all_in或者sampled all in下训练的,但是测试需要在task版本的dataset上做
    extract_pattern_detail = re.compile(".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar")
    tot_num_tasks = 20000
    way =  2
    query = 15

    model_path = "{}/train_pytorch_model/white_box_model/DL_DET@{}_{}@model_{}@data_{}@epoch_40@class_2@lr_0.0001@balance_True.pth.tar".format(
        PY_ROOT, dataset, protocol, arch, adv_arch)
    assert os.path.exists(model_path), "{} is not exists".format(model_path)
    root_folder = IMAGE_DATA_ROOT[dataset] + "/adversarial_images/white_box@data_{}@det_{}/{}/".format(adv_arch, detector,attack_name)
    ma = extract_pattern_detail.match(model_path)
    balance = ma.group(8)
    if balance == "True":
        balance = "balance"
    else:
        balance = "no_balance"
    print("evaluate_accuracy model :{}".format(os.path.basename(model_path)))
    arch = ma.group(3)
    model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2)
    model = model.cuda()
    checkpoint = torch.load(model_path, map_location=lambda storage, location: storage)
    model.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint '{}' (epoch {})"
          .format(model_path, checkpoint['epoch']))

    old_num_update = num_update
    # for shot in range(16):
    for shot in [0,1,5]:
        if shot == 0:
            shot = 1
            num_update = 0
        else:
            num_update = old_num_update

        meta_task_dataset = MetaTaskDataset(tot_num_tasks, way, shot, query,
                                            dataset,
                                            load_mode=load_mode, detector=detector, attack_name=attack_name,root_folder=root_folder)
        data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True)
        evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_update,update_BN=True)
        if num_update == 0:
            shot = 0
        result["{}_{}_{}_{}".format(dataset, attack_name, detector, adv_arch)][shot] = evaluate_result

    return result
def meta_zero_shot_evaluate(args):
    extract_pattern = re.compile(
        ".*/MAML@(.*?)_(.*?)@model_(.*?)@data.*?@epoch_(\d+)@meta_batch_size_(\d+)@way_(\d+)@shot_(\d+)@num_query_(\d+)@num_updates_(\d+)@lr_(.*?)@inner_lr_(.*?)@fixed_way_(.*?)@rotate_(.*?)\.pth.tar"
    )
    extract_param_prefix = re.compile(".*/MAML@(.*?)\.pth.tar")
    report_result = defaultdict(dict)
    str2bool = lambda v: v.lower() in ("yes", "true", "t", "1")
    for model_path in glob.glob("{}/train_pytorch_model/{}/MAML@*".format(
            PY_ROOT, args.study_subject)):
        if str(args.split_protocol) not in model_path:
            continue
        ma_prefix = extract_param_prefix.match(model_path)
        param_prefix = ma_prefix.group(1)
        ma = extract_pattern.match(model_path)
        orig_ma = ma
        dataset = ma.group(1)
        split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)]
        arch = ma.group(3)
        epoch = int(ma.group(4))
        meta_batch_size = int(ma.group(5))
        num_classes = int(ma.group(6))
        num_support = int(ma.group(7))
        num_query = int(ma.group(8))  # 用这个num_query来做
        num_updates = int(ma.group(9))
        meta_lr = float(ma.group(10))
        inner_lr = float(ma.group(11))
        fixe_way = str2bool(ma.group(12))
        rotate = str2bool(ma.group(13))

        checkpoint = torch.load(model_path,
                                map_location=lambda storage, location: storage)
        key = "{}__{}".format(dataset, split_protocol)
        if args.cross_domain_source is not None:
            if dataset != args.cross_domain_source:
                continue
            dataset = args.cross_domain_target
            key = "{}-->{}__{}".format(args.cross_domain_source, dataset,
                                       split_protocol)

        learner = MetaLearner(
            dataset,
            num_classes,
            meta_batch_size,
            meta_lr,
            inner_lr,
            args.lr_decay_itr,
            epoch,
            0,  # zero_shot
            args.load_task_mode,
            split_protocol,
            arch,
            args.tot_num_tasks,
            num_support,
            num_query,  # 这个num_query统一用15
            no_random_way=True,
            tensorboard_data_prefix=param_prefix,
            train=False,
            adv_arch=args.adv_arch,
            need_val=True)
        learner.network.load_state_dict(checkpoint['state_dict'], strict=True)
        result_json = finetune_eval_task_accuracy(
            learner.network,
            learner.val_loader,
            learner.inner_step_size,
            learner.test_finetune_updates,
            update_BN=True)
        report_result[dataset][key] = result_json

    file_name = "{}/train_pytorch_model/{}/{}_result.json".format(
        PY_ROOT, args.study_subject, args.study_subject)
    if args.cross_domain_source:
        file_name = "{}/train_pytorch_model/{}/zero_shot_{}--{}_result.json".format(
            PY_ROOT, args.study_subject, args.cross_domain_source,
            args.cross_domain_target)
    with open(file_name, "w") as file_obj:
        file_obj.write(json.dumps(report_result))
        file_obj.flush()
示例#6
0
def meta_cross_arch_evaluate(args):
    # 1 shot 训练出来的模型只能用于1 shot的数据测试
    extract_pattern = re.compile(
        ".*/MAML@(.*?)_(.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@meta_batch_size_(\d+)@way_(\d+)@shot_(\d+)@num_query_(\d+)@num_updates_(\d+)@lr_(.*?)@inner_lr_(.*?)@fixed_way_(.*?)@.*"
    )
    extract_param_prefix = re.compile(".*/MAML@(.*?)\.pth.tar")
    report_result = defaultdict(dict)
    str2bool = lambda v: v.lower() in ("yes", "true", "t", "1")
    updateBN = True
    for shot in [1, 5]:
        for model_path in glob.glob("{}/train_pytorch_model/{}/MAML@*".format(
                PY_ROOT, args.study_subject)):
            ma_prefix = extract_param_prefix.match(model_path)
            param_prefix = ma_prefix.group(1)
            ma = extract_pattern.match(model_path)
            dataset = ma.group(1)
            split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)]
            if split_protocol != args.split_protocol:
                continue
            arch = ma.group(3)
            data_arch = ma.group(4)
            if data_arch != args.cross_arch_source:
                continue
            epoch = int(ma.group(5))
            meta_batch_size = int(ma.group(6))
            num_classes = int(ma.group(7))
            model_train_num_support = int(ma.group(8))
            if shot != model_train_num_support:
                continue
            meta_lr = float(ma.group(11))
            inner_lr = float(ma.group(12))
            fixe_way = str2bool(ma.group(13))
            if not fixe_way:
                continue
            print("=> loading checkpoint '{}'".format(model_path))
            checkpoint = torch.load(
                model_path, map_location=lambda storage, location: storage)
            learner = MetaLearner(dataset,
                                  num_classes,
                                  meta_batch_size,
                                  meta_lr,
                                  inner_lr,
                                  args.lr_decay_itr,
                                  epoch,
                                  args.test_num_updates,
                                  args.load_task_mode,
                                  split_protocol,
                                  arch,
                                  args.tot_num_tasks,
                                  shot,
                                  15,
                                  True,
                                  param_prefix,
                                  train=False,
                                  adv_arch=args.cross_arch_target,
                                  need_val=True)
            learner.network.load_state_dict(checkpoint['state_dict'],
                                            strict=True)
            result_json = finetune_eval_task_accuracy(
                learner.network,
                learner.val_loader,
                learner.inner_step_size,
                learner.test_finetune_updates,
                update_BN=updateBN)
            report_result[dataset + "@" + args.cross_arch_source + "--" +
                          args.cross_arch_target][shot] = result_json
    with open(
            "{}/train_pytorch_model/{}/{}--{}@finetune_{}_result_updateBN_{}.json"
            .format(PY_ROOT, args.study_subject, args.cross_arch_source,
                    args.cross_arch_target, args.test_num_updates, updateBN),
            "w") as file_obj:
        file_obj.write(json.dumps(report_result))
        file_obj.flush()
示例#7
0
def evaluate_zero_shot(model_path_list, lr, protocol, args):
    # deep learning训练是在all_in或者sampled all in下训练的,但是测试需要在task版本的dataset上做
    extract_pattern_detail = re.compile(
        ".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar"
    )
    tot_num_tasks = 20000
    way = 2
    query = 15
    result = defaultdict(dict)
    for model_path in model_path_list:
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        file_protocol = ma.group(2)
        if str(protocol) != file_protocol:
            continue
        balance = ma.group(8)
        if balance == "True":
            balance = "balance"
        else:
            balance = "no_balance"

        key = "{}@{}__{}".format(dataset, balance, protocol)
        if args.cross_domain_source is not None:
            if dataset != args.cross_domain_source:
                continue
            dataset = args.cross_domain_target
            key = "{}@{}-->{}__{}".format(args.cross_domain_source, balance,
                                          dataset, protocol)

        print("evaluate_accuracy model :{}".format(
            os.path.basename(model_path)))
        arch = ma.group(3)
        adv_arch = ma.group(4)
        model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2)
        model = model.cuda()

        checkpoint = torch.load(model_path,
                                map_location=lambda storage, location: storage)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            model_path, checkpoint['epoch']))
        shot = 1
        num_update = 0
        meta_task_dataset = MetaTaskDataset(tot_num_tasks,
                                            way,
                                            shot,
                                            query,
                                            dataset,
                                            is_train=False,
                                            load_mode=LOAD_TASK_MODE.LOAD,
                                            protocol=protocol,
                                            no_random_way=True,
                                            adv_arch=adv_arch)
        data_loader = DataLoader(meta_task_dataset,
                                 batch_size=100,
                                 shuffle=False,
                                 pin_memory=True)
        evaluate_result = finetune_eval_task_accuracy(model,
                                                      data_loader,
                                                      lr,
                                                      num_update,
                                                      update_BN=False)
        result[key][0] = evaluate_result
    return result
示例#8
0
def evaluate_shots(args):
    # 0-shot的时候请传递args.num_updates = 0
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0])
    # IMG_ROTATE_DET@ImageNet_TRAIN_I_TEST_II@model_resnet10@data_resnet10@epoch_10@lr_0.001@batch_10@no_fix_cnn_params.pth.tar
    extract_pattern_detail = re.compile(
        ".*?IMG_ROTATE_DET@(.*?)_(.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)\.pth.tar"
    )
    result = defaultdict(dict)
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_I_TEST_II@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar
    for model_path in glob.glob(
            "{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*"
            .format(PY_ROOT)):
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        if dataset == "ImageNet":
            continue
        split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)]
        if split_protocol != args.protocol:
            continue
        arch = ma.group(3)
        data_adv_arch = ma.group(4)

        lr = float(ma.group(6))
        print("evaluate_accuracy model :{}".format(
            os.path.basename(model_path)))
        tot_num_tasks = 20000
        num_classes = 2

        num_query = 15
        old_num_update = args.num_updates

        all_shots = [0, 1, 5]
        for shot in all_shots:
            if shot == 0:
                shot = 1
                num_update = 0
            else:
                num_update = old_num_update
            meta_task_dataset = MetaTaskDataset(tot_num_tasks,
                                                num_classes,
                                                shot,
                                                num_query,
                                                dataset,
                                                is_train=False,
                                                load_mode=LOAD_TASK_MODE.LOAD,
                                                protocol=split_protocol,
                                                no_random_way=True,
                                                adv_arch=data_adv_arch)
            data_loader = DataLoader(meta_task_dataset,
                                     batch_size=100,
                                     shuffle=False,
                                     pin_memory=True)

            if arch == "resnet10":
                img_classifier_network = resnet10(
                    num_classes=CLASS_NUM[dataset],
                    in_channels=IN_CHANNELS[dataset],
                    pretrained=False)
            elif arch == "resnet18":
                img_classifier_network = resnet18(
                    num_classes=CLASS_NUM[dataset],
                    in_channels=IN_CHANNELS[dataset],
                    pretrained=False)
            elif arch == "conv3":
                img_classifier_network = Conv3(IN_CHANNELS[dataset],
                                               IMAGE_SIZE[dataset],
                                               CLASS_NUM[dataset])

            image_transform = ImageTransformCV2(dataset, [1, 2])
            layer_number = 3 if dataset in [
                "ImageNet", "CIFAR-10", "CIFAR-100", "SVHN"
            ] else 2
            model = Detector(dataset,
                             img_classifier_network,
                             CLASS_NUM[dataset],
                             image_transform,
                             layer_number,
                             num_classes=2)
            checkpoint = torch.load(
                model_path, map_location=lambda storage, location: storage)
            model.load_state_dict(checkpoint['state_dict'])
            model.cuda()
            print("=> loaded checkpoint '{}' (epoch {})".format(
                model_path, checkpoint['epoch']))
            evaluate_result = finetune_eval_task_accuracy(model,
                                                          data_loader,
                                                          lr,
                                                          num_update,
                                                          update_BN=False)
            if num_update == 0:
                shot = 0
            result[dataset + "@" + data_adv_arch][shot] = evaluate_result
    with open(
            "{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/shots_result.json"
            .format(PY_ROOT), "w") as file_obj:
        file_obj.write(json.dumps(result))
        file_obj.flush()
def evaluate_zero_shot(args):
    # 0-shot的时候请传递args.num_updates = 0
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpus[0][0])
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_II_TEST_I@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar
    extract_pattern_detail = re.compile(
        ".*?IMG_ROTATE_DET@(.*?)_(.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@lr_(.*?)@batch_(\d+)@(.*?)\.pth.tar")
    result = defaultdict(dict)
    # IMG_ROTATE_DET@CIFAR-10_TRAIN_I_TEST_II@conv3@epoch_20@lr_0.001@batch_100@no_fix_cnn_params.pth.tar
    for model_path in glob.glob("{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/IMG_ROTATE_DET*.tar".format(PY_ROOT)):
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        # if dataset != "MNIST":
        #     continue
        split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)]
        if split_protocol != args.protocol:
            continue
        arch = ma.group(3)
        adv_arch = ma.group(4)
        if adv_arch != args.adv_arch:
            continue
        epoch = int(ma.group(5))
        lr = float(ma.group(6))
        batch_size = int(ma.group(7))

        tot_num_tasks = 20000
        num_classes = 2

        num_query = 15
        old_num_update = args.num_updates
        shot = 1
        key = "{}@{}".format(dataset, split_protocol)
        if args.cross_domain_source is not None:
            if dataset != args.cross_domain_source:
                continue
            dataset = args.cross_domain_target
            key = "{}-->{}___{}".format(args.cross_domain_source, dataset, split_protocol)

        print("evaluate_accuracy model :{}".format(os.path.basename(model_path)))
        meta_task_dataset = MetaTaskDataset(tot_num_tasks, num_classes, shot, num_query,
                                            dataset, is_train=False, load_mode=LOAD_TASK_MODE.LOAD,
                                            protocol=split_protocol, no_random_way=True, adv_arch="conv3") #FIXME adv arch还没做其他architecture的代码
        data_loader = DataLoader(meta_task_dataset, batch_size=100, shuffle=False, pin_memory=True)
        img_classifier_network = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset],
                                               CLASS_NUM[dataset])
        image_transform = ImageTransformCV2(dataset, [1, 2])
        layer_number = 3 if dataset in ["CIFAR-10", "CIFAR-100", "SVHN"] else 2
        model = Detector(dataset, img_classifier_network, CLASS_NUM[dataset],image_transform, layer_number,num_classes=2)
        checkpoint = torch.load(model_path, map_location=lambda storage, location: storage)
        model.load_state_dict(checkpoint['state_dict'])
        model.cuda()
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(model_path, checkpoint['epoch']))
        evaluate_result = finetune_eval_task_accuracy(model, data_loader, lr, num_updates=0, update_BN=False)
        result[key][0] = evaluate_result
    file_name = "{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/zero_shot_result.json".format(PY_ROOT)
    if args.cross_domain_source:
        file_name = "{}/train_pytorch_model/ROTATE_DET/cv2_rotate_model/zero_shot_{}--{}_result.json".format(PY_ROOT, args.cross_domain_source, args.cross_domain_target)
    with open(file_name, "w") as file_obj:
        file_obj.write(json.dumps(result))
        file_obj.flush()
示例#10
0
def meta_ablation_study_evaluate(args):
    extract_pattern = re.compile(
        ".*/MAML@(.*?)_(.*?)@model_(.*?)@data.*?@epoch_(\d+)@meta_batch_size_(\d+)@way_(\d+)@shot_(\d+)@num_query_(\d+)@num_updates_(\d+)@lr_(.*?)@inner_lr_(.*?)@fixed_way_(.*?)@rotate_(.*?)\.pth.tar"
    )
    extract_param_prefix = re.compile(".*/MAML@(.*?)\.pth.tar")
    report_result = defaultdict(dict)
    str2bool = lambda v: v.lower() in ("yes", "true", "t", "1")
    for model_path in glob.glob("{}/train_pytorch_model/{}/MAML@*".format(
            PY_ROOT, args.study_subject)):
        if str(args.split_protocol) not in model_path:
            continue
        ma_prefix = extract_param_prefix.match(model_path)
        param_prefix = ma_prefix.group(1)
        ma = extract_pattern.match(model_path)
        orig_ma = ma
        dataset = ma.group(1)
        if dataset == "ImageNet":
            continue
        split_protocol = SPLIT_DATA_PROTOCOL[ma.group(2)]
        arch = ma.group(3)
        epoch = int(ma.group(4))
        meta_batch_size = int(ma.group(5))
        num_classes = int(ma.group(6))
        num_support = int(ma.group(7))
        num_query = int(ma.group(8))  # 用这个num_query来做
        num_updates = int(ma.group(9))
        meta_lr = float(ma.group(10))
        inner_lr = float(ma.group(11))
        fixe_way = str2bool(ma.group(12))
        rotate = str2bool(ma.group(13))
        extract_key = param_prefix
        if args.study_subject == "inner_update_ablation_study":
            extract_key = num_updates
        elif args.study_subject == "shots_ablation_study":
            extract_key = num_support
        elif args.study_subject == "cross_adv_group":
            extract_key = num_support
        elif args.study_subject == "tasks_ablation_study":
            extract_key = meta_batch_size
        elif args.study_subject == "ways_ablation_study":
            extract_key = num_classes
        elif args.study_subject == "random_vs_fix_way":
            extract_key = "shots_{}_fixed_way_{}".format(num_support, fixe_way)
        elif args.study_subject == "query_size_ablation_study":
            extract_key = num_query
        elif args.study_subject == "vs_deep_MAX":
            extract_key = num_support
        elif args.study_subject == "zero_shot":
            extract_key = "0-shot"
        print("=> loading checkpoint '{}'".format(model_path))
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, location: storage)

        if args.study_subject == "fine_tune_update_ablation_study":  # 这个实验重做
            learner = MetaLearner(dataset,
                                  num_classes,
                                  meta_batch_size,
                                  meta_lr,
                                  inner_lr,
                                  args.lr_decay_itr,
                                  epoch,
                                  num_updates,
                                  args.load_task_mode,
                                  split_protocol,
                                  arch,
                                  args.tot_num_tasks,
                                  num_support,
                                  15,
                                  True,
                                  param_prefix,
                                  train=False,
                                  adv_arch=args.adv_arch,
                                  need_val=True)
            learner.network.load_state_dict(checkpoint['state_dict'],
                                            strict=True)
            for test_num_updates in range(1, 51):
                result_json = finetune_eval_task_accuracy(learner.network,
                                                          learner.val_loader,
                                                          inner_lr,
                                                          test_num_updates,
                                                          update_BN=True)
                report_result[dataset][test_num_updates] = result_json
        elif args.study_subject == "zero_shot":
            learner = MetaLearner(
                dataset,
                num_classes,
                meta_batch_size,
                meta_lr,
                inner_lr,
                args.lr_decay_itr,
                epoch,
                args.test_num_updates,
                args.load_task_mode,
                split_protocol,
                arch,
                args.tot_num_tasks,
                num_support,
                num_query,  # 这个num_query统一用15
                no_random_way=True,
                tensorboard_data_prefix=param_prefix,
                train=True,
                adv_arch=args.adv_arch,
                need_val=True)
            learner.network.load_state_dict(checkpoint['state_dict'],
                                            strict=True)
            result_json = learner.test_zero_shot_with_finetune_trainset()
            report_result[dataset][extract_key] = result_json
        else:
            load_mode = args.load_task_mode
            learner = MetaLearner(
                dataset,
                num_classes,
                meta_batch_size,
                meta_lr,
                inner_lr,
                args.lr_decay_itr,
                epoch,
                args.test_num_updates,
                load_mode,
                split_protocol,
                arch,
                args.tot_num_tasks,
                num_support,
                15,  # 这个num_query统一用15
                True,
                param_prefix,
                train=False,
                adv_arch=args.adv_arch,
                need_val=True)
            learner.network.load_state_dict(checkpoint['state_dict'],
                                            strict=True)
            result_json = finetune_eval_task_accuracy(learner.network,
                                                      learner.val_loader,
                                                      inner_lr,
                                                      args.test_num_updates,
                                                      update_BN=True)
            report_result[dataset][extract_key] = result_json

    file_name = "{}/train_pytorch_model/{}/{}_result.json".format(
        PY_ROOT, args.study_subject, args.study_subject)
    if args.study_subject == "cross_domain":
        file_name = "{}/train_pytorch_model/{}/{}--{}_result.json".format(
            PY_ROOT, args.study_subject, args.cross_domain_source,
            args.cross_domain_target)
    with open(file_name, "w") as file_obj:
        file_obj.write(json.dumps(report_result))
        file_obj.flush()
def evaluate_cross_arch(model_path_list, num_update, lr, protocol, src_arch,
                        target_arch, updateBN):
    extract_pattern_detail = re.compile(
        ".*?DL_DET@(.*?)_(TRAIN_.*?)@model_(.*?)@data_(.*?)@epoch_(\d+)@class_(\d+)@lr_(.*?)@balance_(.*?)\.pth\.tar"
    )
    tot_num_tasks = 20000
    way = 2
    query = 15
    result = defaultdict(dict)
    for model_path in model_path_list:
        ma = extract_pattern_detail.match(model_path)
        dataset = ma.group(1)
        file_protocol = ma.group(2)
        if str(protocol) != file_protocol:
            continue

        arch = ma.group(3)
        adv_arch = ma.group(4)
        if adv_arch != src_arch:
            continue
        balance = ma.group(8)
        if balance == "True":
            balance = "balance"
        else:
            balance = "no_balance"
        print("evaluate_accuracy model :{}".format(
            os.path.basename(model_path)))
        model = Conv3(IN_CHANNELS[dataset], IMAGE_SIZE[dataset], 2)
        model = model.cuda()
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, location: storage)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            model_path, checkpoint['epoch']))
        old_num_update = num_update
        for shot in [0, 1, 5]:
            if shot == 0:
                shot = 1
                num_update = 0
            else:
                num_update = old_num_update
            meta_task_dataset = MetaTaskDataset(tot_num_tasks,
                                                way,
                                                shot,
                                                query,
                                                dataset,
                                                is_train=False,
                                                load_mode=LOAD_TASK_MODE.LOAD,
                                                protocol=protocol,
                                                no_random_way=True,
                                                adv_arch=target_arch,
                                                fetch_attack_name=False)
            data_loader = DataLoader(meta_task_dataset,
                                     batch_size=100,
                                     shuffle=False,
                                     pin_memory=True)
            evaluate_result = finetune_eval_task_accuracy(model,
                                                          data_loader,
                                                          lr,
                                                          num_update,
                                                          update_BN=updateBN)
            if num_update == 0:
                shot = 0
            result["{}-->{}@{}_{}".format(src_arch, target_arch, dataset,
                                          balance)][shot] = evaluate_result
    return result