def define_dataset(conf, data, display_log=True):
    # prepare general train/test.
    conf.partitioned_by_user = True if "femnist" == conf.data else False
    train_dataset = get_dataset(conf, data, conf.data_dir, split="train")
    test_dataset = get_dataset(conf, data, conf.data_dir, split="test")

    # create the validation from train.
    train_dataset, val_dataset, test_dataset = define_val_dataset(
        conf, train_dataset, test_dataset)

    if display_log:
        conf.logger.log(
            "Data stat for original dataset: we have {} samples for train, {} samples for val, {} samples for test."
            .format(
                len(train_dataset),
                len(val_dataset) if val_dataset is not None else 0,
                len(test_dataset),
            ))
    return {"train": train_dataset, "val": val_dataset, "test": test_dataset}
예제 #2
0
def define_nlp_dataset(conf, force_shuffle):
    print("create {} dataset for rank {}".format(conf.data, conf.graph.rank))
    # create dataset.
    TEXT, train, valid, _ = get_dataset(conf, conf.data, conf.data_dir)

    # Build vocb.
    # we can use some precomputed word embeddings,
    # e.g., GloVe vectors with 100, 200, and 300.
    if conf.rnn_use_pretrained_emb:
        try:
            vectors = "glove.6B.{}d".format(conf.rnn_n_hidden)
            vectors_cache = os.path.join(conf.data_dir, ".vector_cache")
        except:
            vectors, vectors_cache = None, None
    else:
        vectors, vectors_cache = None, None
    TEXT.build_vocab(train, vectors=vectors, vectors_cache=vectors_cache)

    # Partition training data.
    train_loader, _ = torchtext.data.BPTTIterator.splits(
        (train, valid),
        batch_size=conf.batch_size * conf.graph.n_nodes,
        bptt_len=conf.rnn_bptt_len,
        device="cuda:{}".format(conf.graph.device[0])
        if conf.graph.on_cuda else None,
        repeat=True,
        shuffle=force_shuffle or conf.reshuffle_per_epoch,
    )
    _, val_loader = torchtext.data.BPTTIterator.splits(
        (train, valid),
        batch_size=conf.batch_size,
        bptt_len=conf.rnn_bptt_len,
        device="cuda:{}".format(conf.graph.device[0])
        if conf.graph.on_cuda else None,
        shuffle=False,
    )

    # get some stat.
    _get_nlp_data_stat(conf, train, valid, train_loader, val_loader)
    return {
        "TEXT": TEXT,
        "train_loader": train_loader,
        "val_loader": val_loader
    }
예제 #3
0
def _define_cv_dataset(conf,
                       partition_type,
                       dataset_type,
                       force_shuffle=False):
    """ Given a dataset, partition it. """
    dataset = get_dataset(conf, conf.data, conf.data_dir, split=dataset_type)
    batch_size = conf.batch_size
    world_size = conf.graph.n_nodes

    # determine the data to load,
    # either the whole dataset, or a subset specified by partition_type.
    if partition_type is not None and conf.distributed:
        partition_sizes = [1.0 / world_size for _ in range(world_size)]
        partition = DataPartitioner(conf,
                                    dataset,
                                    partition_sizes,
                                    partition_type=partition_type)
        data_to_load = partition.use(conf.graph.rank)
        print("Data partition: partitioned data and use subdata.")
    else:
        data_to_load = dataset
        print("Data partition: used whole data.")

    # use Dataloader.
    data_loader = torch.utils.data.DataLoader(
        data_to_load,
        batch_size=batch_size,
        shuffle=force_shuffle or dataset_type == "train",
        num_workers=conf.num_workers,
        pin_memory=conf.pin_memory,
        drop_last=False,
    )

    print(("Data stat: we have {} samples for {}, " +
           "load {} data for process (rank {}). " +
           "The batch size is {}, number of batches is {}.").format(
               len(dataset),
               dataset_type,
               len(data_to_load),
               conf.graph.rank,
               batch_size,
               len(data_loader),
           ))
    return data_loader
    def _define_aggregation_data(self, return_loader=True):
        # init.
        fl_aggregate = self.conf.fl_aggregate

        # prepare the data.
        if self.dataset["val"] is not None:
            # prepare the dataloader.
            data_loader = torch.utils.data.DataLoader(
                self.dataset["val"],
                batch_size=self.conf.batch_size,
                shuffle=False,
                num_workers=self.conf.num_workers,
                pin_memory=self.conf.pin_memory,
                drop_last=False,
            )
            # define things to return.
            things_to_return = {"self_val_data_loader": data_loader}
        else:
            things_to_return = {}

        if "data_source" in fl_aggregate and "other" in fl_aggregate["data_source"]:
            assert (
                "data_name" in fl_aggregate
                and "data_scheme" in fl_aggregate
                and "data_type" in fl_aggregate
            )

            # create dataset.
            self.logger.log(f'create data={fl_aggregate["data_name"]} for aggregation.')
            dataset = prepare_data.get_dataset(
                self.conf,
                fl_aggregate["data_name"],
                datasets_path=self.conf.data_dir
                if "data_dir" not in fl_aggregate
                else fl_aggregate["data_dir"],
                split="train",
            )
            self.logger.log(
                f'created data={fl_aggregate["data_name"]} for aggregation with size {len(dataset)}.'
            )

            # sample the indices from the dataset.
            if fl_aggregate["data_scheme"] == "random_sampling":
                assert "data_percentage" in fl_aggregate
                sampler = partition_data.DataSampler(
                    self.conf,
                    data=dataset,
                    data_scheme=fl_aggregate["data_scheme"],
                    data_percentage=fl_aggregate["data_percentage"],
                )
            elif fl_aggregate["data_scheme"] == "class_selection":
                assert "num_overlap_class" in fl_aggregate
                assert "num_total_class" in fl_aggregate
                assert self.conf.data == "cifar100"
                assert "imagenet" in self.conf.fl_aggregate["data_name"]

                #
                selected_imagenet_classes = partition_data.get_imagenet1k_classes(
                    num_overlap_classes=int(fl_aggregate["num_overlap_class"]),
                    random_state=self.conf.random_state,
                    num_total_classes=int(
                        fl_aggregate["num_total_class"]
                    ),  # for cifar-100
                )
                sampler = partition_data.DataSampler(
                    self.conf,
                    data=dataset,
                    data_scheme=fl_aggregate["data_scheme"],
                    data_percentage=fl_aggregate["data_percentage"]
                    if "data_percentage" in fl_aggregate
                    else None,
                    selected_classes=selected_imagenet_classes,
                )
            else:
                raise NotImplementedError("invalid data_scheme")

            sampler.sample_indices()

            # define things to return.
            things_to_return.update({"sampler": sampler})

            if return_loader:
                data_loader = torch.utils.data.DataLoader(
                    sampler.use_indices(),
                    batch_size=self.conf.batch_size,
                    shuffle=fl_aggregate["randomness"]
                    if "randomness" in fl_aggregate
                    else True,
                    num_workers=self.conf.num_workers,
                    pin_memory=self.conf.pin_memory,
                    drop_last=False,
                )
            things_to_return.update({"data_loader": data_loader})
        return things_to_return