Exemple #1
0
 def evaluate_dataloader(self,
                         data: DataLoader,
                         criterion: Callable,
                         metric,
                         logger,
                         ratio_width=None,
                         output=False,
                         official=False,
                         confusion_matrix=False,
                         **kwargs):
     self.model.eval()
     self.reset_metrics(metric)
     timer = CountdownTimer(len(data))
     total_loss = 0
     if official:
         sentences = []
         gold = []
         pred = []
     for batch in data:
         output_dict = self.feed_batch(batch)
         if official:
             sentences += batch['token']
             gold += batch['srl']
             pred += output_dict['prediction']
         self.update_metrics(batch, output_dict, metric)
         loss = output_dict['loss']
         total_loss += loss.item()
         timer.log(self.report_metrics(total_loss / (timer.current + 1), metric), ratio_percentage=None,
                   logger=logger,
                   ratio_width=ratio_width)
         del loss
     if official:
         scores = compute_srl_f1(sentences, gold, pred)
         if logger:
             if confusion_matrix:
                 labels = sorted(set(y for x in scores.label_confusions.keys() for y in x))
                 headings = ['GOLD↓PRED→'] + labels
                 matrix = []
                 for i, gold in enumerate(labels):
                     row = [gold]
                     matrix.append(row)
                     for j, pred in enumerate(labels):
                         row.append(scores.label_confusions.get((gold, pred), 0))
                 matrix = markdown_table(headings, matrix)
                 logger.info(f'{"Confusion Matrix": ^{len(matrix.splitlines()[0])}}')
                 logger.info(matrix)
             headings = ['Settings', 'Precision', 'Recall', 'F1']
             data = []
             for h, (p, r, f) in zip(['Unlabeled', 'Labeled', 'Official'], [
                 [scores.unlabeled_precision, scores.unlabeled_recall, scores.unlabeled_f1],
                 [scores.precision, scores.recall, scores.f1],
                 [scores.conll_precision, scores.conll_recall, scores.conll_f1],
             ]):
                 data.append([h] + [f'{x:.2%}' for x in [p, r, f]])
             table = markdown_table(headings, data)
             logger.info(f'{"Scores": ^{len(table.splitlines()[0])}}')
             logger.info(table)
     else:
         scores = metric
     return total_loss / timer.total, scores
