Ejemplo n.º 1
0
def data_sampler(dataset, shuffle, distributed):
    if distributed:
        return data.distributed.DistributedSampler(dataset, shuffle=shuffle)

    if shuffle:
        return data.RandomSampler(dataset)

    else:
        return data.SequentialSampler(dataset)
Ejemplo n.º 2
0
def get_data(transform, mode='train'):
    print('[embeddings.py] Loading data for "%s" ...' % mode)
    global dataset
    if args.dataset == 'ucf101':
        dataset = UCF101_3d(mode=mode,
                            transform=transform,
                            seq_len=args.seq_len,
                            num_seq=args.num_seq,
                            downsample=args.ds,
                            which_split=args.split)
    elif args.dataset == 'hmdb51':
        dataset = HMDB51_3d(mode=mode,
                            transform=transform,
                            seq_len=args.seq_len,
                            num_seq=args.num_seq,
                            downsample=args.ds,
                            which_split=args.split)
    elif args.dataset == 'ucf11':
        # no split here
        dataset = UCF11_3d(mode=mode,
                           transform=transform,
                           num_seq=args.num_seq,
                           seq_len=args.seq_len,
                           downsample=args.ds)
    elif args.dataset == 'block_toy':
        # no split here
        dataset = block_toy(mode=mode,
                            transform=transform,
                            num_seq=args.num_seq,
                            seq_len=args.seq_len,
                            downsample=args.ds,
                            background=args.background)
    else:
        raise ValueError('dataset not supported')

    # shuffle data
    # print (dataset.shape)
    my_sampler = data.SequentialSampler(dataset)
    # print(len(dataset)

    if mode == 'val':
        data_loader = data.DataLoader(dataset,
                                      batch_size=args.batch_size,
                                      sampler=my_sampler,
                                      shuffle=False,
                                      num_workers=16,
                                      pin_memory=True,
                                      drop_last=True)
    elif mode == 'test':
        data_loader = data.DataLoader(dataset,
                                      batch_size=args.batch_size,
                                      sampler=my_sampler,
                                      shuffle=False,
                                      num_workers=16,
                                      pin_memory=True)
    print('"%s" dataset size: %d' % (mode, len(dataset)))
    return data_loader
Ejemplo n.º 3
0
def make_data_sampler(dataset,
                      is_train=True,
                      shuffle=True,
                      is_distributed=False):
    if is_train:
        sampler = dutils.RandomSampler(dataset)
    else:
        sampler = dutils.SequentialSampler(dataset)
    return sampler
    def param_calculation(self, args):
        if args.seed is not None:
            self.__make_reproducible(args.seed)

        transform = transforms.Compose(
            [transforms.ToPILImage(), transforms.Resize((160, 64)), transforms.ToTensor()]
        )

        labels = pd.read_csv(os.path.join('/home/edisn/edisn/TeamClassifier/team124_dataset/', 'train.csv'))
        train_data, _ = train_test_split(labels, stratify=labels.cls, test_size=0.1)

        # dataset = DigitDataset(args.root, split="train", transform=transform)
        dataset = TeamDataset(train_data, '/home/edisn/edisn/TeamClassifier/team124_dataset/', transform=transform)

        num_samples = args.num_samples
        if num_samples is None:
            num_samples = len(dataset)
        if num_samples < len(dataset):
            sampler = FiniteRandomSampler(dataset, num_samples)
        else:
            sampler = data.SequentialSampler(dataset)

        loader = data.DataLoader(
            dataset,
            sampler=sampler,
            num_workers=args.num_workers,
            batch_size=args.batch_size,
        )

        running_mean = RunningAverage(device=args.device)
        running_std = RunningAverage(device=args.device)
        num_batches = ceil(num_samples / args.batch_size)

        with torch.no_grad():
            for batch, (images, _) in enumerate(loader, 1):
                images = images.to(args.device)
                images_flat = torch.flatten(images, 2)

                mean = torch.mean(images_flat, dim=2)
                running_mean.update(mean)

                std = torch.std(images_flat, dim=2)
                running_std.update(std)

                if not args.quiet and batch % args.print_freq == 0:
                    print(
                        (
                            f"[{batch:6d}/{num_batches}] "
                            f"mean={running_mean}, std={running_std}"
                        )
                    )

        print(f"mean={running_mean}, std={running_std}")

        return running_mean.tolist(), running_std.tolist()
Ejemplo n.º 5
0
    def param_calculation(self, config, dataset_path, csv_path):
        if ast.literal_eval(config.normalization_param.seed) is not None:
            self.__make_reproducible(config.normalization_param.seed)

        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((160, 64)),
            transforms.ToTensor()
        ])

        labels = pd.read_csv(csv_path)
        train_data, _ = train_test_split(labels,
                                         stratify=labels.cls,
                                         test_size=0.2)

        dataset = TeamDataset(train_data, dataset_path, transform=transform)

        num_samples = ast.literal_eval(config.normalization_param.num_samples)
        if num_samples is None:
            num_samples = len(dataset)
        if num_samples < len(dataset):
            sampler = FiniteRandomSampler(dataset, num_samples)
        else:
            sampler = data.SequentialSampler(dataset)

        loader = data.DataLoader(
            dataset,
            sampler=sampler,
            num_workers=config.normalization_param.num_workers,
            batch_size=config.normalization_param.batch_size,
        )

        running_mean = RunningAverage(device=config.normalization_param.device)
        running_std = RunningAverage(device=config.normalization_param.device)
        num_batches = ceil(num_samples / config.normalization_param.batch_size)

        with torch.no_grad():
            for batch, (images, _) in enumerate(loader, 1):
                images = images.to(config.normalization_param.device)
                images_flat = torch.flatten(images, 2)

                mean = torch.mean(images_flat, dim=2)
                running_mean.update(mean)

                std = torch.std(images_flat, dim=2)
                running_std.update(std)

                if not config.normalization_param.quiet and batch % config.normalization_param.print_freq == 0:
                    print((f"[{batch:6d}/{num_batches}] "
                           f"mean={running_mean}, std={running_std}"))

        print(f"mean={running_mean}, std={running_std}")

        return running_mean.tolist(), running_std.tolist()
