def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)
        transform = transforms.Compose([
                                       transforms.Scale(opt.loadSize),
                                       transforms.CenterCrop(opt.fineSize),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5))])

        # Dataset A
        dataset_A = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'A',
                                transform=transform, return_paths=True)
        data_loader_A = torch.utils.data.DataLoader(
            dataset_A,
            batch_size=self.opt.batchSize,
            shuffle=not self.opt.serial_batches,
            num_workers=int(self.opt.nThreads))

        # Dataset B
        dataset_B = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'B',
                                transform=transform, return_paths=True)
        data_loader_B = torch.utils.data.DataLoader(
            dataset_B,
            batch_size=self.opt.batchSize,
            shuffle=not self.opt.serial_batches,
            num_workers=int(self.opt.nThreads))
        self.dataset_A = dataset_A
        self.dataset_B = dataset_B
        self.paired_data = PairedData(data_loader_A, data_loader_B)
    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)
        transformations = [
            #transforms.Scale(opt.loadSize),
            MyScale(size=(256, 256), pad=True),
            transforms.RandomCrop(opt.fineSize),
            transforms.ToTensor(),
            # this is wrong! because the fake samples are not normalized like this,
            # still they are inferred on the same network,
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            #lambda x: (x - x.min()) / x.max() * 2 - 1,  # [-1., 1.]
        ]  # 归一化,会产生负数。

        #transformations = [transforms.Scale(opt.loadSize), transforms.RandomCrop(opt.fineSize),
        #                    transforms.ToTensor()]
        transform = transforms.Compose(transformations)

        # Dataset A, eg.. trainA目录
        dataset_A = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'A',
                                transform=transform,
                                return_paths=True)
        data_loader_A = torch.utils.data.DataLoader(
            dataset_A,
            batch_size=self.opt.batchSize,
            shuffle=not self.opt.serial_batches,
            num_workers=int(self.opt.nThreads))

        # Dataset B
        dataset_B = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'B',
                                transform=transform,
                                return_paths=True)
        data_loader_B = torch.utils.data.DataLoader(
            dataset_B,
            batch_size=self.opt.batchSize,
            shuffle=not self.opt.serial_batches,
            num_workers=int(self.opt.nThreads))

        # Dataset C
        dataset_C = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'C',
                                transform=transform,
                                return_paths=True)
        data_loader_C = torch.utils.data.DataLoader(
            dataset_C,
            batch_size=self.opt.batchSize,
            shuffle=not self.opt.serial_batches,
            num_workers=int(self.opt.nThreads))

        # 如何保证 A、B、C是一一一对应的呢,shuffle=not self.opt.serial_batches 这个参数 serial_batches 为True,代表有序,否则随机
        # shuffle 是洗牌,搅乱的意思

        self.dataset_A = dataset_A
        self.dataset_B = dataset_B
        self.dataset_C = dataset_C
        flip = opt.isTrain and not opt.no_flip
        self.three_paired_data = ThreePairedData(data_loader_A, data_loader_B,
                                                 data_loader_C,
                                                 self.opt.max_dataset_size,
                                                 flip)
Exemplo n.º 3
0
    def __init__(self, opt):
        super(ImageEdgeDataset, self).__init__(opt)
        self.image_paths = ImageFolder(os.path.join(
            self.root, opt.path_image))  # image path list
        self.edge_paths = ImageFolder(os.path.join(
            self.root, opt.path_edge))  # edge path list
        assert len(self.image_paths) == len(self.edge_paths)

        self.transform = get_transform()
