def get_feature_metadata(cls, feature_config: FeatureConfig, feature_meta: Dict[str, FieldMeta]): # The number of names in input_names *must* be equal to the number of # tensors passed in dummy_input input_names: List[str] = [] dummy_model_input: List = [] feature_itos_map = {} for name, feat_config in feature_config._asdict().items(): if isinstance(feat_config, ConfigBase): input_names.extend(feat_config.export_input_names) if getattr(feature_meta[name], "vocab", None): feature_itos_map[feat_config.export_input_names[ 0]] = feature_meta[name].vocab.itos dummy_model_input.append(feature_meta[name].dummy_model_input) if "tokens_vals" in input_names: dummy_model_input.append(torch.tensor( [1, 1], dtype=torch.long)) # token lengths input_names.append("tokens_lens") if "seq_tokens_vals" in input_names: dummy_model_input.append(torch.tensor( [1, 1], dtype=torch.long)) # seq lengths input_names.append("seq_tokens_lens") return input_names, tuple(dummy_model_input), feature_itos_map
def _get_exportable_metadata( cls, exportable_filter: Callable, feature_config: FeatureConfig, feature_meta: Dict[str, FieldMeta], ) -> Tuple[List[str], List, Dict]: # The number of names in input_names *must* be equal to the number of # tensors passed in dummy_input input_names: List[str] = [] dummy_model_input: List = [] feature_itos_map = {} for name, feat_config in feature_config._asdict().items(): if exportable_filter(feat_config): input_names.extend(feat_config.export_input_names) if getattr(feature_meta[name], "vocab", None): feature_itos_map[feat_config.export_input_names[ 0]] = feature_meta[name].vocab.itos dummy_model_input.append(feature_meta[name].dummy_model_input) return input_names, dummy_model_input, feature_itos_map
def create_sub_embs(cls, emb_config: FeatureConfig, metadata: CommonMetadata) -> Dict[str, EmbeddingBase]: """ Creates the embedding modules defined in the `emb_config`. Args: emb_config (FeatureConfig): Object containing all the sub-embedding configurations. metadata (CommonMetadata): Object containing features and label metadata. Returns: Dict[str, EmbeddingBase]: Named dictionary of embedding modules. """ sub_emb_module_dict = {} for name, config in emb_config._asdict().items(): if issubclass(getattr(config, "__COMPONENT__", object), EmbeddingBase): sub_emb_module_dict[name] = create_module( config, metadata=metadata.features[name]) else: print(f"{name} is not a config of embedding, skipping") return sub_emb_module_dict