Ejemplo n.º 6
0
def create_data_loader(config, idxs, shuffle=True):
    if shuffle:
        sampler = TD.SubsetRandomSampler(idxs)
        dataset = config.dataset
    else:
        sampler = TD.SequentialSampler(idxs)
        dataset = TD.Subset(config.dataset, idxs)
    return TD.DataLoader(
        dataset,
        batch_size=config.batch_size,
        sampler=sampler,
        pin_memory=True, num_workers=config.data_loader_num_workers
    )
Ejemplo n.º 7
0
def get_video_data_loader(path, vid_range, save_path, batch_size=2):
    dataset = VideoDataset(path, vid_range, save_path)
    data_loader = data.DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=data.SequentialSampler(dataset),
        shuffle=False,
        num_workers=2,
        collate_fn=individual_collate,
        pin_memory=True,
        drop_last=True
    )
    return data_loader
Ejemplo n.º 8
0
def data_sampler(dataset, shuffle, distributed, weights=None):
    if distributed:
        return data.distributed.DistributedSampler(dataset, shuffle=shuffle)

    if weights is not None:
        return data.WeightedRandomSampler(weights,
                                          len(weights),
                                          replacement=True)

    if shuffle:
        return data.RandomSampler(dataset)
    else:
        return data.SequentialSampler(dataset)
Ejemplo n.º 9
0
    def create_itrator_for_dataset(self,
                                   input_ids=None,
                                   attention_masks=None,
                                   label_arg=None):
        '''
        把input_id,att_mask,label(如果有)转换成dataloader
        会被labelDataset继承,此时label_arg会被赋值
        :return: dataloader
        '''
        assert input_ids and attention_masks, f'input_ids,attention_masks必须被赋值!'

        inputs, masks = torch.tensor(input_ids), torch.tensor(attention_masks)

        if self.use_variable_batch:
            # 依据每天微博数,variable batch size
            if not label_arg:
                input_data = TensorDataset(
                    inputs,
                    masks,
                    preprocessed_data=self.preprocessed_data.cleaned_data)
            else:
                labels = torch.tensor(label_arg)
                input_data = TensorDataset(
                    inputs,
                    masks,
                    labels,
                    preprocessed_data=self.preprocessed_data.cleaned_data)

            input_sampler = SequentialSampler(input_data)

            return tud.DataLoader(input_data,
                                  batch_sampler=input_sampler,
                                  shuffle=False,
                                  num_workers=4)
        else:
            # 正常loaddataset
            if label_arg == None:
                input_data = tud.TensorDataset(inputs, masks)
            else:
                labels = torch.tensor(label_arg)
                input_data = tud.TensorDataset(inputs, masks, labels)

            input_sampler = tud.SequentialSampler(input_data)

            return tud.DataLoader(input_data,
                                  sampler=input_sampler,
                                  shuffle=False,
                                  batch_size=int(
                                      utils.cfg.get('HYPER_PARAMETER',
                                                    'batch_size')),
                                  num_workers=4)
Ejemplo n.º 10
0
    def test_sampler_wrapper(self, mock_len, mock_get_item):
        def side_effect(idx):
            return [0, 1, None, 3, 4, 5][idx]

        mock_get_item.side_effect = side_effect
        mock_len.return_value = 6
        dataset = data.TensorDataset(torch.arange(0, 10))
        dataset = nonechucks.SafeDataset(dataset)
        self.assertEqual(len(dataset), 6)
        sequential_sampler = data.SequentialSampler(dataset)
        dataloader = data.DataLoader(dataset,
                                     sampler=nonechucks.SafeSampler(
                                         dataset, sequential_sampler))
        for i_batch, sample_batched in enumerate(dataloader):
            print('Sample {}: {}'.format(i_batch, sample_batched))
Ejemplo n.º 11
0
def data_loader(root, phase, batch_size, tokenizer, config):
    dataset = load_and_cache_examples(root, tokenizer, config=config, mode=phase)

    if phase == 'train':
        sampler = data.RandomSampler(dataset)
    else:
        sampler = data.SequentialSampler(dataset)

    dataloader = data.DataLoader(dataset=dataset, sampler=sampler, batch_size=batch_size)
    return dataloader


# from transformers import AutoTokenizer
# dataloader = data_loader('/home/ubuntu/aikorea/sbs/data', 'train', 32, AutoTokenizer.from_pretrained(config.bert_model_name))
# print(len(dataloader))
Ejemplo n.º 12
0
def create_dataloader(source_strings, target_strings, text_encoder, batch_size,
                      shuffle_batches_each_epoch):
    '''target_strings parameter can be None'''
    dataset = TranslitData(source_strings,
                           target_strings,
                           text_encoder=text_encoder)
    seq_sampler = torch_data.SequentialSampler(dataset)
    batch_sampler = BatchSampler(seq_sampler,
                                 batch_size=batch_size,
                                 drop_last=False,
                                 shuffle_each_epoch=shuffle_batches_each_epoch)
    dataloader = torch_data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       collate_fn=collate_fn)
    return dataloader
Ejemplo n.º 13
0
 def get_latents(self, dset):
     self.enc.eval()
     collected_latents = []
     determin_dl = data.DataLoader(dset,
                                   batch_sampler=data.BatchSampler(
                                       data.SequentialSampler(dset),
                                       self.batch_size,
                                       drop_last=False),
                                   pin_memory=False)
     for idx, (xb, yb, tb) in enumerate(determin_dl):
         batch_latents = self.enc(xb)
         batch_latents = batch_latents.view(batch_latents.shape[0],
                                            -1).detach().cpu().numpy()
         collected_latents.append(batch_latents)
     collected_latents = np.concatenate(collected_latents, axis=0)
     return collected_latents
