コード例 #1
0
    def get_dataloader(self, context, dataset_name, dataloader):
        name = "dataset." + dataset_name + "."
        sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
        dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
        batch_size = envs.get_global_env(name + "batch_size")

        reader_class = envs.get_global_env(name + "data_converter")
        reader_class_name = envs.get_global_env(name + "reader_class_name",
                                                "Reader")

        if sparse_slots == "" and dense_slots == "":
            reader = dataloader_instance.dataloader_by_name(
                reader_class,
                dataset_name,
                context["config_yaml"],
                context,
                reader_class_name=reader_class_name)

            reader_class = envs.lazy_instance_by_fliename(
                reader_class, reader_class_name)
            reader_ins = reader_class(context["config_yaml"])
        else:
            reader = dataloader_instance.slotdataloader_by_name(
                "", dataset_name, context["config_yaml"], context)
            reader_ins = SlotReader(context["config_yaml"])
        if hasattr(reader_ins, 'generate_batch_from_trainfiles'):
            dataloader.set_sample_list_generator(reader)
        elif hasattr(reader_ins, 'batch_tensor_creator'):
            dataloader.set_batch_generator(reader)
        else:
            dataloader.set_sample_generator(reader, batch_size)
        return dataloader
コード例 #2
0
ファイル: single_infer.py プロジェクト: vslyu/PaddleRec
 def _get_dataloader(self, dataset_name, dataloader):
     name = "dataset." + dataset_name + "."
     thread_num = envs.get_global_env(name + "thread_num")
     batch_size = envs.get_global_env(name + "batch_size")
     reader_class = envs.get_global_env(name + "data_converter")
     abs_dir = os.path.dirname(os.path.abspath(__file__))
     sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
     dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
     if sparse_slots == "" and dense_slots == "":
         reader = dataloader_instance.dataloader_by_name(
             reader_class, dataset_name, self._config_yaml)
         reader_class = envs.lazy_instance_by_fliename(
             reader_class, "TrainReader")
         reader_ins = reader_class(self._config_yaml)
     else:
         reader = dataloader_instance.slotdataloader_by_name(
             "", dataset_name, self._config_yaml)
         reader_ins = SlotReader(self._config_yaml)
     if hasattr(reader_ins, 'generate_batch_from_trainfiles'):
         dataloader.set_sample_list_generator(reader)
     else:
         dataloader.set_sample_generator(reader, batch_size)
     return dataloader