Exemplo n.º 1
0
  def run(self,
          inputs: List[JsonDict],
          dataset: lit_dataset.Dataset,
          config: Optional[JsonDict] = None):
    """Run generation on a set of inputs.

    Args:
      inputs: sequence of inputs, following dataset.spec()
      dataset: dataset, used to access dataset.spec()
      config: additional runtime options

    Returns:
      list of list of new generated inputs, following dataset.spec()
    """
    all_outputs = [[] for _ in inputs]

    # Find text fields
    text_fields = utils.find_spec_keys(dataset.spec(), types.TextSegment)
    # TODO(lit-team): configure a subset of fields to operate on
    candidates_by_field = {}
    for field_name in text_fields:
      texts = [ex[field_name] for ex in inputs]
      candidates_by_field[field_name] = self.generate_from_texts(texts)
    # Generate by substituting in each field.
    # TODO(lit-team): substitute on a combination of fields?
    for field_name in candidates_by_field:
      candidates = candidates_by_field[field_name]
      for i, ex in enumerate(inputs):
        for candidate in candidates[i]:
          new_ex = utils.copy_and_update(ex, {field_name: candidate})
          all_outputs[i].append(new_ex)
    return all_outputs
Exemplo n.º 2
0
  def test_copy_and_update(self):
    d = {
        "a": True,
        "b": False,
        "c": True
    }
    update = {
        "a": False,
        "b": True
    }
    expected = {
        "a": False,
        "b": True,
        "c": True
    }
    self.assertDictEqual(expected, utils.copy_and_update(d, update))

    d = {
        "a": True,
        "b": False,
    }
    update = {
        "a": False,
        "c": True
    }
    expected = {
        "a": False,
        "b": False,
        "c": True
    }
    self.assertDictEqual(expected, utils.copy_and_update(d, update))

    d = {
        "a": True,
        "b": False,
    }
    update = {}
    self.assertDictEqual(d, utils.copy_and_update(d, update))

    d = {}
    update = {
        "a": False,
        "c": True
    }
    self.assertDictEqual(update, utils.copy_and_update(d, update))
Exemplo n.º 3
0
    def run(self,
            inputs: List[JsonDict],
            dataset: lit_dataset.Dataset,
            config: Optional[JsonDict] = None):
        """Run generation on a set of inputs.

    Args:
      inputs: sequence of inputs, following dataset.spec()
      dataset: dataset, used to access dataset.spec()
      config: additional runtime options

    Returns:
      list of list of new generated inputs, following dataset.spec()
    """
        all_outputs = [[] for _ in inputs]

        config = config or {}

        # Find text fields.
        text_fields = utils.find_spec_keys(dataset.spec(), types.TextSegment)
        # If config key is missing, backtranslate all text fields.
        fields_to_backtranslate = list(
            config.get(FIELDS_TO_BACKTRANSLATE_KEY, text_fields))
        candidates_by_field = {}
        for field_name in fields_to_backtranslate:
            texts = [ex[field_name] for ex in inputs]
            candidates_by_field[field_name] = self.generate_from_texts(texts)
        # Generate by substituting in each field.
        # TODO(lit-team): substitute on a combination of fields?
        for field_name in candidates_by_field:
            candidates = candidates_by_field[field_name]
            for i, ex in enumerate(inputs):
                for candidate in candidates[i]:
                    new_ex = utils.copy_and_update(ex, {field_name: candidate})
                    all_outputs[i].append(new_ex)
        return all_outputs