Ejemplo n.º 14
0
    def __init__(self, dataset, output_size, device, config):
        self.dataset = dataset
        self.output_size = output_size
        self.device = device

        if hasattr(config, "task") and hasattr(config.task, "init_index"):
            from utils import data_sampler  # Avoid import error from COCO-GAN
            sampler = data_sampler(dataset,
                                   shuffle=False,
                                   init_index=config.task.init_index)
        else:
            sampler = data.SequentialSampler(dataset)

        self.dataloader_proto = data.DataLoader(
            dataset,
            batch_size=config.train_params.batch_size,
            sampler=sampler,
            drop_last=False,
            num_workers=16,
        )
        self.dataloader = None
Ejemplo n.º 15
0
 def _setup_dataloader_from_config(self, cfg: DictConfig):
     if cfg.get("load_from_cached_dataset", False):
         logging.info('Loading from cached dataset %s' %
                      (cfg.src_file_name))
         if cfg.src_file_name != cfg.tgt_file_name:
             raise ValueError(
                 "src must be equal to target for cached dataset")
         dataset = pickle.load(open(cfg.src_file_name, 'rb'))
         dataset.reverse_lang_direction = cfg.get("reverse_lang_direction",
                                                  False)
     else:
         dataset = TranslationDataset(
             dataset_src=str(Path(cfg.src_file_name).expanduser()),
             dataset_tgt=str(Path(cfg.tgt_file_name).expanduser()),
             tokens_in_batch=cfg.tokens_in_batch,
             clean=cfg.get("clean", False),
             max_seq_length=cfg.get("max_seq_length", 512),
             min_seq_length=cfg.get("min_seq_length", 1),
             max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
             max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
             cache_ids=cfg.get("cache_ids", False),
             cache_data_per_node=cfg.get("cache_data_per_node", False),
             use_cache=cfg.get("use_cache", False),
             reverse_lang_direction=cfg.get("reverse_lang_direction",
                                            False),
         )
         dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer)
     if cfg.shuffle:
         sampler = pt_data.RandomSampler(dataset)
     else:
         sampler = pt_data.SequentialSampler(dataset)
     return torch.utils.data.DataLoader(
         dataset=dataset,
         batch_size=1,
         sampler=sampler,
         num_workers=cfg.get("num_workers", 2),
         pin_memory=cfg.get("pin_memory", False),
         drop_last=cfg.get("drop_last", False),
     )
Ejemplo n.º 16
0
def main():
    np.random.seed(50)
    X_test = scipy.io.loadmat('Processed_data/HRF_test_noised.mat')
    X_test = X_test['HRF_test_noised']
    n = X_test.shape[0]
    X_test = np.concatenate((X_test[0:int(n / 2), :], X_test[int(n / 2):, :]),
                            axis=1)
    Y_test = scipy.io.loadmat('Processed_data/HRF_test.mat')
    Y_test = Y_test['HRF_test']
    n = Y_test.shape[0]
    Y_test = np.concatenate((Y_test[0:int(n / 2), :], Y_test[int(n / 2):, :]),
                            axis=1)
    X_test = X_test * 1000000
    Y_test = Y_test * 1000000
    X_test = X_test[0:100, :]
    Y_test = Y_test[0:100, :]
    test_set = Dataset(X_test, Y_test)

    testloader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=512,
        sampler=tordata.SequentialSampler(test_set),
        num_workers=2)

    net = Net_8layers()
    hdf5_filepath = "networks/8layers"
    net.load_state_dict(
        torch.load(hdf5_filepath, map_location=torch.device('cpu')))
    print('loaded nn file')
    for i, data in enumerate(testloader, 0):
        print(i)
        inputs = data[0]
        inputs = inputs.float()
        start_time = time.time()
        net(inputs)
        print("--- %s seconds ---" % (time.time() - start_time))

    return
Ejemplo n.º 17
0
def main():
    args = parse_args()

    config = load_config(os.path.join(args.models_path, args.exp_name))

    if config['model'] == 'TextTransposeModel':
        test_set = DatasetTextImages(path=args.test_data_path, patch_size=None,
                                     aug_resize_factor_range=None, scale=config['scale'])
    else:
        test_set = DatasetFromSingleImages(path=args.test_data_path, patch_size=None,
                                           aug_resize_factor_range=None, scale=config['scale'])

    batch_sampler = Data.BatchSampler(
        sampler=Data.SequentialSampler(test_set),
        batch_size=1,
        drop_last=True
    )

    evaluation_data_loader = Data.DataLoader(dataset=test_set, num_workers=0, batch_sampler=batch_sampler)

    trainer = Trainer(name=args.exp_name, models_root=args.models_path, resume=True)
    trainer.load_best()

    psnr = PSNR(name='PSNR', border=config['border'])

    tic = time.time()
    count = 0
    for batch in tqdm(evaluation_data_loader):
        output = trainer.predict(batch=batch)
        psnr.update(batch[1], output)
        count += 1

    toc = time.time()

    print('FPS: {}, SAMPLES: {}'.format(float(count) / (toc - tic), count))
    print('PSNR: {}'.format(psnr.get()))
Ejemplo n.º 18
0
                self.worker_result_queue.close()

            # Exit workers now.
            for q in self.index_queues:
                q.put(None)
                # Indicate that no more data will be put on this queue by the
                # current process.
                q.close()
            for w in self.workers:
                w.join()

    def __del__(self):
        if self.num_workers > 0:
            self._shutdown_workers()


if __name__ == '__main__':
    from torchvision import transforms
    from data.base import DemoDataset
    from torch.utils import data

    dataset = DemoDataset(20)
    transform_fn = [transforms.Compose([transforms.Resize((s, s)), transforms.ToTensor()]) for s in [10, 15, 20, 25,
                                                                                                     30, 35, 40]]
    dataset = dataset.transform(transform_fn[0])
    sampler = data.SequentialSampler(dataset)
    batch_sampler = data.sampler.BatchSampler(sampler, batch_size=4, drop_last=False)
    loader = RandomTransformDataLoader(transform_fn, dataset, batch_sampler=batch_sampler, interval=1, num_workers=4)
    for i, batch in enumerate(loader):
        print(batch.shape)
