예제 #1
0
파일: criteo.py 프로젝트: xz10620/inference
    def __init__(self,
                 data_path,
                 name,
                 pre_process,
                 use_cache,
                 count=None,
                 samples_to_aggregate_fix=None,
                 samples_to_aggregate_min=None,
                 samples_to_aggregate_max=None,
                 samples_to_aggregate_quantile_file=None,
                 samples_to_aggregate_trace_file=None,
                 test_num_workers=0,
                 max_ind_range=-1,
                 sub_sample_rate=0.0,
                 mlperf_bin_loader=False,
                 randomize="total",
                 memory_map=False):
        super().__init__()

        self.count = count
        self.random_offsets = []
        self.use_fixed_size = ((samples_to_aggregate_quantile_file is None) and
                               (samples_to_aggregate_min is None or samples_to_aggregate_max is None))
        if self.use_fixed_size:
            # fixed size queries
            self.samples_to_aggregate = 1 if samples_to_aggregate_fix is None else samples_to_aggregate_fix
            self.samples_to_aggregate_min = None
            self.samples_to_aggregate_max = None
        else:
            # variable size queries
            self.samples_to_aggregate = 1
            self.samples_to_aggregate_min = samples_to_aggregate_min
            self.samples_to_aggregate_max = samples_to_aggregate_max
            self.samples_to_aggregate_quantile_file = samples_to_aggregate_quantile_file

        if name == "kaggle":
            raw_data_file = data_path + "/train.txt"
            processed_data_file = data_path + "/kaggleAdDisplayChallenge_processed.npz"
        elif name == "terabyte":
            raw_data_file = data_path + "/day"
            processed_data_file = data_path + "/terabyte_processed.npz"
        else:
            raise ValueError("only kaggle|terabyte dataset options are supported")
        self.use_mlperf_bin_loader = mlperf_bin_loader and memory_map and name == "terabyte"
        # debug prints
        # print("dataset filenames", raw_data_file, processed_data_file)

        self.test_data = dp.CriteoDataset(
            dataset=name,
            max_ind_range=max_ind_range,
            sub_sample_rate=sub_sample_rate,
            randomize=randomize,
            split="test",
            raw_path=raw_data_file,
            pro_data=processed_data_file,
            memory_map=memory_map
        )
        self.num_individual_samples = len(self.test_data)

        if self.use_mlperf_bin_loader:

            test_file = data_path + "/terabyte_processed_test.bin"
            counts_file = raw_data_file + '_fea_count.npz'

            data_loader_terabyte.numpy_to_binary(
                input_files=[raw_data_file + '_23_reordered.npz'],
                output_file_path=data_path + "/terabyte_processed_test.bin",
                split="test")

            self.test_data = data_loader_terabyte.CriteoBinDataset(
                data_file=test_file,
                counts_file=counts_file,
                batch_size=self.samples_to_aggregate,
                max_ind_range=max_ind_range
            )

            self.test_loader = torch.utils.data.DataLoader(
                self.test_data,
                batch_size=None,
                batch_sampler=None,
                shuffle=False,
                num_workers=0,
                collate_fn=None,
                pin_memory=False,
                drop_last=False,
            )
        else:

            self.test_loader = torch.utils.data.DataLoader(
                self.test_data,
                batch_size=self.samples_to_aggregate,
                shuffle=False,
                num_workers=test_num_workers,
                collate_fn=dp.collate_wrapper_criteo,
                pin_memory=False,
                drop_last=False,
            )

        # WARNING: Note that the orignal dataset returns number of samples, while the
        # binary dataset returns the number of batches. Therefore, when using a mini-batch
        # of size samples_to_aggregate as an item we need to adjust the original dataset item_count.
        # On the other hand, data loader always returns number of batches.
        if self.use_fixed_size:
            # the offsets for fixed query size will be generated on-the-fly later on
            print("Using fixed query size: " + str(self.samples_to_aggregate))
            if self.use_mlperf_bin_loader:
                self.num_aggregated_samples = len(self.test_data)
                # self.num_aggregated_samples2 = len(self.test_loader)
            else:
                self.num_aggregated_samples = (self.num_individual_samples + self.samples_to_aggregate - 1) // self.samples_to_aggregate
                # self.num_aggregated_samples2 = len(self.test_loader)
        else:
            # the offsets for variable query sizes will be pre-generated here
            if self.samples_to_aggregate_quantile_file is None:
                # generate number of samples in a query from a uniform(min,max) distribution
                print("Using variable query size: uniform distribution (" + str(self.samples_to_aggregate_min) + "," + str(self.samples_to_aggregate_max) +  ")")
                done = False
                qo = 0
                while done == False:
                    self.random_offsets.append(int(qo))
                    qs = random.randint(self.samples_to_aggregate_min, self.samples_to_aggregate_max)
                    qo = min(qo + qs, self.num_individual_samples)
                    if qo >= self.num_individual_samples:
                        done = True
                self.random_offsets.append(int(qo))

                # compute min and max number of samples
                nas_max = (self.num_individual_samples + self.samples_to_aggregate_min - 1) // self.samples_to_aggregate_min
                nas_min = (self.num_individual_samples + self.samples_to_aggregate_max - 1) // self.samples_to_aggregate_max
            else:
                # generate number of samples in a query from a custom distribution,
                # with quantile (inverse of its cdf) given in the file. Note that
                # quantile is related to the concept of percentile in statistics.
                #
                # For instance, assume that we have the following distribution for query length
                # length = [100, 200, 300,  400,  500,  600,  700] # x
                # pdf =    [0.1, 0.6, 0.1, 0.05, 0.05, 0.05, 0.05] # p(x)
                # cdf =    [0.1, 0.7, 0.8, 0.85,  0.9, 0.95,  1.0] # f(x) = prefix-sum of p(x)
                # The inverse of its cdf with granularity of 0.05 can be written as
                # quantile_p = [.05, .10, .15, .20, .25, .30, .35, .40, .45, .50, .55, .60, .65, .70, .75, .80, .85, .90, .95, 1.0] # p
                # quantile_x = [100, 100, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 300, 300, 400, 500, 600, 700] # q(p) = x, such that f(x) >= p
                # Notice that once we have quantile, we can apply inverse transform sampling method.
                print("Using variable query size: custom distribution (file " + str(samples_to_aggregate_quantile_file) + ")")
                with open(self.samples_to_aggregate_quantile_file, 'r') as f:
                    line = f.readline()
                    quantile = np.fromstring(line, dtype=int, sep=", ")
                # debug prints
                # print(quantile)
                # print(len(quantile))

                l = len(quantile)
                done = False
                qo = 0
                while done == False:
                    self.random_offsets.append(int(qo))
                    pr = np.random.randint(low=0, high=l)
                    qs = quantile[pr]
                    qo = min(qo + qs, self.num_individual_samples)
                    if qo >= self.num_individual_samples:
                        done = True
                self.random_offsets.append(int(qo))

                # compute min and max number of samples
                nas_max = (self.num_individual_samples + quantile[0] - 1) // quantile[0]
                nas_min = (self.num_individual_samples + quantile[-1]- 1) // quantile[-1]

            # reset num_aggregated_samples
            self.num_aggregated_samples = len(self.random_offsets) - 1

            # check num_aggregated_samples
            if self.num_aggregated_samples < nas_min or nas_max < self.num_aggregated_samples:
                raise ValueError("Sannity check failed")

        # limit number of items to count if needed
        if self.count is not None:
            self.num_aggregated_samples = min(self.count, self.num_aggregated_samples)

        # dump the trace of aggregated samples
        if samples_to_aggregate_trace_file is not None:
            with open(samples_to_aggregate_trace_file, 'w') as f:
                for l in range(self.num_aggregated_samples):
                    if self.use_fixed_size:
                        s = l * self.samples_to_aggregate
                        e = min((l + 1) * self.samples_to_aggregate, self.num_individual_samples)
                    else:
                        s = self.random_offsets[l]
                        e = self.random_offsets[l+1]
                    f.write(str(s) + ", " + str(e) + ", " + str(e-s) + "\n")
