コード例 #1
0
 def __init__(self, config, metadata,
              text_embedder: EmbedderInterface.Config):
     super().__init__(channels=[
         C.TensorBoardChannel(SummaryWriter(config.output_path)),
         C.ConsoleChannel()
     ])
     self.text_embedder = EmbedderInterface.from_config(text_embedder)
     self._reset()
コード例 #2
0
 def __init__(self, embedder_cfg: EmbedderInterface.Config, all_responses: bool = False,
              fixed_n_turns: bool = False):
   # Common setup in the process for it's lifetime
   self.text_embedder = EmbedderInterface.from_config(embedder_cfg)
   self.pad_token_idx = self.text_embedder.pad_idx
   self.unk_token_idx = self.text_embedder.unk_idx
   self.bos_token_idx = self.text_embedder.bos_idx
   self.eos_token_idx = self.text_embedder.eos_idx
   self.fixed_n_turns = fixed_n_turns
   self.all_responses = all_responses
コード例 #3
0
 class Config(ConfigBase):
   model: RetrievalModel.Config = RetrievalModel.Config()
   trainer: RetrievalTrainer.Config = RetrievalTrainer.Config(
     report_train_metrics=True,
     save_modules_checkpoint=True,
     modules_save_dir="exp/retrieval"
   )
   featurizer: TokenIdFeaturizer.Config = TokenIdFeaturizer.Config()
   features: ModelInputConfig = ModelInputConfig()
   labels: ModelOutputConfig = ModelOutputConfig()   # was: WordLabelConfig
   metric_reporter: MetaLearnMetricReporter.Config = MetaLearnMetricReporter.Config()
   text_embedder: EmbedderInterface.Config = EmbedderInterface.Config()
   # Maybe we could just have a single instance, would need the nested batch iterator
   data_handler: MetaDataHandler.Config = MetaDataHandler.Config()
   model_needs_meta_training: bool = True
コード例 #4
0
  def from_config(cls, config: Config,
                  feature_config: ModelInputConfig,
                  target_config: ModelOutputConfig,
                  text_embedder_config: EmbedderInterface.Config,
                  **kwargs):

    text_embedder: EmbedderInterface = EmbedderInterface.from_config(text_embedder_config)
    features: Dict[str, Field] = {
      ModelInput.SEQ: BPEField(text_embedder)
    }
    assert len(features)

    targets: Dict[str, Field] = {
      ModelOutputConfig._name: BPEField(text_embedder, is_target=True, all_responses=config.all_responses),
    }
    extra_fields = {
      RAW_TEXT: RawField(),
      ModelInput.DLG_LEN: RawField(),
      ModelInput.DLG_ID: RawField(),
      ModelInput.DOMAIN_ID: RawField(),
      ModelInput.TASK_ID: RawField()
    }

    kwargs.update(config.items())
    self = cls(
      raw_columns=[],  # ignored in our read function
      features=features,
      labels=targets,
      extra_fields=extra_fields,
      **kwargs,
    )
    self.max_turns = config.max_turns
    self.text_embedder_cfg = text_embedder_config
    self.all_responses = config.all_responses
    self.preproc_chunksize = config.preproc_chunksize
    self.train_domains = config.train_domains
    self.eval_domains = config.eval_domains
    self.featurized_cache_dir = config.featurized_cache_dir
    self.test_domains = config.test_domains
    self.text_embedder = text_embedder
    self.seed = config.seed
    return self
コード例 #5
0
 def __init__(self, config: Config, feature_config, text_embedder_config):
     self.text_embedder = EmbedderInterface.from_config(
         text_embedder_config)