Ejemplo n.º 19
0
def get_dataset_loaders(args, transform, mode='train'):
    print('Loading data for "%s" ...' % mode)

    if type(args) != dict:
        args_dict = deepcopy(vars(args))
    else:
        args_dict = args

    if args_dict['debug']:
        orig_mode = mode
        mode = 'train'

    use_big_K400 = False
    if args_dict["dataset"] == 'kinetics':
        use_big_K400 = args_dict["img_dim"] > 150
        dataset = Kinetics_3d(
            mode=mode,
            transform=transform,
            seq_len=args_dict["seq_len"],
            num_seq=args_dict["num_seq"],
            downsample=args_dict["ds"],
            vals_to_return=args_dict["data_sources"].split('_'),
            use_big=use_big_K400,
        )
    elif args_dict["dataset"] == 'ucf101':
        dataset = UCF101_3d(
            mode=mode,
            transform=transform,
            seq_len=args_dict["seq_len"],
            num_seq=args_dict["num_seq"],
            downsample=args_dict["ds"],
            vals_to_return=args_dict["data_sources"].split('_'),
            debug=args_dict["debug"])
    elif args_dict["dataset"] == 'jhmdb':
        dataset = JHMDB_3d(mode=mode,
                           transform=transform,
                           seq_len=args_dict["seq_len"],
                           num_seq=args_dict["num_seq"],
                           downsample=args_dict["ds"],
                           vals_to_return=args_dict["data_sources"].split('_'),
                           sampling_method=args_dict["sampling"])
    elif args_dict["dataset"] == 'hmdb51':
        dataset = HMDB51_3d(
            mode=mode,
            transform=transform,
            seq_len=args_dict["seq_len"],
            num_seq=args_dict["num_seq"],
            downsample=args_dict["ds"],
            vals_to_return=args_dict["data_sources"].split('_'),
            sampling_method=args_dict["sampling"])
    else:
        raise ValueError('dataset not supported')

    val_sampler = data.SequentialSampler(dataset)
    if use_big_K400:
        train_sampler = data.RandomSampler(dataset,
                                           replacement=True,
                                           num_samples=int(0.2 * len(dataset)))
    else:
        train_sampler = data.RandomSampler(dataset)

    if args_dict["debug"]:
        if orig_mode == 'val':
            train_sampler = data.RandomSampler(dataset,
                                               replacement=True,
                                               num_samples=200)
        else:
            train_sampler = data.RandomSampler(dataset,
                                               replacement=True,
                                               num_samples=2000)
        val_sampler = data.RandomSampler(dataset)
        # train_sampler = data.RandomSampler(dataset, replacement=True, num_samples=100)

    data_loader = None
    if mode == 'train':
        data_loader = data.DataLoader(dataset,
                                      batch_size=args_dict["batch_size"],
                                      sampler=train_sampler,
                                      shuffle=False,
                                      num_workers=args_dict["num_workers"],
                                      collate_fn=data_utils.individual_collate,
                                      pin_memory=True,
                                      drop_last=True)
    elif mode == 'val':
        data_loader = data.DataLoader(dataset,
                                      sampler=val_sampler,
                                      batch_size=args_dict["batch_size"],
                                      shuffle=False,
                                      num_workers=args_dict["num_workers"],
                                      collate_fn=data_utils.individual_collate,
                                      pin_memory=True,
                                      drop_last=True)
    elif mode == 'test':
        data_loader = data.DataLoader(dataset,
                                      sampler=val_sampler,
                                      batch_size=args_dict["batch_size"],
                                      shuffle=False,
                                      num_workers=args_dict["num_workers"],
                                      collate_fn=data_utils.individual_collate,
                                      pin_memory=True,
                                      drop_last=False)

    print('"%s" dataset size: %d' % (mode, len(dataset)))
    return data_loader
Ejemplo n.º 20
0
 def _setup_dataloader_from_config(self, cfg: DictConfig):
     if cfg.get("load_from_cached_dataset", False):
         logging.info('Loading from cached dataset %s' %
                      (cfg.src_file_name))
         if cfg.src_file_name != cfg.tgt_file_name:
             raise ValueError(
                 "src must be equal to target for cached dataset")
         dataset = pickle.load(open(cfg.src_file_name, 'rb'))
         dataset.reverse_lang_direction = cfg.get("reverse_lang_direction",
                                                  False)
     elif cfg.get("use_tarred_dataset", False):
         if cfg.get('tar_files') is None:
             raise FileNotFoundError("Could not find tarred dataset.")
         logging.info(f'Loading from tarred dataset {cfg.get("tar_files")}')
         if cfg.get("metadata_file", None) is None:
             raise FileNotFoundError(
                 "Could not find metadata path in config")
         dataset = TarredTranslationDataset(
             text_tar_filepaths=cfg.tar_files,
             metadata_path=cfg.metadata_file,
             encoder_tokenizer=self.encoder_tokenizer,
             decoder_tokenizer=self.decoder_tokenizer,
             shuffle_n=cfg.get("tar_shuffle_n", 100),
             shard_strategy=cfg.get("shard_strategy", "scatter"),
             global_rank=self.global_rank,
             world_size=self.world_size,
             reverse_lang_direction=cfg.get("reverse_lang_direction",
                                            False),
         )
         return torch.utils.data.DataLoader(
             dataset=dataset,
             batch_size=1,
             num_workers=cfg.get("num_workers", 2),
             pin_memory=cfg.get("pin_memory", False),
             drop_last=cfg.get("drop_last", False),
         )
     else:
         dataset = TranslationDataset(
             dataset_src=str(Path(cfg.src_file_name).expanduser()),
             dataset_tgt=str(Path(cfg.tgt_file_name).expanduser()),
             tokens_in_batch=cfg.tokens_in_batch,
             clean=cfg.get("clean", False),
             max_seq_length=cfg.get("max_seq_length", 512),
             min_seq_length=cfg.get("min_seq_length", 1),
             max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
             max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
             cache_ids=cfg.get("cache_ids", False),
             cache_data_per_node=cfg.get("cache_data_per_node", False),
             use_cache=cfg.get("use_cache", False),
             reverse_lang_direction=cfg.get("reverse_lang_direction",
                                            False),
         )
         dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer)
     if cfg.shuffle:
         sampler = pt_data.RandomSampler(dataset)
     else:
         sampler = pt_data.SequentialSampler(dataset)
     return torch.utils.data.DataLoader(
         dataset=dataset,
         batch_size=1,
         sampler=sampler,
         num_workers=cfg.get("num_workers", 2),
         pin_memory=cfg.get("pin_memory", False),
         drop_last=cfg.get("drop_last", False),
     )
    train_set = Dataset(X_train, Y_train)
    val_set = Dataset(X_val, Y_val)
    test_set = Dataset(X_test)

    # %% define data loaders
    trainloader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=512,
        sampler=tordata.RandomSampler(train_set),
        num_workers=2)

    valloader = torch.utils.data.DataLoader(
        dataset=val_set,
        batch_size=512,
        sampler=tordata.SequentialSampler(val_set),
        num_workers=2)

    testloader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=32,
        sampler=tordata.SequentialSampler(test_set),
        num_workers=2)

    # %% trian and validate
    data_loaders = {"train": trainloader, "val": valloader}
    data_lengths = {"train": n_train, "val": SampleSize - n_train}
    model = ['8layers']
    #loss_modes = ['mse','mse+snr','mse+std','mse+snr+std']
    loss_modes = ['mse+snr+std']
    #torch.set_num_threads(16)