예제 #2
0
def make_criteo_data_and_loaders(args):

    if args.mlperf_logging and args.memory_map and args.data_set == "terabyte":
        # more efficient for larger batches
        data_directory = path.dirname(args.raw_data_file)

        if args.mlperf_bin_loader:
            lstr = args.processed_data_file.split("/")
            d_path = "/".join(lstr[0:-1]) + "/" + lstr[-1].split(".")[0]
            train_file = d_path + "_train.bin"
            test_file = d_path + "_test.bin"
            # val_file = d_path + "_val.bin"
            counts_file = args.raw_data_file + '_fea_count.npz'

            if any(not path.exists(p) for p in [train_file,
                                                test_file,
                                                counts_file]):
                ensure_dataset_preprocessed(args, d_path)

            train_data = data_loader_terabyte.CriteoBinDataset(
                data_file=train_file,
                counts_file=counts_file,
                batch_size=args.mini_batch_size,
                max_ind_range=args.max_ind_range
            )

            train_loader = torch.utils.data.DataLoader(
                train_data,
                batch_size=None,
                batch_sampler=None,
                shuffle=False,
                num_workers=0,
                collate_fn=None,
                pin_memory=False,
                drop_last=False,
                sampler=RandomSampler(train_data) if args.mlperf_bin_shuffle else None
            )

            test_data = data_loader_terabyte.CriteoBinDataset(
                data_file=test_file,
                counts_file=counts_file,
                batch_size=args.test_mini_batch_size,
                max_ind_range=args.max_ind_range
            )

            test_loader = torch.utils.data.DataLoader(
                test_data,
                batch_size=None,
                batch_sampler=None,
                shuffle=False,
                num_workers=0,
                collate_fn=None,
                pin_memory=False,
                drop_last=False,
            )
        else:
            data_filename = args.raw_data_file.split("/")[-1]

            train_data = CriteoDataset(
                args.data_set,
                args.max_ind_range,
                args.data_sub_sample_rate,
                args.data_randomize,
                "train",
                args.raw_data_file,
                args.processed_data_file,
                args.memory_map
            )

            test_data = CriteoDataset(
                args.data_set,
                args.max_ind_range,
                args.data_sub_sample_rate,
                args.data_randomize,
                "test",
                args.raw_data_file,
                args.processed_data_file,
                args.memory_map
            )

            train_loader = data_loader_terabyte.DataLoader(
                data_directory=data_directory,
                data_filename=data_filename,
                days=list(range(23)),
                batch_size=args.mini_batch_size,
                max_ind_range=args.max_ind_range,
                split="train"
            )

            test_loader = data_loader_terabyte.DataLoader(
                data_directory=data_directory,
                data_filename=data_filename,
                days=[23],
                batch_size=args.test_mini_batch_size,
                max_ind_range=args.max_ind_range,
                split="test"
            )
    else:
        train_data = CriteoDataset(
            args.data_set,
            args.max_ind_range,
            args.data_sub_sample_rate,
            args.data_randomize,
            "train",
            args.raw_data_file,
            args.processed_data_file,
            args.memory_map
        )

        test_data = CriteoDataset(
            args.data_set,
            args.max_ind_range,
            args.data_sub_sample_rate,
            args.data_randomize,
            "test",
            args.raw_data_file,
            args.processed_data_file,
            args.memory_map
        )

        train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.mini_batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            collate_fn=collate_wrapper_criteo,
            pin_memory=False,
            drop_last=False,  # True
        )

        test_loader = torch.utils.data.DataLoader(
            test_data,
            batch_size=args.test_mini_batch_size,
            shuffle=False,
            num_workers=args.test_num_workers,
            collate_fn=collate_wrapper_criteo,
            pin_memory=False,
            drop_last=False,  # True
        )

    return train_data, train_loader, test_data, test_loader