Exemple #2
0
def loading_data(args):
    mean_std = cfg_data.MEAN_STD
    log_para = cfg_data.LOG_PARA

    sou_main_transform = own_transforms.Compose([
        own_transforms.RandomCrop(cfg_data.TRAIN_SIZE),
        own_transforms.RandomHorizontallyFlip(),
        # Rand_Augment()
    ])

    # converts a PIL Image(H*W*C) in the range[0,255]
    # to a torch.FloatTensor of shape (C*H*W) in the range[0.0, 1.0]
    img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    gt_transform = standard_transforms.Compose(
        [own_transforms.LabelNormalize(log_para)])
    restore_transform = standard_transforms.Compose([
        own_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    if args.phase == 'DA_train' or args.phase == 'fine_tune':
        # Load meta-train set
        IFS_path = '/media/D/ht/C-3-Framework-trans/trans-display/GCC2SHHB/s2t'
        IFS_path = '/media/D/ht/C-3-Framework-trans/trans-display/GCC2QNRF/s2t'
        IFS_path = '/media/D/ht/C-3-Framework-trans/trans-display/GCC2WE/s2t'
        trainset = GCC('train',
                       main_transform=sou_main_transform,
                       img_transform=img_transform,
                       gt_transform=gt_transform,
                       filter_rule=cfg_data.FILTER_RULE,
                       IFS_path=None)
        sou_loader = DataLoader(trainset,
                                batch_size=cfg_data.sou_batch_size,
                                shuffle=True,
                                num_workers=12,
                                drop_last=True,
                                pin_memory=True)

        if args.target_dataset == 'QNRF':
            tar_main_transform = own_transforms.Compose(
                [own_transforms.RandomHorizontallyFlip()])
            trainset = QNRF('train',
                            main_transform=tar_main_transform,
                            img_transform=img_transform,
                            gt_transform=gt_transform)
            tar_shot_loader = DataLoader(trainset,
                                         batch_size=cfg_data.target_shot_size,
                                         shuffle=True,
                                         num_workers=12,
                                         collate_fn=SHHA_collate,
                                         drop_last=True)

            valset = QNRF('val',
                          img_transform=img_transform,
                          gt_transform=gt_transform)
            tar_val_loader = DataLoader(valset,
                                        batch_size=1,
                                        num_workers=8,
                                        pin_memory=True)

            testset = QNRF('test',
                           img_transform=img_transform,
                           gt_transform=gt_transform)
            tar_test_loader = DataLoader(testset,
                                         batch_size=1,
                                         num_workers=8,
                                         pin_memory=True)
        elif args.target_dataset == 'SHHA':
            tar_main_transform = own_transforms.Compose(
                [own_transforms.RandomHorizontallyFlip()])
            trainset = SHHA('train',
                            main_transform=tar_main_transform,
                            img_transform=img_transform,
                            gt_transform=gt_transform)
            tar_shot_loader = DataLoader(trainset,
                                         batch_size=cfg_data.target_shot_size,
                                         shuffle=True,
                                         num_workers=12,
                                         collate_fn=SHHA_collate,
                                         drop_last=True)

            valset = SHHA('val',
                          img_transform=img_transform,
                          gt_transform=gt_transform)
            tar_val_loader = DataLoader(valset,
                                        batch_size=1,
                                        num_workers=8,
                                        pin_memory=True)

            testset = SHHA('test',
                           img_transform=img_transform,
                           gt_transform=gt_transform)
            tar_test_loader = DataLoader(testset,
                                         batch_size=1,
                                         num_workers=8,
                                         pin_memory=True)

        elif args.target_dataset == 'MALL':
            tar_main_transform = own_transforms.Compose([
                own_transforms.RandomCrop(cfg_data.MALL_TRAIN_SIZE),
                own_transforms.RandomHorizontallyFlip()
            ])
            trainset = MALL('train',
                            main_transform=tar_main_transform,
                            img_transform=img_transform,
                            gt_transform=gt_transform)
            tar_shot_loader = DataLoader(trainset,
                                         batch_size=cfg_data.target_shot_size,
                                         shuffle=True,
                                         num_workers=12,
                                         drop_last=True,
                                         pin_memory=True)

            valset = MALL('val',
                          img_transform=img_transform,
                          gt_transform=gt_transform)
            tar_val_loader = DataLoader(valset,
                                        batch_size=8,
                                        num_workers=8,
                                        pin_memory=True)

            testset = MALL('test',
                           img_transform=img_transform,
                           gt_transform=gt_transform)
            tar_test_loader = DataLoader(testset,
                                         batch_size=12,
                                         num_workers=8,
                                         pin_memory=True)

        elif args.target_dataset == 'UCSD':
            tar_main_transform = own_transforms.Compose([
                own_transforms.RandomCrop(cfg_data.UCSD_TRAIN_SIZE),
                own_transforms.RandomHorizontallyFlip(),
            ])
            trainset = UCSD('train',
                            main_transform=tar_main_transform,
                            img_transform=img_transform,
                            gt_transform=gt_transform)
            tar_shot_loader = DataLoader(trainset,
                                         batch_size=cfg_data.target_shot_size,
                                         shuffle=True,
                                         num_workers=12,
                                         drop_last=True,
                                         pin_memory=True)

            valset = UCSD('val',
                          img_transform=img_transform,
                          gt_transform=gt_transform)
            tar_val_loader = DataLoader(valset,
                                        batch_size=8,
                                        num_workers=8,
                                        pin_memory=True)

            testset = UCSD('test',
                           img_transform=img_transform,
                           gt_transform=gt_transform)
            tar_test_loader = DataLoader(testset,
                                         batch_size=12,
                                         num_workers=8,
                                         pin_memory=True)
        elif args.target_dataset == 'SHHB':
            tar_main_transform = own_transforms.Compose([
                own_transforms.RandomCrop(cfg_data.SHHB_TRAIN_SIZE),
                own_transforms.RandomHorizontallyFlip(),
                # Rand_Augment()
            ])

            trainset = SHHB('train',
                            main_transform=tar_main_transform,
                            img_transform=img_transform,
                            gt_transform=gt_transform)
            tar_shot_loader = DataLoader(trainset,
                                         batch_size=cfg_data.target_shot_size,
                                         shuffle=True,
                                         num_workers=8,
                                         drop_last=True,
                                         pin_memory=True)

            valset = SHHB('val',
                          img_transform=img_transform,
                          gt_transform=gt_transform)
            tar_val_loader = DataLoader(valset,
                                        batch_size=8,
                                        num_workers=8,
                                        pin_memory=True)

            testset = SHHB('test',
                           img_transform=img_transform,
                           gt_transform=gt_transform)
            tar_test_loader = DataLoader(testset,
                                         batch_size=8,
                                         num_workers=8,
                                         pin_memory=True)

        elif args.target_dataset == 'WE':
            tar_test_loader = []
            tar_main_transform = own_transforms.Compose([
                own_transforms.RandomCrop(cfg_data.WE_TRAIN_SIZE),
                own_transforms.RandomHorizontallyFlip(),
                # Rand_Augment()
            ])
            trainset = WE(None,
                          'train',
                          main_transform=tar_main_transform,
                          img_transform=img_transform,
                          gt_transform=gt_transform)
            tar_shot_loader = DataLoader(trainset,
                                         batch_size=cfg_data.target_shot_size,
                                         shuffle=True,
                                         num_workers=8,
                                         drop_last=True,
                                         pin_memory=True)
            valset = WE(None,
                        'val',
                        main_transform=tar_main_transform,
                        img_transform=img_transform,
                        gt_transform=gt_transform)
            tar_val_loader = DataLoader(valset,
                                        batch_size=12,
                                        shuffle=False,
                                        num_workers=8,
                                        drop_last=False,
                                        pin_memory=True)

            for subname in cfg_data.WE_test_list:
                sub_set = WE(subname,
                             'test',
                             img_transform=img_transform,
                             gt_transform=gt_transform)
                tar_test_loader.append(
                    DataLoader(sub_set,
                               batch_size=12,
                               num_workers=8,
                               pin_memory=True))
        else:
            print(
                "Please set the target dataset as one of them:SHHB,  UCF50,  QNRF, MALL, UCSD, SHHA"
            )

        return sou_loader, tar_shot_loader, tar_val_loader, tar_test_loader, restore_transform

    if args.phase == 'pre_train':
        trainset = GCC('train',
                       main_transform=sou_main_transform,
                       img_transform=img_transform,
                       gt_transform=gt_transform)
        train_loader = DataLoader(trainset,
                                  batch_size=args.pre_batch_size,
                                  shuffle=True,
                                  num_workers=8,
                                  drop_last=True,
                                  pin_memory=True)

        valset = GCC('val',
                     img_transform=img_transform,
                     gt_transform=gt_transform)
        val_loader = DataLoader(valset,
                                batch_size=12,
                                num_workers=8,
                                pin_memory=True)

        return train_loader, val_loader, restore_transform
Exemple #3
0
def main(argv):
    # Read arguments passed
    (opts, args) = parser.parse_args(argv)

    # Reading config
    cfg = config(opts.config,
                 debugging=False,
                 additionalText="training_ERM_seen_resnet18")

    # Use CUDA
    # os.environ['CUDA_VISIBLE_DEVICES'] = 1
    use_cuda = torch.cuda.is_available()

    # If the manual seed is not yet choosen
    if cfg.manualSeed == None:
        cfg.manualSeed = 1

    # Set seed for reproducibility for CPU and GPU randomizaton process
    random.seed(cfg.manualSeed)
    torch.manual_seed(cfg.manualSeed)

    if use_cuda:
        torch.cuda.manual_seed_all(cfg.manualSeed)

    dataloader_train = None
    if hasattr(cfg, "train_mode"):

        # Preprocessing (transformation) instantiation for training groupwise
        transformation_train = torchvision.transforms.Compose([
            transforms.GroupMultiScaleCrop(224, [1, 0.875, 0.75, 0.66]),
            transforms.GroupRandomHorizontalFlip(is_flow=False),
            transforms.Stack(),  # concatenation of images
            transforms.ToTorchFormatTensor(),  # to torch
            transforms.GroupNormalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224,
                                           0.225]),  # Normalization
        ])

        if cfg.algo == "ERM" or cfg.algo == "MTGA":
            # Loading training Dataset with N segment for TSN
            EPICdata_train = EPIC(
                mode=cfg.train_mode,
                cfg=cfg,
                transforms=transformation_train,
            )

            # Creating validation dataloader
            # batch size = 16, num_workers = 8 are best fit for 12 Gb GPU and >= 16 Gb RAM
            dataloader_train = DataLoader(
                EPICdata_train,
                batch_size=cfg.train_batch_size,
                shuffle=True,
                num_workers=cfg.num_worker_train,
                pin_memory=True,
            )
        elif cfg.algo == "IRM":
            df = pd.read_csv(cfg.anno_path)
            p_ids = list(set(df["participant_id"].tolist()))

            dataloader_train = []
            for p_id in p_ids:
                tmp_dataset = EPIC(
                    mode=cfg.train_mode,
                    cfg=cfg,
                    transforms=transformation_train,
                    participant_id=p_id,
                )

                if tmp_dataset.haveData:
                    dataloader_train.append(
                        DataLoader(
                            tmp_dataset,
                            batch_size=cfg.train_batch_size,
                            shuffle=True,
                            num_workers=cfg.num_worker_train,
                            pin_memory=True,
                        ))
        elif cfg.algo == "FSL":
            dataloader_train = {}
            # Loading training Dataset with N segment for TSN
            EPICdata_train_verb = EPIC(mode=cfg.train_mode,
                                       cfg=cfg,
                                       transforms=transformation_train)
            sampler = CategoriesSampler(EPICdata_train_verb.verb_label, 200,
                                        cfg.way, cfg.shot + cfg.query)
            dataloader_train["verb"] = DataLoader(
                dataset=EPICdata_train_verb,
                batch_sampler=sampler,
                num_workers=cfg.num_worker_train,
                pin_memory=True,
            )

            EPICdata_train_noun = EPIC(mode=cfg.train_mode,
                                       cfg=cfg,
                                       transforms=transformation_train)
            sampler = CategoriesSampler(EPICdata_train_noun.noun_label, 200,
                                        cfg.way, cfg.shot + cfg.query)
            dataloader_train["noun"] = DataLoader(
                dataset=EPICdata_train_noun,
                batch_sampler=sampler,
                num_workers=cfg.num_worker_train,
                pin_memory=True,
            )

    dataloader_val = None
    if hasattr(cfg, "val_mode") and hasattr(cfg, "train_mode"):
        # Preprocessing (transformation) instantiation for validation groupwise
        transformation_val = torchvision.transforms.Compose([
            transforms.GroupOverSample(
                224, 256),  # group sampling from images using multiple crops
            transforms.Stack(),  # concatenation of images
            transforms.ToTorchFormatTensor(),  # to torch
            transforms.GroupNormalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224,
                                           0.225]),  # Normalization
        ])

        # Loading validation Dataset with N segment for TSN
        EPICdata_val = EPIC(
            mode=cfg.val_mode,
            cfg=cfg,
            transforms=transformation_val,
        )

        # Creating validation dataloader
        dataloader_val = DataLoader(
            EPICdata_val,
            batch_size=cfg.val_batch_size,
            shuffle=False,
            num_workers=cfg.num_worker_val,
            pin_memory=True,
        )

    # Loading Models (Resnet50)
    model = EPICModel(config=cfg)

    if not cfg.feature_extraction:
        if hasattr(cfg, "train_mode"):
            policies = model.get_optim_policies()

            # for group in policies:
            #     print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
            #         group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))

            # Optimizer
            # initial lr = 0.01
            # momentum = 0.9
            # weight_decay = 5e-4
            optimizer = torch.optim.SGD(policies,
                                        lr=cfg.lr,
                                        momentum=cfg.momentum,
                                        weight_decay=cfg.weight_decay)

            # Loss function (CrossEntropy)
            if cfg.algo == "IRM":
                criterion = torch.nn.CrossEntropyLoss(reduction="none")
            elif cfg.algo == "ERM" or cfg.algo == "MTGA":
                criterion = torch.nn.CrossEntropyLoss()
            elif cfg.algo == "FSL":
                criterion = torch.nn.CrossEntropyLoss()

            # If multiple GPUs are available (and bridged)
            # if torch.cuda.device_count() > 1:
            #     print("Let's use", torch.cuda.device_count(), "GPUs!")
            #     model = torch.nn.DataParallel(model)

            # Convert model and loss function to GPU if available for faster computation
            if use_cuda:
                model = model.cuda()
                criterion = criterion.cuda()

            # Loading Trainer
            experiment = Experiment(
                cfg=cfg,
                model=model,
                loss=criterion,
                optimizer=optimizer,
                use_cuda=use_cuda,
                data_train=dataloader_train,
                data_val=dataloader_val,
                debugging=False,
            )

            # Train the model
            experiment.train()

        else:

            # Load Model Checkpoint
            checkpoint = torch.load(cfg.checkpoint_filename_final)
            model.load_state_dict(checkpoint["model_state_dict"])

            if use_cuda:
                model = model.cuda()

            transformation = torchvision.transforms.Compose([
                transforms.GroupOverSample(
                    224,
                    256),  # group sampling from images using multiple crops
                transforms.Stack(),  # concatenation of images
                transforms.ToTorchFormatTensor(),  # to torch
                transforms.GroupNormalize(mean=[0.485, 0.456, 0.406],
                                          std=[0.229, 0.224,
                                               0.225]),  # Normalization
            ])

            # Loading Predictor
            experiment = Experiment(cfg=cfg,
                                    model=model,
                                    use_cuda=use_cuda,
                                    debugging=False)

            filenames = ["seen.json", "unseen.json"]
            for filename in filenames:
                EPICdata = EPIC(
                    mode=cfg.val_mode,
                    cfg=cfg,
                    transforms=transformation,
                    test_mode=filename[:-5],
                )

                data_loader = torch.utils.data.DataLoader(EPICdata,
                                                          batch_size=8,
                                                          shuffle=False,
                                                          num_workers=4,
                                                          pin_memory=True)
                experiment.data_val = data_loader
                experiment.predict(filename)
    else:
        # Load Model Checkpoint
        checkpoint = torch.load(cfg.checkpoint_filename_final)
        model.load_state_dict(checkpoint["model_state_dict"])

        if use_cuda:
            model = model.cuda()

        model.eval()

        transformation = torchvision.transforms.Compose([
            transforms.GroupOverSample(
                224, 256),  # group sampling from images using multiple crops
            transforms.Stack(),  # concatenation of images
            transforms.ToTorchFormatTensor(),  # to torch
            transforms.GroupNormalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224,
                                           0.225]),  # Normalization
        ])

        # Loading Predictor
        experiment = Experiment(cfg=cfg,
                                model=model,
                                use_cuda=use_cuda,
                                debugging=False)

        with torch.no_grad():
            modes = ["train-unseen", "val-unseen"]
            for mode in modes:
                data = np.empty((1, 2050))
                EPICdata = EPIC(
                    mode=mode,
                    cfg=cfg,
                    transforms=transformation,
                )

                data_loader = torch.utils.data.DataLoader(EPICdata,
                                                          batch_size=1,
                                                          shuffle=False,
                                                          num_workers=0,
                                                          pin_memory=True)

                for i, sample_batch in enumerate(data_loader):
                    output = experiment.extract_features(sample_batch)
                    verb_ann = sample_batch["verb_id"].data.item()
                    noun_ann = sample_batch["noun_id"].data.item()
                    out = np.append(np.mean(output, 0), verb_ann)
                    out = np.append(out, noun_ann)
                    data = np.concatenate((data, np.expand_dims(out, 0)), 0)
                np.save(mode, data)