Ejemplo n.º 22
0
def data_sampler(dataset, shuffle):
    if shuffle:
        return data.RandomSampler(dataset)

    else:
        return data.SequentialSampler(dataset)
Ejemplo n.º 23
0
def get_data(args,
             mode='train',
             return_label=False,
             hierarchical_label=False,
             action_level_gt=False,
             num_workers=0,
             path_dataset='',
             path_data_info=''):
    if hierarchical_label and args.dataset not in ['finegym', 'hollywood2']:
        raise Exception(
            'Hierarchical information is only implemented in finegym and hollywood2 datasets'
        )
    if return_label and not action_level_gt and args.dataset != 'finegym':
        raise Exception(
            'subaction only subactions available in finegym dataset')

    if mode == 'train':
        transform = transforms.Compose([
            augmentation.RandomSizedCrop(size=args.img_dim,
                                         consistent=True,
                                         p=1.0),
            augmentation.RandomHorizontalFlip(consistent=True),
            augmentation.RandomGray(consistent=False, p=0.5),
            augmentation.ColorJitter(brightness=0.5,
                                     contrast=0.5,
                                     saturation=0.5,
                                     hue=0.25,
                                     p=1.0),
            augmentation.ToTensor(),
            augmentation.Normalize()
        ])
    else:
        transform = transforms.Compose([
            augmentation.CenterCrop(size=args.img_dim, consistent=True),
            augmentation.ToTensor(),
            augmentation.Normalize()
        ])

    if args.dataset == 'kinetics':
        dataset = Kinetics600(mode=mode,
                              transform=transform,
                              seq_len=args.seq_len,
                              num_seq=args.num_seq,
                              downsample=5,
                              return_label=return_label,
                              return_idx=False,
                              path_dataset=path_dataset,
                              path_data_info=path_data_info)
    elif args.dataset == 'hollywood2':
        if return_label:
            assert action_level_gt, 'hollywood2 does not have subaction labels'
        dataset = Hollywood2(mode=mode,
                             transform=transform,
                             seq_len=args.seq_len,
                             num_seq=args.num_seq,
                             downsample=args.ds,
                             return_label=return_label,
                             hierarchical_label=hierarchical_label,
                             path_dataset=path_dataset,
                             path_data_info=path_data_info)
    elif args.dataset == 'finegym':
        if hierarchical_label:
            assert not action_level_gt, 'finegym does not have hierarchical information at the action level'
        dataset = FineGym(
            mode=mode,
            transform=transform,
            seq_len=args.seq_len,
            num_seq=args.num_seq,
            fps=int(25 / args.ds),  # approx
            return_label=return_label,
            hierarchical_label=hierarchical_label,
            action_level_gt=action_level_gt,
            path_dataset=path_dataset,
            return_idx=False,
            path_data_info=path_data_info)
    elif args.dataset == 'movienet':
        assert not return_label, 'Not yet implemented (actions not available online)'
        assert args.seq_len == 3, 'We only have 3 frames per subclip/scene, but always 3'
        dataset = MovieNet(mode=mode,
                           transform=transform,
                           num_seq=args.num_seq,
                           path_dataset=path_dataset,
                           path_data_info=path_data_info)
    else:
        raise ValueError('dataset not supported')

    sampler = data.RandomSampler(
        dataset) if mode == 'train' else data.SequentialSampler(dataset)

    data_loader = data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=sampler,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=(mode != 'test'
                   )  # test always same examples independently of batch size
    )
    return data_loader