예제 #3
0
    def __init__(self, data_path, name, pre_process, use_cache, count=None, test_num_workers=0, max_ind_range=-1, sub_sample_rate=0.0, mlperf_bin_loader=False, randomize="total", memory_map=False):
        super().__init__()
        # debug prints
        # print('__init__', data_path, name, pre_process, use_cache, count, test_num_workers, max_ind_range, sub_sample_rate, randomize, memory_map)

        self.count = count

        if name == "kaggle":
            raw_data_file = data_path + "/train.txt"
            processed_data_file = data_path + "/kaggleAdDisplayChallenge_processed.npz"
        elif name == "terabyte":
            raw_data_file = data_path + "/day"
            processed_data_file = data_path + "/terabyte_processed.npz"
        else:
            raise ValueError("only kaggle|terabyte dataset options are supported")
        self.use_mlperf_bin_loader = mlperf_bin_loader and memory_map and name == "terabyte"
        # debug prints
        # print("dataset filenames", raw_data_file, processed_data_file)

        self.test_data = dp.CriteoDataset(
            dataset=name,
            max_ind_range=max_ind_range,
            sub_sample_rate=sub_sample_rate,
            randomize=randomize,
            split="test",
            raw_path=raw_data_file,
            pro_data=processed_data_file,
            memory_map=memory_map
        )

        if self.use_mlperf_bin_loader:

            test_file = data_path + "/terabyte_processed_test.bin"
            counts_file = raw_data_file + '_fea_count.npz'

            data_loader_terabyte.numpy_to_binary(
                input_files=[raw_data_file + '_23_reordered.npz'],
                output_file_path=data_path + "/terabyte_processed_test.bin",
                split="test")

            self.test_data = data_loader_terabyte.CriteoBinDataset(
                data_file=test_file,
                counts_file=counts_file,
                batch_size=1, # FIGURE this out
                max_ind_range=max_ind_range
            )

            self.test_loader = torch.utils.data.DataLoader(
                self.test_data,
                batch_size=None,
                batch_sampler=None,
                shuffle=False,
                num_workers=0,
                collate_fn=None,
                pin_memory=False,
                drop_last=False,
            )
        else:

            self.test_loader = torch.utils.data.DataLoader(
                self.test_data,
                batch_size=1,  # FIGURE this out
                shuffle=False,
                num_workers=test_num_workers,
                collate_fn=dp.collate_wrapper_criteo,
                pin_memory=False,
                drop_last=False,
            )
