Exemple #1
0
 def load_dataset(self, hparams):
     print("Loading the dataset of codes into memory...")
     device = "cuda" if hparams.gpus else "cpu"
     vqvae = VQVAE.load_from_checkpoint(hparams.vqvae_model_path)
     vqvae = vqvae.to(device)
     codes = []
     nb = 0
     for X, _ in vqvae.train_dataloader():
         X = X.to(device)
         zinds = vqvae.encode(X)
         codes.append(zinds.data.cpu())
         nb += len(zinds)
         if hparams.nb_examples and nb >= hparams.nb_examples:
             break
     codes = torch.cat(codes)
     if hparams.nb_examples and len(codes) >= hparams.nb_examples:
         codes = codes[:hparams.nb_examples]
     vocab_size = vqvae.model.num_embeddings + 1
     start_token = vqvae.model.num_embeddings
     codes_ = codes.view(len(codes), -1)
     codes_ = torch.cat([
         (torch.ones(len(codes_), 1) * start_token).long(),
         codes_,
     ],
                        dim=1)
     length = codes_.shape[1]
     dataset = TensorDataset(codes_)
     dataset.vocab_size = vocab_size
     dataset.shape = codes.shape
     dataset.length = length
     dataset.start_token = start_token
     print("Done loading dataset")
     return dataset
def construct_datasets(dataset_constructor, embeddings, args):
    ## Train dataset
    if args.tensor_dataset:
        from serialize_pytorch_dataset import load_serialized_data
        input_shape = None

        Xs = list()
        Ys = list()
        for path in args.tensor_dataset:
            X, Y = load_serialized_data(path)
            Xs.append(X)
            Ys.append(Y)
            if input_shape is None:
                input_shape = X.shape[1:]

        assert len({el.shape[0]
                    for el in Xs + Ys
                    }) == 1, 'All datasets must have the same batch dimension'
        train_dataset = TensorDataset(*(Xs + [Ys[0]]))

    elif args.train_news_dir is not None and args.train_ground_truth_dir is not None:
        from datasets import CachedDataset
        train_dataset = dataset_constructor(args.train_news_dir,
                                            args.train_ground_truth_dir,
                                            embeddings, args)
        input_shape = train_dataset.shape()
        train_dataset = CachedDataset(train_dataset)

    else:
        train_dataset = None
        print('\n=> # NOT USING A TRAIN DATASET; ONLY FOR TESTING #\n')

    ## Val dataset
    if args.val_news_dir is None or args.val_ground_truth_dir is None:
        val_dataset = None
    else:
        from datasets import CachedDataset
        val_dataset = dataset_constructor(args.val_news_dir,
                                          args.val_ground_truth_dir,
                                          embeddings, args)
        val_dataset = CachedDataset(val_dataset)

    ## Test dataset: if no test dataset was provided, split train dataset
    if args.test_news_dir is None or args.test_ground_truth_dir is None and train_dataset is not None:
        print('=> No test dataset provided, splitting train/test {}/{}'.format(
            1. - args.test_percentage, args.test_percentage))

        num_train = int(len(train_dataset) * (1 - args.test_percentage) + 0.5)
        num_test = int(len(train_dataset) * args.test_percentage + 0.5)
        train_dataset, test_dataset = torch.utils.data.random_split(
            train_dataset, [num_train, num_test])
    else:
        print('Test dataset: "{}"'.format(args.test_news_dir))
        test_dataset = dataset_constructor(args.test_news_dir,
                                           args.test_ground_truth_dir,
                                           embeddings, args)
        input_shape = test_dataset.shape()

    return train_dataset, val_dataset, test_dataset, input_shape