Exemple #4
0
def bert_train_loader(data_name,validation_size,batch_size):

    if data_name in ['yelp','amazon','yahoo','dbpedia','agnews']:
        y_train,train_text = create_tc_data(data_name,mode = "train")
        y_test,test_text = create_tc_data(data_name,mode = 'test')
        y_train = list(np.array(y_train)-1) # change to right class
        y_test = list(np.array(y_test)-1)
        MAX_LEN=128

    elif data_name == "text_mutitask":
        y_train,train_text,y_test,test_text = [],[],[],[]
        num_class = 0
        for name in ['yelp','amazon','yahoo','dbpedia','agnews']:
            y_train_s,train_text_s = create_tc_data(name,mode = "train")
            y_test_s,test_text_s = create_tc_data(name,mode = 'test')
            y_train_s = list(np.array(y_train_s)+num_class-1) # change to right class
            y_test_s = list(np.array(y_test_s)+num_class-1)
            num_class += TC_NUM_CLASSES[name]

            y_train.extend(y_train_s)
            train_text.extend(train_text_s)
            y_test.extend(y_test_s)
            test_text.extend(test_text_s)
        MAX_LEN=128

    elif data_name == "sentiment_mutitask":
        y_train,train_text,y_test,test_text = [],[],[],[]
        num_class = 0
        for name in ['magazines.task','apparel.task','health_personal_care.task','camera_photo.task','toys_games.task','software.task','baby.task','kitchen_housewares.task','sports_outdoors.task',
                    'electronics.task','books.task','video.task','imdb.task','dvd.task','music.task','MR.task']:
            y_train_s,y_test_s,train_text_s,test_text_s=read_files(name)

            y_train.extend(y_train_s)
            train_text.extend(train_text_s)
            y_test.extend(y_test_s)
            test_text.extend(test_text_s)
        MAX_LEN=256

    else:
        y_train,y_test,train_text,test_text=read_files(data_name)
        MAX_LEN=256

    # Load the BERT tokenizer.
    print('Loading BERT tokenizer...')
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
    sentences = train_text
    labels = y_train
    test_sentences=test_text
    test_labels=y_test



    input_ids = [tokenizer.encode(sent,add_special_tokens=True,max_length=MAX_LEN,truncation=True) for sent in sentences]
    test_input_ids=[tokenizer.encode(sent,add_special_tokens=True,max_length=MAX_LEN,truncation=True) for sent in test_sentences]


    print('\nPadding token: "{:}", ID: {:}'.format(tokenizer.pad_token, tokenizer.pad_token_id))

    input_ids = pad_sequences(input_ids, maxlen=MAX_LEN, dtype="long", 
                            value=0, truncating="post", padding="post")

    test_input_ids = pad_sequences(test_input_ids, maxlen=MAX_LEN, dtype="long", 
                            value=0, truncating="post", padding="post")


    # Create attention masks
    attention_masks = []

    # For each sentence...
    for sent in input_ids:
        
        # Create the attention mask.
        #   - If a token ID is 0, then it's padding, set the mask to 0.
        #   - If a token ID is > 0, then it's a real token, set the mask to 1.
        att_mask = [int(token_id > 0) for token_id in sent]
        
        # Store the attention mask for this sentence.
        attention_masks.append(att_mask)

    test_attention_masks = []

    # For each sentence...
    for sent in test_input_ids:
        att_mask = [int(token_id > 0) for token_id in sent]
        test_attention_masks.append(att_mask)



    # # Use 90% for training and 10% for validation.
    if validation_size!=0:
        train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(input_ids, labels, 
                                                                    random_state=2020, test_size=validation_size)
        # Do the same for the masks.
        train_masks, validation_masks, _, _ = train_test_split(attention_masks, labels,
                                                    random_state=2020, test_size=validation_size)
    else:
        train_inputs,train_labels = input_ids,labels
        train_masks = attention_masks
        validation_inputs,validation_labels = test_input_ids,test_labels
        validation_masks = test_attention_masks

    train_inputs = torch.LongTensor(train_inputs)
    validation_inputs = torch.LongTensor(validation_inputs)
    test_inputs=torch.LongTensor(test_input_ids)

    train_labels = torch.LongTensor(train_labels)
    validation_labels = torch.LongTensor(validation_labels)
    test_labels=torch.LongTensor(test_labels)

    train_masks = torch.LongTensor(train_masks)
    validation_masks = torch.LongTensor(validation_masks)
    test_masks=torch.LongTensor(test_attention_masks)

    print(train_inputs.size())

    # The DataLoader needs to know our batch size for training, so we specify it 
    # here.
    # For fine-tuning BERT on a specific task, the authors recommend a batch size of
    # 16 or 32.


    # Create the DataLoader for our training set.
    train_data = TensorDataset(train_inputs, train_masks, train_labels)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size,num_workers=2,pin_memory=True)

    # # Create the DataLoader for our validation set.
    validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels)
    validation_sampler = SequentialSampler(validation_data)
    validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size,num_workers=2,pin_memory=True)

    # Create the DataLoader for our test set.
    test_data = TensorDataset(test_inputs, test_masks, test_labels)
    test_sampler = SequentialSampler(test_data)
    test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size,num_workers=2,pin_memory=True)
    if data_name == 'sentiment_mutitask':
        test_dataloader = []
        for i in range(16):
            test_data = TensorDataset(test_inputs[i*200:(i+1)*200], test_masks[i*200:(i+1)*200], test_labels[i*200:(i+1)*200])
            test_sampler = SequentialSampler(test_data)
            test_dataloader.append(DataLoader(test_data, sampler=test_sampler, batch_size=batch_size,num_workers=2,pin_memory=True))
    if data_name == 'text_mutitask':
        test_dataloader = []
        for i in range(5):
            test_data = TensorDataset(test_inputs[i*7600:(i+1)*7600], test_masks[i*7600:(i+1)*7600], test_labels[i*7600:(i+1)*7600])
            test_sampler = SequentialSampler(test_data)
            test_dataloader.append(DataLoader(test_data, sampler=test_sampler, batch_size=batch_size,num_workers=2,pin_memory=True))

    return train_dataloader,validation_dataloader, test_dataloader