예제 #4
0
파일: criteo.py 프로젝트: wom-ai/inference
    def __init__(self,
                 data_path,
                 name,
                 pre_process,
                 use_cache,
                 count=None,
                 samples_to_aggregate=None,
                 min_samples_to_aggregate=None,
                 max_samples_to_aggregate=None,
                 test_num_workers=0,
                 max_ind_range=-1,
                 sub_sample_rate=0.0,
                 mlperf_bin_loader=False,
                 randomize="total",
                 memory_map=False):
        super().__init__()

        self.count = count
        self.random_offsets = []
        self.use_fixed_size = min_samples_to_aggregate is None or max_samples_to_aggregate is None
        if self.use_fixed_size:
            # fixed size queries
            self.samples_to_aggregate = 1 if samples_to_aggregate is None else samples_to_aggregate
            self.min_samples_to_aggregate = None
            self.max_samples_to_aggregate = None
        else:
            # variable size queries
            self.samples_to_aggregate = 1
            self.min_samples_to_aggregate = min_samples_to_aggregate
            self.max_samples_to_aggregate = max_samples_to_aggregate

        if name == "kaggle":
            raw_data_file = data_path + "/train.txt"
            processed_data_file = data_path + "/kaggleAdDisplayChallenge_processed.npz"
        elif name == "terabyte":
            raw_data_file = data_path + "/day"
            processed_data_file = data_path + "/terabyte_processed.npz"
        else:
            raise ValueError(
                "only kaggle|terabyte dataset options are supported")
        self.use_mlperf_bin_loader = mlperf_bin_loader and memory_map and name == "terabyte"
        # debug prints
        # print("dataset filenames", raw_data_file, processed_data_file)

        self.test_data = dp.CriteoDataset(dataset=name,
                                          max_ind_range=max_ind_range,
                                          sub_sample_rate=sub_sample_rate,
                                          randomize=randomize,
                                          split="test",
                                          raw_path=raw_data_file,
                                          pro_data=processed_data_file,
                                          memory_map=memory_map)
        self.num_individual_samples = len(self.test_data)

        if self.use_mlperf_bin_loader:

            test_file = data_path + "/terabyte_processed_test.bin"
            counts_file = raw_data_file + '_fea_count.npz'

            data_loader_terabyte.numpy_to_binary(
                input_files=[raw_data_file + '_23_reordered.npz'],
                output_file_path=data_path + "/terabyte_processed_test.bin",
                split="test")

            self.test_data = data_loader_terabyte.CriteoBinDataset(
                data_file=test_file,
                counts_file=counts_file,
                batch_size=self.samples_to_aggregate,
                max_ind_range=max_ind_range)

            self.test_loader = torch.utils.data.DataLoader(
                self.test_data,
                batch_size=None,
                batch_sampler=None,
                shuffle=False,
                num_workers=0,
                collate_fn=None,
                pin_memory=False,
                drop_last=False,
            )
        else:

            self.test_loader = torch.utils.data.DataLoader(
                self.test_data,
                batch_size=self.samples_to_aggregate,
                shuffle=False,
                num_workers=test_num_workers,
                collate_fn=dp.collate_wrapper_criteo,
                pin_memory=False,
                drop_last=False,
            )

        # WARNING: Note that the orignal dataset returns number of samples, while the
        # binary dataset returns the number of batches. Therefore, when using a mini-batch
        # of size samples_to_aggregate as an item we need to adjust the original dataset item_count.
        # On the other hand, data loader always returns number of batches.
        if self.use_fixed_size:
            # the offsets for fixed query size will be generated on-the-fly later on
            if self.use_mlperf_bin_loader:
                self.num_aggregated_samples = len(self.test_data)
                # self.num_aggregated_samples2 = len(self.test_loader)
            else:
                self.num_aggregated_samples = (self.num_individual_samples +
                                               self.samples_to_aggregate -
                                               1) // self.samples_to_aggregate
                # self.num_aggregated_samples2 = len(self.test_loader)
        else:
            # the offsets for variable query sizes will be pre-generated here
            done = False
            qo = 0
            while done == False:
                self.random_offsets.append(int(qo))
                qs = random.randint(self.min_samples_to_aggregate,
                                    self.max_samples_to_aggregate)
                qo = min(qo + qs, self.num_individual_samples)
                if qo >= self.num_individual_samples:
                    done = True
            self.random_offsets.append(int(qo))

            # reset num_aggregated_samples
            self.num_aggregated_samples = len(self.random_offsets) - 1

            # check num_aggregated_samples
            nas_max = (self.num_individual_samples +
                       self.min_samples_to_aggregate -
                       1) // self.min_samples_to_aggregate
            nas_min = (self.num_individual_samples +
                       self.max_samples_to_aggregate -
                       1) // self.max_samples_to_aggregate
            if self.num_aggregated_samples < nas_min or nas_max < self.num_aggregated_samples:
                raise ValueError("Sannity check failed")

        # limit number of items to count if needed
        if self.count is not None:
            self.num_aggregated_samples = min(self.count,
                                              self.num_aggregated_samples)

        # dump the trace of aggregated samples
        with open('dlrm_trace_of_aggregated_samples.txt', 'w') as f:
            for l in range(self.num_aggregated_samples):
                if self.use_fixed_size:
                    s = l * self.samples_to_aggregate
                    e = min((l + 1) * self.samples_to_aggregate,
                            self.num_individual_samples)
                else:
                    s = self.random_offsets[l]
                    e = self.random_offsets[l + 1]
                f.write(str(s) + ", " + str(e) + "\n")