Ejemplo n.º 24
0
 pytest.param(
     "utils.data.sampler",
     "RandomSampler",
     {},
     [],
     {"data_source": dummy_dataset},
     data.RandomSampler(data_source=dummy_dataset),
     id="RandomSamplerConf",
 ),
 pytest.param(
     "utils.data.sampler",
     "SequentialSampler",
     {},
     [],
     {"data_source": dummy_dataset},
     data.SequentialSampler(data_source=dummy_dataset),
     id="SequentialSamplerConf",
 ),
 pytest.param(
     "utils.data.sampler",
     "SubsetRandomSampler",
     {"indices": [1]},
     [],
     {},
     data.SubsetRandomSampler(indices=[1]),
     id="SubsetRandomSamplerConf",
 ),
 pytest.param(
     "utils.data.sampler",
     "WeightedRandomSampler",
     {
Ejemplo n.º 25
0
    print('Data has been loaded successfully, cost:%.4fs' % (t1 - t0))

    ########################### TRAINING STAGE ##################################
    check_dir('%s/train_log' % conf.out_path)
    log = Logging('%s/train_%s_nrms.log' % (conf.out_path, conf.data_name))
    train_model_path = '%s/train_%s_nrms.mod' % (conf.out_path, conf.data_name)

    # prepare data for the training stage
    train_dataset = data_utils.TrainData(train_data)
    val_dataset = data_utils.TestData(val_data)

    train_batch_sampler = data.BatchSampler(data.RandomSampler(
        range(train_dataset.length)),
                                            batch_size=conf.batch_size,
                                            drop_last=False)
    val_batch_sampler = data.BatchSampler(data.SequentialSampler(
        range(val_dataset.length)),
                                          batch_size=conf.batch_size,
                                          drop_last=True)

    # Start Training !!!
    max_auc = 0
    for epoch in range(1, conf.train_epochs + 1):
        t0 = time()
        model.train()

        train_loss = []
        count = 0
        for batch_idx_list in train_batch_sampler:

            pred_input_news, his_input_news, labels = \
                train_dataset._get_batch(batch_idx_list)
Ejemplo n.º 26
0
    def _setup_dataloader_from_config(self, cfg: DictConfig, predict_last_k=0):

        if cfg.get("use_tarred_dataset", False):
            if cfg.get("metadata_file") is None:
                raise FileNotFoundError(
                    "Trying to use tarred data set but could not find metadata path in config."
                )
            else:
                metadata_file = cfg.get('metadata_file')
                with open(metadata_file) as metadata_reader:
                    metadata = json.load(metadata_reader)
                if cfg.get('tar_files') is None:
                    tar_files = metadata.get('tar_files')
                    if tar_files is not None:
                        logging.info(
                            f'Loading from tarred dataset {tar_files}')
                    else:
                        raise FileNotFoundError(
                            "Could not find tarred dataset in config or metadata."
                        )
                else:
                    tar_files = cfg.get('tar_files')
                    if metadata.get('tar_files') is not None:
                        raise ValueError(
                            'Tar files specified in config and in metadata file. Tar files should only be specified once.'
                        )
            dataset = TarredSentenceDataset(
                text_tar_filepaths=tar_files,
                metadata_path=metadata_file,
                tokenizer=self.tokenizer,
                shuffle_n=cfg.get("tar_shuffle_n", 100),
                shard_strategy=cfg.get("shard_strategy", "scatter"),
                global_rank=self.global_rank,
                world_size=self.world_size,
            )
            return torch.utils.data.DataLoader(
                dataset=dataset,
                batch_size=1,
                num_workers=cfg.get("num_workers", 2),
                pin_memory=cfg.get("pin_memory", False),
                drop_last=cfg.get("drop_last", False),
            )
        else:
            dataset = SentenceDataset(
                tokenizer=self.tokenizer,
                dataset=cfg.file_name,
                tokens_in_batch=cfg.tokens_in_batch,
                clean=cfg.get("clean", False),
                max_seq_length=cfg.get("max_seq_length", 512),
                min_seq_length=cfg.get("min_seq_length", 1),
                cache_ids=cfg.get("cache_ids", False),
            )
        if cfg.shuffle:
            sampler = pt_data.RandomSampler(dataset)
        else:
            sampler = pt_data.SequentialSampler(dataset)
        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=1,
            sampler=sampler,
            num_workers=cfg.get("num_workers", 2),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=cfg.get("drop_last", False),
        )
Ejemplo n.º 27
0
    train_data, val_data = data_utils.load_all()
    t1 = time()
    print('Data has been loaded successfully, cost:%.4fs' % (t1 - t0))

    ########################### TRAINING STAGE ##################################
    check_dir('%s/train_log' % conf.out_path)
    log = Logging('%s/train_%s_nrms.log' % (conf.out_path, conf.data_name))
    train_model_path = '%s/train_%s_nrms.mod' % (conf.out_path, conf.data_name)

    # prepare data for the training stage
    train_dataset = data_utils.TrainData(train_data)
    val_dataset = data_utils.TestData(val_data)

    train_batch_sampler = data.BatchSampler(data.RandomSampler(
        range(train_dataset.length)), batch_size=conf.batch_size, drop_last=False)
    val_batch_sampler = data.BatchSampler(data.SequentialSampler(
        range(val_dataset.length)), batch_size=conf.batch_size, drop_last=True)

    # Start Training !!!
    max_auc = 0
    for epoch in range(1, conf.train_epochs+1):
        t0 = time()
        model.train()
        
        train_loss = []
        count = 0
        for batch_idx_list in train_batch_sampler:
            
            his_input_title, pred_input_title, labels = \
                train_dataset._get_batch(batch_idx_list)
            obj_loss = model(his_input_title, pred_input_title, labels)
            train_loss.append(obj_loss.item())
Ejemplo n.º 28
0
    def _setup_dataloader_from_config(self, cfg: DictConfig):
        if cfg.get("use_tarred_dataset", False):
            if cfg.get("metadata_file") is None:
                raise FileNotFoundError(
                    "Trying to use tarred data set but could not find metadata path in config."
                )
            metadata_file_list = cfg.get('metadata_file')
            tar_files_list = cfg.get('tar_files', None)
            if isinstance(metadata_file_list, str):
                metadata_file_list = [metadata_file_list]
            if tar_files_list is not None and isinstance(tar_files_list, str):
                tar_files_list = [tar_files_list]
            if tar_files_list is not None and len(tar_files_list) != len(
                    metadata_file_list):
                raise ValueError(
                    'The config must have the same number of tarfile paths and metadata file paths.'
                )

            datasets = []
            for idx, metadata_file in enumerate(metadata_file_list):
                with open(metadata_file) as metadata_reader:
                    metadata = json.load(metadata_reader)
                if tar_files_list is None:
                    tar_files = metadata.get('tar_files')
                    if tar_files is not None:
                        logging.info(
                            f'Loading from tarred dataset {tar_files}')
                else:
                    tar_files = tar_files_list[idx]
                    if metadata.get('tar_files') is not None:
                        logging.info(
                            f'Tar file paths found in both cfg and metadata using one in cfg by default - {tar_files}'
                        )

                dataset = TarredTranslationDataset(
                    text_tar_filepaths=tar_files,
                    metadata_path=metadata_file,
                    encoder_tokenizer=self.encoder_tokenizer,
                    decoder_tokenizer=self.decoder_tokenizer,
                    shuffle_n=cfg.get("tar_shuffle_n", 100),
                    shard_strategy=cfg.get("shard_strategy", "scatter"),
                    global_rank=self.global_rank,
                    world_size=self.world_size,
                    reverse_lang_direction=cfg.get("reverse_lang_direction",
                                                   False),
                    prepend_id=self.multilingual_ids[idx]
                    if self.multilingual else None,
                )
                datasets.append(dataset)

            if len(datasets) > 1:
                dataset = ConcatDataset(
                    datasets=datasets,
                    sampling_technique=cfg.get('concat_sampling_technique'),
                    sampling_temperature=cfg.get(
                        'concat_sampling_temperature'),
                    sampling_probabilities=cfg.get(
                        'concat_sampling_probabilities'),
                    global_rank=self.global_rank,
                    world_size=self.world_size,
                )
            else:
                dataset = datasets[0]
        else:
            src_file_list = cfg.src_file_name
            tgt_file_list = cfg.tgt_file_name
            if isinstance(src_file_list, str):
                src_file_list = [src_file_list]
            if isinstance(tgt_file_list, str):
                tgt_file_list = [tgt_file_list]

            if len(src_file_list) != len(tgt_file_list):
                raise ValueError(
                    'The same number of filepaths must be passed in for source and target.'
                )

            datasets = []
            for idx, src_file in enumerate(src_file_list):
                dataset = TranslationDataset(
                    dataset_src=str(Path(src_file).expanduser()),
                    dataset_tgt=str(Path(tgt_file_list[idx]).expanduser()),
                    tokens_in_batch=cfg.tokens_in_batch,
                    clean=cfg.get("clean", False),
                    max_seq_length=cfg.get("max_seq_length", 512),
                    min_seq_length=cfg.get("min_seq_length", 1),
                    max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
                    max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
                    cache_ids=cfg.get("cache_ids", False),
                    cache_data_per_node=cfg.get("cache_data_per_node", False),
                    use_cache=cfg.get("use_cache", False),
                    reverse_lang_direction=cfg.get("reverse_lang_direction",
                                                   False),
                    prepend_id=self.multilingual_ids[idx]
                    if self.multilingual else None,
                )
                dataset.batchify(self.encoder_tokenizer,
                                 self.decoder_tokenizer)
                datasets.append(dataset)

            if len(datasets) > 1:
                dataset = ConcatDataset(
                    datasets=datasets,
                    shuffle=cfg.get('shuffle'),
                    sampling_technique=cfg.get('concat_sampling_technique'),
                    sampling_temperature=cfg.get(
                        'concat_sampling_temperature'),
                    sampling_probabilities=cfg.get(
                        'concat_sampling_probabilities'),
                    global_rank=self.global_rank,
                    world_size=self.world_size,
                )
            else:
                dataset = datasets[0]

        if cfg.shuffle:
            sampler = pt_data.RandomSampler(dataset)
        else:
            sampler = pt_data.SequentialSampler(dataset)
        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=1,
            sampler=None if cfg.get("use_tarred_dataset", False) else sampler,
            num_workers=cfg.get("num_workers", 2),
            pin_memory=cfg.get("pin_memory", False),
            drop_last=cfg.get("drop_last", False),
        )
