예제 #1
0
        def raw_to_tensor(inputs):
            is_one = True  # batch_size 1 flag
            feature, _helper = data_reader.read_one_example(inputs)

            nonlocal helper
            helper.update(_helper)

            if type(feature) == list:
                is_one = False
                features = feature
            else:
                features = [feature]

            self._index_features(features,
                                 data_reader.text_columns,
                                 suppress_tqdm=True)

            if is_one:
                indexed_features = features[0]
            else:  # when features > 1, need to transpose (dict_of_list -> list_of_dict)
                indexed_features = {}
                for key in features[0]:
                    feature_with_key = [feature[key] for feature in features]
                    indexed_features[key] = transpose(feature_with_key,
                                                      skip_keys=["text"])

            for key in indexed_features:
                for token_name in self.token_makers:
                    if token_name not in indexed_features[key]:
                        continue

                    indexed_values = indexed_features[key][token_name]
                    if is_one:
                        indexed_values = [indexed_values]

                    tensor = padding_tokens(indexed_values,
                                            token_name=token_name)
                    if cuda_device is not None and type(tensor) != list:
                        tensor = tensor.cuda(cuda_device)
                    indexed_features[key][token_name] = tensor

            for key in indexed_features:
                if "text" in indexed_features[key]:
                    del indexed_features[key]["text"]

            return indexed_features, helper
예제 #2
0
파일: collate.py 프로젝트: seongl/claf
 def _apply_pad(self, value, token_name=None, pad_value=0):
     return utils.padding_tokens(value,
                                 token_name=token_name,
                                 pad_value=pad_value)