def make_criteo_data_and_loaders(args):
    use_pin_memory = False
    # FIXME: Check if the pinned memory is still required after the CustomPreprocessor
    # if (hasattr(args, 'no_habana') and (not args.no_habana)) or args.use_gpu:
    #     use_pin_memory = True

    if args.mlperf_logging and args.memory_map and args.data_set == "terabyte":
        # more efficient for larger batches
        data_directory = path.dirname(args.raw_data_file)

        if args.mlperf_bin_loader:
            lstr = args.processed_data_file.split("/")
            d_path = "/".join(lstr[0:-1]) + "/" + lstr[-1].split(".")[0]
            train_file = d_path + "_train.bin"
            test_file = d_path + "_test.bin"
            # val_file = d_path + "_val.bin"
            counts_file = args.raw_data_file + '_fea_count.npz'

            if any(not path.exists(p) for p in [train_file,
                                                test_file,
                                                counts_file]):
                ensure_dataset_preprocessed(args, d_path)

            train_data = data_loader_terabyte.CriteoBinDataset(
                data_file=train_file,
                counts_file=counts_file,
                batch_size=args.mini_batch_size,
                max_ind_range=args.max_ind_range
            )

            train_loader = torch.utils.data.DataLoader(
                train_data,
                batch_size=None,
                batch_sampler=None,
                shuffle=False,
                num_workers=0,
                collate_fn=None,
                pin_memory = use_pin_memory,
                drop_last=False,
                sampler=RandomSampler(train_data) if args.mlperf_bin_shuffle else None
            )

            test_data = data_loader_terabyte.CriteoBinDataset(
                data_file=test_file,
                counts_file=counts_file,
                batch_size=args.test_mini_batch_size,
                max_ind_range=args.max_ind_range
            )

            test_loader = torch.utils.data.DataLoader(
                test_data,
                batch_size=None,
                batch_sampler=None,
                shuffle=False,
                num_workers=0,
                collate_fn=None,
                pin_memory = use_pin_memory,
                drop_last=False,
            )
        else:
            data_filename = args.raw_data_file.split("/")[-1]

            train_data = CriteoDataset(
                args.data_set,
                args.max_ind_range,
                args.data_sub_sample_rate,
                args.data_randomize,
                "train",
                args.raw_data_file,
                args.processed_data_file,
                args.memory_map
            )

            test_data = CriteoDataset(
                args.data_set,
                args.max_ind_range,
                args.data_sub_sample_rate,
                args.data_randomize,
                "test",
                args.raw_data_file,
                args.processed_data_file,
                args.memory_map
            )

            train_loader = data_loader_terabyte.DataLoader(
                data_directory=data_directory,
                data_filename=data_filename,
                days=list(range(23)),
                batch_size=args.mini_batch_size,
                max_ind_range=args.max_ind_range,
                split="train"
            )

            test_loader = data_loader_terabyte.DataLoader(
                data_directory=data_directory,
                data_filename=data_filename,
                days=[23],
                batch_size=args.test_mini_batch_size,
                max_ind_range=args.max_ind_range,
                split="test"
            )
    else:
        train_data = CriteoDataset(
            args.data_set,
            args.max_ind_range,
            args.data_sub_sample_rate,
            args.data_randomize,
            "train",
            args.raw_data_file,
            args.processed_data_file,
            args.memory_map
        )

        test_data = CriteoDataset(
            args.data_set,
            args.max_ind_range,
            args.data_sub_sample_rate,
            args.data_randomize,
            "test",
            args.raw_data_file,
            args.processed_data_file,
            args.memory_map
        )

        if deviceIsHabana(args):
            from dlrm_habana_kernels import CustomPreProcessor
            collate_fn_train = CustomPreProcessor(train_data.counts, train_data.m_den, args,is_train=True).collate_wrapper_criteo
            collate_fn_test  = CustomPreProcessor(train_data.counts, train_data.m_den, args,is_train=False).collate_wrapper_criteo
        else:
            collate_fn = collate_wrapper_criteo


        train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.mini_batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            collate_fn=collate_fn_train,
            pin_memory = use_pin_memory,
            drop_last=True,
        )

        test_loader = torch.utils.data.DataLoader(
            test_data,
            batch_size=args.test_mini_batch_size,
            shuffle=False,
            num_workers=args.test_num_workers,
            collate_fn=collate_fn_test,
            pin_memory = use_pin_memory,
            drop_last=True,
        )

    return train_data, train_loader, test_data, test_loader
