コード例 #1
0
ファイル: Fairseq.py プロジェクト: yyht/EasyNMT
    def make_batches(lines,
                     cfg,
                     task,
                     max_positions,
                     batch_size=None,
                     max_tokens=None):
        Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints")

        if cfg.generation.constraints:
            # Strip (tab-delimited) contraints, if present, from input lines,
            # store them in batch_constraints
            batch_constraints = [list() for _ in lines]
            for i, line in enumerate(lines):
                if "\t" in line:
                    lines[i], *batch_constraints[i] = line.split("\t")

            # Convert each List[str] to List[Tensor]
            for i, constraint_list in enumerate(batch_constraints):
                batch_constraints[i] = [
                    task.target_dictionary.encode_line(
                        constraint,
                        append_eos=False,
                        add_if_not_exist=False,
                    ) for constraint in constraint_list
                ]

        tokens = [
            task.source_dictionary.encode_line(src_str,
                                               add_if_not_exist=False).long()
            for src_str in lines
        ]

        if cfg.generation.constraints:
            constraints_tensor = pack_constraints(batch_constraints)
        else:
            constraints_tensor = None

        lengths = [t.numel() for t in tokens]
        itr = task.get_batch_iterator(
            dataset=task.build_dataset_for_inference(
                tokens, lengths, constraints=constraints_tensor),
            max_tokens=max_tokens,
            max_sentences=batch_size,
            max_positions=max_positions,
            ignore_invalid_inputs=cfg.dataset.
            skip_invalid_size_inputs_valid_test,
        ).next_epoch_itr(shuffle=False)
        for batch in itr:
            ids = batch["id"]
            src_tokens = batch["net_input"]["src_tokens"]
            src_lengths = batch["net_input"]["src_lengths"]
            constraints = batch.get("constraints", None)

            yield Batch(
                ids=ids,
                src_tokens=src_tokens,
                src_lengths=src_lengths,
                constraints=constraints,
            )