def main(args):
    utils.init_distributed_mode(args)

    device = torch.device(args.gpus)

    in_chns = 3
    if args.vision_type == 'monochromat':
        in_chns = 1
    elif 'dichromat' in args.vision_type:
        in_chns = 2
    data_reading_kwargs = {
        'target_size': args.target_size,
        'colour_vision': args.vision_type,
        'colour_space': args.colour_space
    }
    dataset, num_classes = utils.get_dataset(args.dataset, args.data_dir,
                                             'train', **data_reading_kwargs)

    json_file_name = os.path.join(args.out_dir, 'args.json')
    with open(json_file_name, 'w') as fp:
        json.dump(dict(args._get_kwargs()), fp, sort_keys=True, indent=4)

    dataset_test, _ = utils.get_dataset(args.dataset, args.data_dir, 'val',
                                        **data_reading_kwargs)

    if args.distributed:
        train_sampler = torch_dist.DistributedSampler(dataset)
        test_sampler = torch_dist.DistributedSampler(dataset_test)
    else:
        train_sampler = torch_data.RandomSampler(dataset)
        test_sampler = torch_data.SequentialSampler(dataset_test)

    data_loader = torch_data.DataLoader(dataset,
                                        batch_size=args.batch_size,
                                        sampler=train_sampler,
                                        num_workers=args.workers,
                                        collate_fn=utils.collate_fn,
                                        drop_last=True)

    data_loader_test = torch_data.DataLoader(dataset_test,
                                             batch_size=1,
                                             sampler=test_sampler,
                                             num_workers=args.workers,
                                             collate_fn=utils.collate_fn)

    if args.network_name == 'unet':
        model = segmentation_models.unet.model.Unet(
            encoder_weights=args.backbone, classes=num_classes)
        if args.pretrained:
            print('Loading %s' % args.pretrained)
            checkpoint = torch.load(args.pretrained, map_location='cpu')
            remove_keys = []
            for key_ind, key in enumerate(checkpoint['state_dict'].keys()):
                if 'segmentation_head' in key:
                    remove_keys.append(key)
            for key in remove_keys:
                del checkpoint['state_dict'][key]
            model.load_state_dict(checkpoint['state_dict'], strict=False)
    elif args.custom_arch:
        print('Custom model!')
        backbone_name, customs = model_utils.create_custom_resnet(
            args.backbone, None)
        if customs is not None:
            args.backbone = {'arch': backbone_name, 'customs': customs}

        model = custom_models.__dict__[args.network_name](
            args.backbone, num_classes=num_classes, aux_loss=args.aux_loss)

        if args.pretrained:
            print('Loading %s' % args.pretrained)
            checkpoint = torch.load(args.pretrained, map_location='cpu')
            num_all_keys = len(checkpoint['state_dict'].keys())
            remove_keys = []
            for key_ind, key in enumerate(checkpoint['state_dict'].keys()):
                if key_ind > (num_all_keys - 3):
                    remove_keys.append(key)
            for key in remove_keys:
                del checkpoint['state_dict'][key]
            pretrained_weights = OrderedDict(
                (k.replace('segmentation_model.', ''), v)
                for k, v in checkpoint['state_dict'].items())
            model.load_state_dict(pretrained_weights, strict=False)
    else:
        model = seg_models.__dict__[args.network_name](
            num_classes=num_classes,
            aux_loss=args.aux_loss,
            pretrained=args.pretrained)
    model.to(device)
    if args.distributed:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    best_iou = 0
    model_progress = []
    model_progress_path = os.path.join(args.out_dir, 'model_progress.csv')
    # loading the model if to eb resumed
    if args.resume is not None:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        best_iou = checkpoint['best_iou']
        # if model progress exists, load it
        if os.path.exists(model_progress_path):
            model_progress = np.loadtxt(model_progress_path, delimiter=',')
            model_progress = model_progress.tolist()

    master_model = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpus])
        master_model = model.module

    if args.network_name == 'unet':
        params_to_optimize = model.parameters()
    else:
        params_to_optimize = [
            {
                'params': [
                    p for p in master_model.backbone.parameters()
                    if p.requires_grad
                ]
            },
            {
                'params': [
                    p for p in master_model.classifier.parameters()
                    if p.requires_grad
                ]
            },
        ]
        if args.aux_loss:
            params = [
                p for p in master_model.aux_classifier.parameters()
                if p.requires_grad
            ]
            params_to_optimize.append({'params': params, 'lr': args.lr * 10})
    optimizer = torch.optim.SGD(params_to_optimize,
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    lr_lambda = lambda x: (1 - x / (len(data_loader) * args.epochs))**0.9
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    criterion = select_criterion(args.dataset)

    start_time = time.time()
    for epoch in range(args.initial_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_log = train_one_epoch(model, criterion, optimizer, data_loader,
                                    lr_scheduler, device, epoch,
                                    args.print_freq)
        val_confmat = utils.evaluate(model,
                                     data_loader_test,
                                     device=device,
                                     num_classes=num_classes)
        val_log = val_confmat.get_log_dict()
        is_best = val_log['iou'] > best_iou
        best_iou = max(best_iou, val_log['iou'])
        model_data = {
            'epoch': epoch + 1,
            'arch': args.network_name,
            'customs': {
                'aux_loss': args.aux_loss,
                'pooling_type': args.pooling_type,
                'in_chns': in_chns,
                'num_classes': num_classes,
                'backbone': args.backbone
            },
            'state_dict': master_model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'target_size': args.target_size,
            'args': args,
            'best_iou': best_iou,
        }
        utils.save_on_master(model_data,
                             os.path.join(args.out_dir, 'checkpoint.pth'))
        if is_best:
            utils.save_on_master(model_data,
                                 os.path.join(args.out_dir, 'model_best.pth'))

        epoch_prog, header = add_to_progress(train_log, [], '')
        epoch_prog, header = add_to_progress(val_log,
                                             epoch_prog,
                                             header,
                                             prefix='v_')
        model_progress.append(epoch_prog)
        np.savetxt(model_progress_path,
                   np.array(model_progress),
                   delimiter=';',
                   header=header,
                   fmt='%s')

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
Ejemplo n.º 30
0
    def _setup_eval_dataloader_from_config(self, cfg: DictConfig):
        src_file_name = cfg.get('src_file_name')
        tgt_file_name = cfg.get('tgt_file_name')

        if src_file_name is None or tgt_file_name is None:
            raise ValueError(
                'Validation dataloader needs both cfg.src_file_name and cfg.tgt_file_name to not be None.'
            )
        else:
            # convert src_file_name and tgt_file_name to list of strings
            if isinstance(src_file_name, str):
                src_file_list = [src_file_name]
            elif isinstance(src_file_name, ListConfig):
                src_file_list = src_file_name
            else:
                raise ValueError(
                    "cfg.src_file_name must be string or list of strings")
            if isinstance(tgt_file_name, str):
                tgt_file_list = [tgt_file_name]
            elif isinstance(tgt_file_name, ListConfig):
                tgt_file_list = tgt_file_name
            else:
                raise ValueError(
                    "cfg.tgt_file_name must be string or list of strings")
        if len(src_file_list) != len(tgt_file_list):
            raise ValueError(
                'The same number of filepaths must be passed in for source and target validation.'
            )

        dataloaders = []
        prepend_idx = 0
        for idx, src_file in enumerate(src_file_list):
            if self.multilingual:
                prepend_idx = idx
            dataset = TranslationDataset(
                dataset_src=str(Path(src_file).expanduser()),
                dataset_tgt=str(Path(tgt_file_list[idx]).expanduser()),
                tokens_in_batch=cfg.tokens_in_batch,
                clean=cfg.get("clean", False),
                max_seq_length=cfg.get("max_seq_length", 512),
                min_seq_length=cfg.get("min_seq_length", 1),
                max_seq_length_diff=cfg.get("max_seq_length_diff", 512),
                max_seq_length_ratio=cfg.get("max_seq_length_ratio", 512),
                cache_ids=cfg.get("cache_ids", False),
                cache_data_per_node=cfg.get("cache_data_per_node", False),
                use_cache=cfg.get("use_cache", False),
                reverse_lang_direction=cfg.get("reverse_lang_direction",
                                               False),
                prepend_id=self.multilingual_ids[prepend_idx]
                if self.multilingual else None,
            )
            dataset.batchify(self.encoder_tokenizer, self.decoder_tokenizer)

            if cfg.shuffle:
                sampler = pt_data.RandomSampler(dataset)
            else:
                sampler = pt_data.SequentialSampler(dataset)

            dataloader = torch.utils.data.DataLoader(
                dataset=dataset,
                batch_size=1,
                sampler=sampler,
                num_workers=cfg.get("num_workers", 2),
                pin_memory=cfg.get("pin_memory", False),
                drop_last=cfg.get("drop_last", False),
            )
            dataloaders.append(dataloader)

        return dataloaders