예제 #6
0
def make_criteo_data_and_loaders(args, offset_to_length_converter=False):
    if args.mlperf_logging and args.memory_map and args.data_set == "terabyte":
        # more efficient for larger batches
        data_directory = path.dirname(args.raw_data_file)

        if args.mlperf_bin_loader:
            lstr = args.processed_data_file.split("/")
            d_path = "/".join(lstr[0:-1]) + "/" + lstr[-1].split(".")[0]
            train_file = d_path + "_train.bin"
            test_file = d_path + "_test.bin"
            # val_file = d_path + "_val.bin"
            counts_file = args.raw_data_file + '_fea_count.npz'

            if any(not path.exists(p)
                   for p in [train_file, test_file, counts_file]):
                ensure_dataset_preprocessed(args, d_path)

            train_data = data_loader_terabyte.CriteoBinDataset(
                data_file=train_file,
                counts_file=counts_file,
                batch_size=args.mini_batch_size,
                max_ind_range=args.max_ind_range)

            mlperf_logger.log_event(key=mlperf_logger.constants.TRAIN_SAMPLES,
                                    value=train_data.num_samples)

            train_loader = torch.utils.data.DataLoader(
                train_data,
                batch_size=None,
                batch_sampler=None,
                shuffle=False,
                num_workers=0,
                collate_fn=None,
                pin_memory=False,
                drop_last=False,
                sampler=RandomSampler(train_data)
                if args.mlperf_bin_shuffle else None)

            test_data = data_loader_terabyte.CriteoBinDataset(
                data_file=test_file,
                counts_file=counts_file,
                batch_size=args.test_mini_batch_size,
                max_ind_range=args.max_ind_range)

            mlperf_logger.log_event(key=mlperf_logger.constants.EVAL_SAMPLES,
                                    value=test_data.num_samples)

            test_loader = torch.utils.data.DataLoader(
                test_data,
                batch_size=None,
                batch_sampler=None,
                shuffle=False,
                num_workers=0,
                collate_fn=None,
                pin_memory=False,
                drop_last=False,
            )
        else:
            data_filename = args.raw_data_file.split("/")[-1]

            train_data = CriteoDataset(
                args.data_set, args.max_ind_range, args.data_sub_sample_rate,
                args.data_randomize, "train", args.raw_data_file,
                args.processed_data_file, args.memory_map,
                args.dataset_multiprocessing)

            test_data = CriteoDataset(
                args.data_set, args.max_ind_range, args.data_sub_sample_rate,
                args.data_randomize, "test", args.raw_data_file,
                args.processed_data_file, args.memory_map,
                args.dataset_multiprocessing)

            train_loader = data_loader_terabyte.DataLoader(
                data_directory=data_directory,
                data_filename=data_filename,
                days=list(range(23)),
                batch_size=args.mini_batch_size,
                max_ind_range=args.max_ind_range,
                split="train")

            test_loader = data_loader_terabyte.DataLoader(
                data_directory=data_directory,
                data_filename=data_filename,
                days=[23],
                batch_size=args.test_mini_batch_size,
                max_ind_range=args.max_ind_range,
                split="test")
    else:
        train_data = CriteoDataset(
            args.data_set,
            args.max_ind_range,
            args.data_sub_sample_rate,
            args.data_randomize,
            "train",
            args.raw_data_file,
            args.processed_data_file,
            args.memory_map,
            args.dataset_multiprocessing,
        )

        test_data = CriteoDataset(
            args.data_set,
            args.max_ind_range,
            args.data_sub_sample_rate,
            args.data_randomize,
            "test",
            args.raw_data_file,
            args.processed_data_file,
            args.memory_map,
            args.dataset_multiprocessing,
        )

        collate_wrapper_criteo = collate_wrapper_criteo_offset
        if offset_to_length_converter:
            collate_wrapper_criteo = collate_wrapper_criteo_length

        try:
            sampler = torch.utils.data.distributed.DistributedSampler(
                train_data)
        except:
            print("distributed training is off. sampler is set to None")
            sampler = None
        train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.mini_batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            collate_fn=collate_wrapper_criteo,
            pin_memory=False,
            drop_last=False,  # True
            sampler=sampler)

        test_loader = torch.utils.data.DataLoader(
            test_data,
            batch_size=args.test_mini_batch_size,
            shuffle=False,
            num_workers=args.test_num_workers,
            collate_fn=collate_wrapper_criteo,
            pin_memory=False,
            drop_last=False,  # True
            sampler=sampler)
    return train_data, train_loader, test_data, test_loader