Exemple #5
0
def get_photo_sketch_dataset(path, img_size, batch_size, sketch_path):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    img_transform = transforms.Compose(
        [
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(), normalize
        ]
        #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    dataset = datasets.ImageFolder(path, transform=img_transform)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    classes, class_to_idx, rev_class = get_classes(dataset)
    photos = {class_name: [item.split('.')[0] for item in os.listdir(path + '/' + class_name)] for class_name \
              in classes}
    sketches = {class_name: [item.split('.')[0] for item in os.listdir(sketch_path + '/' + class_name)] for class_name \
              in classes}

    # create tuples = (batch_idx, img1, target_tensor, img_2)
    # first create a list of (img_1, target, img_2) and shuffle
    # then split by batch and create the first tuple
    pairs, data_loader = [], []
    for class_name in photos:
        class_idx = class_to_idx[class_name]
        for photo_idx in photos[class_name]:
            if os.path.isfile(path + '/' + class_name + '/' + photo_idx +
                              '.jpg') is False:
                continue
            photo = Im.open(path + '/' + class_name + '/' + photo_idx + '.jpg')
            photo_tr = img_transform(photo)
            #### change this!
            #sketch_idxs = [item for item in sketches[class_name]]
            sketch_idxs = [item for item in sketches[class_name] \
                           if item.split('-')[0] == photo_idx]
            for sketch_idx in sketch_idxs:
                sketch = Im.open(sketch_path + '/' + class_name + '/' +
                                 sketch_idx + '.png')
                sketch_tr = img_transform(sketch)
                pairs.append((torch.tensor(photo_tr), class_idx,
                              torch.tensor(sketch_tr)))

            #break

    np.random.shuffle(pairs)
    batches = [
        pairs[i:i + batch_size] for i in range(0, len(pairs), batch_size)
    ]

    for batch_idx in range(len(batches)):
        batch = batches[batch_idx]
        targets, photos, sketches = [], [], []
        for item in batch:
            targets.append(item[1])
            photos.append(item[0].data.numpy())
            sketches.append(item[2].data.numpy())

        data_loader.append((batch_idx, (torch.tensor(photos), \
                                               torch.tensor(targets),
                                               torch.tensor(sketches))))

    return dataset, data_loader