def get_img_label_data_loader(datasetname,
                               batch_size,
                               is_train,
                               image_size=None):
     workers = 0
     if image_size is None:
         image_size = IMAGE_SIZE[datasetname]
     preprocessor = DataLoaderMaker.get_preprocessor(image_size, is_train)
     if datasetname == "CIFAR-10":
         train_dataset = CIFAR10(IMAGE_DATA_ROOT[datasetname],
                                 train=is_train,
                                 transform=preprocessor)
     elif datasetname == "CIFAR-100":
         train_dataset = CIFAR100(IMAGE_DATA_ROOT[datasetname],
                                  train=is_train,
                                  transform=preprocessor)
     elif datasetname == "MNIST":
         train_dataset = MNIST(IMAGE_DATA_ROOT[datasetname],
                               train=is_train,
                               transform=preprocessor)
     elif datasetname == "FashionMNIST":
         train_dataset = FashionMNIST(IMAGE_DATA_ROOT[datasetname],
                                      train=is_train,
                                      transform=preprocessor)
     elif datasetname == "TinyImageNet":
         train_dataset = TinyImageNet(IMAGE_DATA_ROOT[datasetname],
                                      preprocessor,
                                      train=is_train)
         workers = 7
     elif datasetname == "ImageNet":
         preprocessor = DataLoaderMaker.get_preprocessor(image_size,
                                                         is_train,
                                                         center_crop=True)
         sub_folder = "/train" if is_train else "/validation"  # Note that ImageNet uses pretrainedmodels.utils.TransformImage to apply transformation
         train_dataset = ImageFolder(IMAGE_DATA_ROOT[datasetname] +
                                     sub_folder,
                                     transform=preprocessor)
         workers = 5
     data_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers)
     return data_loader
    # test_dataset =  CIFAR10(IMAGE_DATA_ROOT[dataset], train=False, transform=test_preprocessor)
elif dataset == "CIFAR-100":
    train_dataset = CIFAR100(IMAGE_DATA_ROOT[dataset],
                             train=True,
                             transform=train_preprocessor)
    # test_dataset = CIFAR100(IMAGE_DATA_ROOT[dataset], train=False, transform=test_preprocessor)
elif dataset == "ImageNet":
    train_preprocessor = DataLoaderMaker.get_preprocessor(IMAGE_SIZE[dataset],
                                                          True,
                                                          center_crop=False)
    # test_preprocessor = DataLoaderMaker.get_preprocessor(IMAGE_SIZE[dataset], False, center_crop=True)
    train_dataset = ImageFolder(IMAGE_DATA_ROOT[dataset] + "/train",
                                transform=train_preprocessor)
elif dataset == "TinyImageNet":
    train_dataset = TinyImageNet(IMAGE_DATA_ROOT[dataset],
                                 train_preprocessor,
                                 train=True)

batch_size = args.batch_size
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=0)

print('==> Building model..')
arch_list = MODELS_TRAIN_STANDARD[args.dataset]
model_dict = {}
for arch in arch_list:
    if StandardModel.check_arch(arch, args.dataset):
        print("begin use arch {}".format(arch))
        model = StandardModel(args.dataset, arch, no_grad=True)
    def __init__(self, tot_num_tasks, dataset, inner_batch_size, protocol):
        """
        Args:
            num_samples_per_class: num samples to generate "per class" in one batch
            batch_size: size of meta batch size (e.g. number of functions)
        """
        self.img_size = IMAGE_SIZE[dataset]
        self.dataset = dataset

        if protocol == SPLIT_DATA_PROTOCOL.TRAIN_I_TEST_II:
            self.model_names = MODELS_TRAIN_STANDARD[self.dataset]
        elif protocol == SPLIT_DATA_PROTOCOL.TRAIN_II_TEST_I:
            self.model_names = MODELS_TEST_STANDARD[self.dataset]
        elif protocol == SPLIT_DATA_PROTOCOL.TRAIN_ALL_TEST_ALL:
            self.model_names = MODELS_TRAIN_STANDARD[
                self.dataset] + MODELS_TEST_STANDARD[self.dataset]

        self.model_dict = {}
        for arch in self.model_names:
            if StandardModel.check_arch(arch, dataset):
                model = StandardModel(dataset, arch, no_grad=False).eval()
                if dataset != "ImageNet":
                    model = model.cuda()
                self.model_dict[arch] = model
        is_train = True
        preprocessor = DataLoaderMaker.get_preprocessor(
            IMAGE_SIZE[dataset], is_train)
        if dataset == "CIFAR-10":
            train_dataset = CIFAR10(IMAGE_DATA_ROOT[dataset],
                                    train=is_train,
                                    transform=preprocessor)
        elif dataset == "CIFAR-100":
            train_dataset = CIFAR100(IMAGE_DATA_ROOT[dataset],
                                     train=is_train,
                                     transform=preprocessor)
        elif dataset == "MNIST":
            train_dataset = MNIST(IMAGE_DATA_ROOT[dataset],
                                  train=is_train,
                                  transform=preprocessor)
        elif dataset == "FashionMNIST":
            train_dataset = FashionMNIST(IMAGE_DATA_ROOT[dataset],
                                         train=is_train,
                                         transform=preprocessor)
        elif dataset == "TinyImageNet":
            train_dataset = TinyImageNet(IMAGE_DATA_ROOT[dataset],
                                         preprocessor,
                                         train=is_train)
        elif dataset == "ImageNet":
            preprocessor = DataLoaderMaker.get_preprocessor(
                IMAGE_SIZE[dataset], is_train, center_crop=True)
            sub_folder = "/train" if is_train else "/validation"  # Note that ImageNet uses pretrainedmodels.utils.TransformImage to apply transformation
            train_dataset = ImageFolder(IMAGE_DATA_ROOT[dataset] + sub_folder,
                                        transform=preprocessor)
        self.train_dataset = train_dataset
        self.total_num_images = len(train_dataset)
        self.all_tasks = dict()
        all_images_indexes = np.arange(self.total_num_images).tolist()
        for i in range(tot_num_tasks):
            self.all_tasks[i] = {
                "image": random.sample(all_images_indexes, inner_batch_size),
                "arch": random.choice(list(self.model_dict.keys()))
            }