Exemplo n.º 4
0
    def build_model(self):
        """ DataLoader """
        train_transform = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        test_transform = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        label_transform = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.ToTensor()])

        self.trainA = ImageFolder(os.path.join(self.dataset, 'trainA'), train_transform,extend_paths=True,return_paths=False)
        self.trainB = ImageFolder(os.path.join(self.dataset, 'trainB'), train_transform,extend_paths=True,return_paths=False)
        self.label = ImageFolder(os.path.join(self.dataset, 'label'), label_transform, extend_paths=True,loader="gray")
        self.testA = ImageFolder(os.path.join(self.dataset, 'testA'), test_transform,extend_paths=True,return_paths=True)
        self.testB = ImageFolder(os.path.join(self.dataset, 'testB'), test_transform,extend_paths=True,return_paths=True)
        self.trainA_loader = DataLoader(self.trainA, batch_size=self.batch_size, shuffle=True,pin_memory=True)
        self.trainB_loader = DataLoader(self.trainB, batch_size=self.batch_size, shuffle=True,pin_memory=True)
        self.label_loader = DataLoader(self.label, batch_size=self.batch_size, shuffle=True,pin_memory=True)
        self.testA_loader = DataLoader(self.testA, batch_size=1, shuffle=False,pin_memory=True)
        self.testB_loader = DataLoader(self.testB, batch_size=1, shuffle=False,pin_memory=True)

        """ Define Generator, Discriminator """
        self.gen2B = networks.NiceResnetGenerator(input_nc=self.img_ch, output_nc=self.img_ch, ngf=self.ch, n_blocks=self.n_res, img_size=self.img_size, light=self.light).to(self.device)
        self.gen2A = networks.NiceResnetGenerator(input_nc=self.img_ch, output_nc=self.img_ch, ngf=self.ch, n_blocks=self.n_res, img_size=self.img_size, light=self.light).to(self.device)
        self.disA = networks.NiceDiscriminator(input_nc=self.img_ch, ndf=self.ch, n_layers=self.n_dis).to(self.device)
        self.disB = networks.NiceDiscriminator(input_nc=self.img_ch, ndf=self.ch, n_layers=self.n_dis).to(self.device)

        print('-----------------------------------------------')
        input = torch.randn([1, self.img_ch, self.img_size, self.img_size]).to(self.device)
        macs, params = profile(self.disA, inputs=(input, ))
        macs, params = clever_format([macs*2, params*2], "%.3f")
        print('[Network %s] Total number of parameters: ' % 'disA', params)
        print('[Network %s] Total number of FLOPs: ' % 'disA', macs)
        print('-----------------------------------------------')
        _,_, _,  _, real_A_ae = self.disA(input)
        macs, params = profile(self.gen2B, inputs=(real_A_ae, ))
        macs, params = clever_format([macs*2, params*2], "%.3f")
        print('[Network %s] Total number of parameters: ' % 'gen2B', params)
        print('[Network %s] Total number of FLOPs: ' % 'gen2B', macs)
        print('-----------------------------------------------')

        """ Define Loss """
        self.L1_loss = nn.L1Loss().to(self.device)
        self.MSE_loss = nn.MSELoss().to(self.device)

        """ Trainer """
        self.G_optim = torch.optim.Adam(itertools.chain(self.gen2B.parameters(), self.gen2A.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay)
        self.D_optim = torch.optim.Adam(itertools.chain(self.disA.parameters(), self.disB.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay)
Exemplo n.º 5
0
def CreateDataset(opt, test=False):
    if opt.phase == 'train':
        from data.image_folder import ImageFolder
        target_dataset = ImageFolder(opt, opt.dataroot)
        print("dataset was created")
        return target_dataset
    else:
        from data.image_folder import ImageFolder
        source_dataset = ImageFolder(opt, opt.dataroot)
        print("dataset was created")
    return source_dataset
    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)
        self.fineSize = opt.fineSize

        transformations = [
            # TODO: Scale
            transforms.Scale(opt.loadSize),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
        transform = transforms.Compose(transformations)

        # Dataset A
        dataset = ImageFolder(root=opt.dataroot + '/' + opt.phase,
                              transform=transform,
                              return_paths=True)
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.opt.batchSize,
            shuffle=not self.opt.serial_batches,
            num_workers=int(self.opt.nThreads))

        self.dataset = dataset

        flip = opt.isTrain and not opt.no_flip
        self.paired_data = PairedData(data_loader, opt.fineSize,
                                      opt.max_dataset_size, flip)
Exemplo n.º 7
0
def get_data_loader_folder(input_folder,
                           batch_size,
                           train,
                           new_size=None,
                           height=256,
                           width=256,
                           num_workers=4,
                           crop=True):
    transform_list = [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
    transform_list = [transforms.RandomCrop(
        (height, width))] + transform_list if crop else transform_list
    transform_list = [transforms.Resize(
        new_size)] + transform_list if new_size is not None else transform_list
    transform_list = [transforms.RandomHorizontalFlip()
                      ] + transform_list if train else transform_list
    transform = transforms.Compose(transform_list)
    dataset = ImageFolder(input_folder, transform=transform)
    loader = DataLoader(dataset=dataset,
                        batch_size=batch_size,
                        shuffle=train,
                        drop_last=True,
                        num_workers=num_workers)
    return loader
    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)
        self.fineSize = opt.fineSize

	# transform the image data to the tensor range in (0, 1)   
        transform = transforms.Compose([
	    # this is fake, we will delete it later				
            # transforms.Scale((opt.loadSize, opt.loadSize), interpolation=Image.BILINEAR),  # BICUBIC or ANTIALIAS
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5),
                                 (0.5, 0.5, 0.5))])

        # Dataset A
        dataset = ImageFolder(root=opt.dataroot + '/' + opt.phase, caption = opt.caption, caption_bucket = opt.bucket_description, vocab = opt.vocab, data_augment = opt.augment_data,
                              transform=transform, return_paths=True)


        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.opt.batchSize,
            shuffle=not self.opt.serial_batches,
            num_workers=int(self.opt.nThreads),
	    drop_last=False
            )
            
        self.dataset = dataset
        self.paired_data = PairedData(data_loader, opt.fineSize)
    def __init__(self, params):
        transform = transforms.Compose([
            transforms.Scale(size=(params.load_size, params.load_size)),
            transforms.RandomCrop(size=(params.height, params.width)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset_A = torch.utils.data.DataLoader(ImageFolder(
            root=params.data_root + '/' + params.phase + 'A',
            transform=transform),
                                                num_workers=params.num_workers,
                                                shuffle=params.shuffle)

        dataset_B = torch.utils.data.DataLoader(ImageFolder(
            root=params.data_root + '/' + params.phase + 'B',
            transform=transform),
                                                num_workers=params.num_workers,
                                                shuffle=params.shuffle)

        self.dataset_A = dataset_A
        self.dataset_B = dataset_B
        self.paired_data = PairedData(self.dataset_A, self.dataset_B)
Exemplo n.º 10
0
    def __init__(self, _root, _list_dir, _input_height, _input_width, _is_flip,
                 _shuffle):
        transform = None
        dataset = ImageFolder(root=_root, \
                list_dir =_list_dir, input_height = _input_height, input_width = _input_width, transform=transform, is_flip = _is_flip)

        data_loader = torch.utils.data.DataLoader(dataset,
                                                  batch_size=32,
                                                  shuffle=_shuffle,
                                                  num_workers=int(3))

        self.dataset = dataset
        flip = False
        self.paired_data = PairedData(data_loader, flip)
Exemplo n.º 11
0
    def test(self):
        args = self.args
        input_dim = args.input_nc
        if args.direction == 'AtoB':
            a2b = 1
            test_type = 'testA'
            output_type = 'fakeB'
        else:
            a2b = 0
            test_type = 'testB'
            output_type = 'fakeA'
        # Setup model and data loader
        image_path = os.path.join(self.args.dataroot, test_type)
        image_names = ImageFolder(image_path,
                                  transform=None,
                                  return_paths=True)
        data_loader = get_data_loader_folder(image_path,
                                             1,
                                             False,
                                             new_size=args.load_size,
                                             crop=False)
        checkpoint = os.path.join(args.result_dir, args.name, 'model',
                                  'gen_%08d.pt' % (args.iteration))
        try:
            state_dict = torch.load(checkpoint)
            self.gen_a.load_state_dict(state_dict['a'])
            self.gen_b.load_state_dict(state_dict['b'])
        except:
            state_dict = pytorch03_to_pytorch04(torch.load(checkpoint))
            self.gen_a.load_state_dict(state_dict['a'])
            self.gen_b.load_state_dict(state_dict['b'])

        self.to(self.device)
        #self.eval()
        encode = self.gen_a.encode if a2b else self.gen_b.encode  # encode function
        decode = self.gen_b.decode if a2b else self.gen_a.decode  # decode function
        for i, (images, names) in enumerate(zip(data_loader, image_names)):
            print(names[1])
            images = Variable(images.to(self.device), volatile=True)
            content, _ = encode(images)

            outputs = decode(content)
            outputs = (outputs + 1) / 2.
            # path = os.path.join(opts.output_folder, 'input{:03d}_output{:03d}.jpg'.format(i, j))
            basename = os.path.basename(names[1])
            path = os.path.join(args.result_dir, args.name, output_type,
                                basename)
            if not os.path.exists(os.path.dirname(path)):
                os.makedirs(os.path.dirname(path))
            vutils.save_image(outputs.data, path, padding=0, normalize=True)
Exemplo n.º 12
0
    def initialize(self, opt):

        BaseDataLoader.initialize(self, opt)
        self.fineSize = opt.fineSize
        transform = transforms.Compose([
            # TODO: Scale
            transforms.Resize(opt.loadSize),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # Dataset A
        dataset = ImageFolder(root=opt.dataroot + '/' + opt.phase,
                              transform=transform,
                              return_paths=True,
                              font_trans=(not opt.flat),
                              rgb=opt.rgb,
                              fineSize=opt.fineSize,
                              loadSize=opt.loadSize)
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.opt.batchSize,
            shuffle=not self.opt.serial_batches,
            num_workers=int(self.opt.nThreads))

        self.dataset = dataset
        dict_inds = {}
        test_dict = opt.dataroot + '/test_dict/dict.pkl'
        print("test_dict", test_dict)
        if opt.phase == 'test':
            if os.path.isfile(test_dict):
                with open(test_dict, 'rb') as f:
                    dict_inds = pickle.load(f, encoding='bytes')
            else:
                warnings.warn(
                    'Blanks in test data are random. create a pkl file in ~/data_path/test_dict/dict.pkl including predifined random indices'
                )

        if opt.flat:
            self._data = FlatData(data_loader, data_loader_base, opt.fineSize,
                                  opt.max_dataset_size, opt.rgb, dict_inds,
                                  opt.base_font, opt.blanks)
        else:
            self._data = Data(data_loader, opt.fineSize, opt.max_dataset_size,
                              opt.rgb, dict_inds, opt.blanks)
    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)
        self.fineSize = opt.fineSize
        transform = transforms.Compose([
            # TODO: Scale
            #transforms.Scale((opt.loadSize * 2, opt.loadSize)),
            #transforms.CenterCrop(opt.fineSize),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        # Dataset A
        dataset = ImageFolder(root=opt.dataroot + '/' + opt.phase,
                              transform=transform,
                              return_paths=True)
        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.opt.batchSize,
            shuffle=not self.opt.serial_batches,
            num_workers=int(self.opt.nThreads))

        self.dataset = dataset
        self.paired_data = PairedData(data_loader, opt.fineSize)
Exemplo n.º 14
0
    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)
        transform = transforms.Compose([
            # TODO: Scale
            transforms.Resize(opt.loadSize),
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5),
            # (0.5, 0.5, 0.5))
        ])
        dic_phase = {'train': 'Train', 'test': 'Test'}
        # Dataset A

        dataset_A = ImageFolder(root=opt.dataroot + '/A/' + opt.phase,
                                transform=transform,
                                return_paths=True,
                                rgb=opt.rgb,
                                fineSize=opt.fineSize,
                                loadSize=opt.loadSize,
                                font_trans=False,
                                no_permutation=opt.no_permutation)
        len_A = len(dataset_A.imgs)
        if not opt.no_permutation:
            shuffle_inds = np.random.permutation(len_A)
        else:
            shuffle_inds = range(len_A)

        dataset_B = ImageFolder(root=opt.dataroot + '/B/' + opt.phase,
                                transform=transform,
                                return_paths=True,
                                rgb=opt.rgb,
                                fineSize=opt.fineSize,
                                loadSize=opt.loadSize,
                                font_trans=False,
                                no_permutation=opt.no_permutation)

        if len(dataset_A.imgs) != len(dataset_B.imgs):
            raise Exception(
                "number of images in source folder and target folder does not match"
            )

        if (opt.partial and (not self.opt.serial_batches)):
            dataset_A.imgs = [dataset_A.imgs[i] for i in shuffle_inds]
            dataset_B.imgs = [dataset_B.imgs[i] for i in shuffle_inds]
            dataset_A.img_crop = [dataset_A.img_crop[i] for i in shuffle_inds]
            dataset_B.img_crop = [dataset_B.img_crop[i] for i in shuffle_inds]
            shuffle = False
        else:
            shuffle = not self.opt.serial_batches
        data_loader_A = torch.utils.data.DataLoader(
            dataset_A,
            batch_size=self.opt.batchSize,
            shuffle=shuffle,
            num_workers=int(self.opt.nThreads))

        data_loader_B = torch.utils.data.DataLoader(
            dataset_B,
            batch_size=self.opt.batchSize,
            shuffle=shuffle,
            num_workers=int(self.opt.nThreads))

        if opt.base_font:
            # Read and apply transformation on the BASE font
            dataset_base = ImageFolder(root=opt.base_root,
                                       transform=transform,
                                       return_paths=True,
                                       font_trans=True,
                                       rgb=opt.rgb,
                                       fineSize=opt.fineSize,
                                       loadSize=opt.loadSize)
            data_loader_base = torch.utils.data.DataLoader(
                dataset_base,
                batch_size=1,
                shuffle=False,
                num_workers=int(self.opt.nThreads))
        else:
            data_loader_base = None

        self.dataset_A = dataset_A
        self._data = PartialData(data_loader_A, data_loader_B,
                                 data_loader_base, opt.fineSize, opt.loadSize,
                                 opt.max_dataset_size, opt.phase,
                                 opt.base_font)
Exemplo n.º 15
0
    #     fineSize=176, gpu_ids=[0], half=False, how_many=200, init_type='normal',
    #     input_nc=1, isTrain=True, l_cent=50.0, l_norm=100.0, lambda_A=1.0, lambda_B=1.0,
    #     lambda_GAN=0.0, lambda_identity=0.5, loadSize=256, load_model=True, lr=0.0001,
    #     lr_decay_iters=50, lr_policy='lambda', mask_cent=0.5, max_dataset_size=inf,
    #     model='pix2pix', n_layers_D=3, name='siggraph_retrained',
    #     ndf=64, ngf=64, niter=100, niter_decay=100, no_dropout=False, no_flip=False,
    #     no_html=False, no_lsgan=False, norm='batch', num_threads=1, output_nc=2, phase='val',
    #     pool_size=50, print_freq=200, resize_or_crop='resize_and_crop',
    #     results_dir='./results/', sample_Ps=[1, 2, 3, 4, 5, 6, 7, 8, 9],
    #     sample_p=1.0, save_epoch_freq=1, save_latest_freq=5000, serial_batches=True,
    #     suffix='', update_html_freq=10000, verbose=False, which_direction='AtoB',
    #     which_epoch='latest', which_model_netD='basic', which_model_netG='siggraph')

    dataset = ImageFolder(opt.dataroot,
                          transform=transforms.Compose([
                              transforms.Resize((opt.loadSize, opt.loadSize)),
                              transforms.ToTensor()
                          ]))

    dataset_loader = torch.utils.data.DataLoader(
        dataset, batch_size=opt.batch_size, shuffle=not opt.serial_batches)
    model = create_model(opt)

    model.setup(opt)
    model.eval()

    # pdb.set_trace();

    # create website
    web_dir = os.path.join(opt.results_dir, opt.name,
                           '%s_%s' % (opt.phase, opt.which_epoch))
Exemplo n.º 16
0
def CreateDataset():
    dataset = None
    from data.image_folder import ImageFolder
    dataset = ImageFolder('../../dataset/sbu')
    return dataset
Exemplo n.º 17
0
    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)
        transformations = [
            #transforms.Scale(opt.loadSize),
            MyScale(size=(256, 256), pad=True),
            transforms.RandomCrop(opt.fineSize),
            transforms.ToTensor(),
            # this is wrong! because the fake samples are not normalized like this,
            # still they are inferred on the same network,
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            #lambda x: (x - x.min()) / x.max() * 2 - 1,  # [-1., 1.]
        ]  # 归一化,会产生负数。

        #transformations = [transforms.Scale(opt.loadSize), transforms.RandomCrop(opt.fineSize),
        #                    transforms.ToTensor()]
        transform = transforms.Compose(transformations)

        # Dataset A1, eg.. train/A1目录
        dataset_A1 = ImageFolder(root=opt.dataroot + '/' + opt.phase + '/A1',
                                 transform=transform,
                                 return_paths=True)
        data_loader_A1 = torch.utils.data.DataLoader(
            dataset_A1,
            batch_size=self.opt.batchSize,
            shuffle=not self.opt.serial_batches,
            num_workers=int(self.opt.nThreads))

        dataset_A2 = ImageFolder(root=opt.dataroot + '/' + opt.phase + '/A2',
                                 transform=transform,
                                 return_paths=True)
        data_loader_A2 = torch.utils.data.DataLoader(
            dataset_A2,
            batch_size=self.opt.batchSize,
            shuffle=not self.opt.serial_batches,
            num_workers=int(self.opt.nThreads))

        dataset_B1 = ImageFolder(root=opt.dataroot + '/' + opt.phase + '/B1',
                                 transform=transform,
                                 return_paths=True)
        data_loader_B1 = torch.utils.data.DataLoader(
            dataset_B1,
            batch_size=self.opt.batchSize,
            shuffle=not self.opt.serial_batches,
            num_workers=int(self.opt.nThreads))

        dataset_B2 = ImageFolder(root=opt.dataroot + '/' + opt.phase + '/B2',
                                 transform=transform,
                                 return_paths=True)
        data_loader_B2 = torch.utils.data.DataLoader(
            dataset_B2,
            batch_size=self.opt.batchSize,
            shuffle=not self.opt.serial_batches,
            num_workers=int(self.opt.nThreads))

        # 如何保证 A1、A2、B1、B2是一一一对应的呢,shuffle=not self.opt.serial_batches 这个参数 serial_batches 为True,代表有序,否则随机
        # shuffle 是洗牌,搅乱的意思
        #
        # 奇怪,如果A和B的数据数量不一致呢?也可以在opt里面修改load的数量大小的

        self.dataset_A1 = dataset_A1
        self.dataset_A2 = dataset_A2
        self.dataset_B1 = dataset_B1
        self.dataset_B2 = dataset_B2
        flip = opt.isTrain and not opt.no_flip
        self.four_paired_data = FourPairedData(data_loader_A1, data_loader_A2,
                                               data_loader_B1, data_loader_B2,
                                               self.opt.max_dataset_size, flip)