コード例 #2
0
ファイル: translate.py プロジェクト: bcmi220/d2gpo
def make_batches(lines, args, task, max_positions, encode_fn, truncate_size):
    def encode_fn_target(x):
        return encode_fn(x)

    if args.constraints:
        # Strip (tab-delimited) contraints, if present, from input lines,
        # store them in batch_constraints
        batch_constraints = [list() for _ in lines]
        for i, line in enumerate(lines):
            if "\t" in line:
                lines[i], *batch_constraints[i] = line.split("\t")

        # Convert each List[str] to List[Tensor]
        for i, constraint_list in enumerate(batch_constraints):
            batch_constraints[i] = [
                task.target_dictionary.encode_line(
                    encode_fn_target(constraint),
                    append_eos=False,
                    add_if_not_exist=False,
                ) for constraint in constraint_list
            ]

    tokens = [
        task.source_dictionary.encode_line(encode_fn(src_str),
                                           add_if_not_exist=False,
                                           max_tokens=truncate_size).long()
        for src_str in lines
    ]

    if args.constraints:
        constraints_tensor = pack_constraints(batch_constraints)
    else:
        constraints_tensor = None

    lengths = [t.numel() for t in tokens]
    itr = task.get_batch_iterator(
        dataset=task.build_dataset_for_inference(
            tokens, lengths, constraints=constraints_tensor),
        max_tokens=args.max_tokens,
        max_sentences=args.batch_size,
        max_positions=max_positions,
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
        ids = batch['id']
        src_tokens = batch['net_input']['src_tokens']
        src_lengths = batch['net_input']['src_lengths']
        constraints = batch.get("constraints", None)

        yield Batch(
            ids=ids,
            src_tokens=src_tokens,
            src_lengths=src_lengths,
            constraints=constraints,
        )
コード例 #3
0
def make_batches(lines,
                 cfg,
                 task,
                 max_positions,
                 encode_fn,
                 constrainted_decoding=False):
    def encode_fn_target(x):
        return encode_fn(x)

    if constrainted_decoding:
        # Strip (tab-delimited) contraints, if present, from input lines,
        # store them in batch_constraints
        batch_constraints = [list() for _ in lines]
        for i, line in enumerate(lines):
            if "\t" in line:
                lines[i], *batch_constraints[i] = line.split("\t")

        # Convert each List[str] to List[Tensor]
        for i, constraint_list in enumerate(batch_constraints):
            batch_constraints[i] = [
                task.target_dictionary.encode_line(
                    encode_fn_target(constraint),
                    append_eos=False,
                    add_if_not_exist=False,
                ) for constraint in constraint_list
            ]

    if constrainted_decoding:
        constraints_tensor = pack_constraints(batch_constraints)
    else:
        constraints_tensor = None

    tokens, lengths = task.get_interactive_tokens_and_lengths(lines, encode_fn)

    itr = task.get_batch_iterator(
        dataset=task.build_dataset_for_inference(
            tokens, lengths, constraints=constraints_tensor),
        max_tokens=cfg.dataset.max_tokens,
        max_sentences=cfg.dataset.batch_size,
        max_positions=max_positions,
        ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
        ids = batch["id"]
        src_tokens = batch["net_input"]["src_tokens"]
        src_lengths = batch["net_input"]["src_lengths"]
        constraints = batch.get("constraints", None)

        yield Batch(
            ids=ids,
            src_tokens=src_tokens,
            src_lengths=src_lengths,
            constraints=constraints,
        )
コード例 #4
0
ファイル: test_constraints.py プロジェクト: kahne/fairseq
 def test_sequences(self):
     for i, (constraints, tokens, expected) in enumerate(self.sequences):
         state = OrderedConstraintState.create(
             pack_constraints([constraints])[0])
         for token in tokens:
             state = state.advance(token)
         result = {}
         for attr in expected.keys():
             result[attr] = getattr(state, attr)
         assert (result == expected
                 ), f"TEST({tokens}) GOT: {result} WANTED: {expected}"
コード例 #5
0
 def constraint2tensor(self, constraints: [str]):
     for i, constraint_list in enumerate(constraints):
         constraints[i] = [
             # encode with src_dict as this becomes tgt
             self.bart.src_dict.encode_line(
                 self.bart.apply_bpe(constraint),
                 append_eos=False,
                 add_if_not_exist=False,
             ) for constraint in constraint_list
         ]
     return pack_constraints(constraints)
コード例 #6
0
ファイル: test_constraints.py プロジェクト: kahne/fairseq
 def test_packing(self):
     """Ensures the list of lists of tensors gets packed correctly."""
     for batch_constraints, expected_tensor in self.examples:
         packed = pack_constraints(batch_constraints)
         assert torch.equal(packed, expected_tensor)
コード例 #7
0
def make_batches(lines, cfg, task, max_positions, encode_fn):
    def encode_fn_target(x):
        return encode_fn(x)

    if cfg.generation.constraints:
        # Strip (tab-delimited) contraints, if present, from input lines,
        # store them in batch_constraints
        batch_constraints = [list() for _ in lines]
        for i, line in enumerate(lines):
            if "\t" in line:
                lines[i], *batch_constraints[i] = line.split("\t")

        # Convert each List[str] to List[Tensor]
        for i, constraint_list in enumerate(batch_constraints):
            batch_constraints[i] = [
                task.target_dictionary.encode_line(
                    encode_fn_target(constraint),
                    append_eos=False,
                    add_if_not_exist=False,
                ) for constraint in constraint_list
            ]

    # tokens = [
    #     task.source_dictionary.encode_line(
    #         encode_fn(src_str), add_if_not_exist=False
    #     ).long()[:max_positions[0] - 3]
    #     for src_str in lines
    # ]

    if getattr(cfg.task, "truncate_source", False):
        # tokens = [
        #     task.source_dictionary.encode_line(
        #         encode_fn(src_str), add_if_not_exist=False
        #     ).long()[:128]
        #     for src_str in lines
        # ]
        tokens = [
            task.source_dictionary.encode_line(
                encode_fn(src_str),
                add_if_not_exist=False).long()[:max_positions[0] - 4]
            for src_str in lines
        ]
    else:
        tokens = [
            task.source_dictionary.encode_line(encode_fn(src_str),
                                               add_if_not_exist=False).long()
            for src_str in lines
        ]

    if cfg.generation.constraints:
        constraints_tensor = pack_constraints(batch_constraints)
    else:
        constraints_tensor = None

    lengths = [t.numel() for t in tokens]
    itr = task.get_batch_iterator(
        dataset=task.build_dataset_for_inference(
            tokens, lengths, constraints=constraints_tensor),
        max_tokens=cfg.dataset.max_tokens,
        max_sentences=cfg.dataset.batch_size,
        max_positions=max_positions,
        ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
        ids = batch["id"]
        src_tokens = batch["net_input"]["src_tokens"]
        src_lengths = batch["net_input"]["src_lengths"]
        constraints = batch.get("constraints", None)
        yield Batch(
            ids=ids,
            src_tokens=src_tokens,
            src_lengths=src_lengths,
            constraints=constraints,
        )