def main_train_worker(args):

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
    print("=> creating model '{}'".format(args.arch))
    if args.arch in models.__dict__:
        network = models.__dict__[args.arch](pretrained=True)
    num_classes = CLASS_NUM[args.dataset]
    if args.arch.startswith("resnet"):
        num_ftrs = network.fc.in_features
        network.fc = nn.Linear(num_ftrs, num_classes)
    elif args.arch.startswith("densenet"):
        if args.arch == "densenet161":
            network = densenet161(pretrained=True)
        elif args.arch == "densenet121":
            network = densenet121(pretrained=True)
        elif args.arch == "densenet169":
            network = densenet169(pretrained=True)
        elif args.arch == "densenet201":
            network = densenet201(pretrained=True)
    elif args.arch == "resnext32_4":
        network = resnext101_32x4d(pretrained=None)
    elif args.arch == "resnext64_4":
        network = resnext101_64x4d(pretrained=None)
    elif args.arch == "resnext32_4":
        network = resnext101_32x4d(pretrained="imagenet")
    elif args.arch == "resnext64_4":
        network = resnext101_64x4d(pretrained="imagenet")
    elif args.arch.startswith("squeezenet"):
        network.classifier[-1] = nn.AdaptiveAvgPool2d(1)
        network.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=1)
    elif args.arch.startswith("inception"):
        network = inception_v3(pretrained=True)
    elif args.arch.startswith("vgg"):
        network.avgpool = Identity()
        network.classifier[0] = nn.Linear(512 * 2 * 2, 4096)  # 64 /2**5 = 2
        network.classifier[-1] = nn.Linear(4096, num_classes)

# densenet和inception必须自己改一份新代码,因为forward用了F.avg_pool2d
    model_path = '{}/train_pytorch_model/real_image_model/{}@{}@epoch_{}@lr_{}@batch_{}.pth.tar'.format(
        PY_ROOT, args.dataset, args.arch, args.epochs, args.lr,
        args.batch_size)
    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    print("after train_simulate_grad_mode, model will be saved to {}".format(
        model_path))
    preprocessor = get_preprocessor(IMAGE_SIZE[args.dataset], use_flip=True)
    network.cuda()
    image_classifier_loss = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(network.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    cudnn.benchmark = True
    train_dataset = TinyImageNet(IMAGE_DATA_ROOT[args.dataset],
                                 preprocessor,
                                 train=True)
    test_dataset = TinyImageNet(IMAGE_DATA_ROOT[args.dataset],
                                preprocessor,
                                train=False)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    for epoch in range(0, args.epochs):
        adjust_learning_rate(optimizer, epoch, args)
        # train_simulate_grad_mode for one epoch
        train(train_loader, network, image_classifier_loss, optimizer, epoch,
              args)
        # evaluate_accuracy on validation set
        val_acc = validate(val_loader, network, image_classifier_loss, args)
        # remember best acc@1 and save checkpoint
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                "val_acc": val_acc,
                'state_dict': network.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            filename=model_path)