def make_batches(lines, cfg, task, max_positions, batch_size=None, max_tokens=None): Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints") if cfg.generation.constraints: # Strip (tab-delimited) contraints, if present, from input lines, # store them in batch_constraints batch_constraints = [list() for _ in lines] for i, line in enumerate(lines): if "\t" in line: lines[i], *batch_constraints[i] = line.split("\t") # Convert each List[str] to List[Tensor] for i, constraint_list in enumerate(batch_constraints): batch_constraints[i] = [ task.target_dictionary.encode_line( constraint, append_eos=False, add_if_not_exist=False, ) for constraint in constraint_list ] tokens = [ task.source_dictionary.encode_line(src_str, add_if_not_exist=False).long() for src_str in lines ] if cfg.generation.constraints: constraints_tensor = pack_constraints(batch_constraints) else: constraints_tensor = None lengths = [t.numel() for t in tokens] itr = task.get_batch_iterator( dataset=task.build_dataset_for_inference( tokens, lengths, constraints=constraints_tensor), max_tokens=max_tokens, max_sentences=batch_size, max_positions=max_positions, ignore_invalid_inputs=cfg.dataset. skip_invalid_size_inputs_valid_test, ).next_epoch_itr(shuffle=False) for batch in itr: ids = batch["id"] src_tokens = batch["net_input"]["src_tokens"] src_lengths = batch["net_input"]["src_lengths"] constraints = batch.get("constraints", None) yield Batch( ids=ids, src_tokens=src_tokens, src_lengths=src_lengths, constraints=constraints, )
def make_batches(lines, args, task, max_positions, encode_fn, truncate_size): def encode_fn_target(x): return encode_fn(x) if args.constraints: # Strip (tab-delimited) contraints, if present, from input lines, # store them in batch_constraints batch_constraints = [list() for _ in lines] for i, line in enumerate(lines): if "\t" in line: lines[i], *batch_constraints[i] = line.split("\t") # Convert each List[str] to List[Tensor] for i, constraint_list in enumerate(batch_constraints): batch_constraints[i] = [ task.target_dictionary.encode_line( encode_fn_target(constraint), append_eos=False, add_if_not_exist=False, ) for constraint in constraint_list ] tokens = [ task.source_dictionary.encode_line(encode_fn(src_str), add_if_not_exist=False, max_tokens=truncate_size).long() for src_str in lines ] if args.constraints: constraints_tensor = pack_constraints(batch_constraints) else: constraints_tensor = None lengths = [t.numel() for t in tokens] itr = task.get_batch_iterator( dataset=task.build_dataset_for_inference( tokens, lengths, constraints=constraints_tensor), max_tokens=args.max_tokens, max_sentences=args.batch_size, max_positions=max_positions, ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test ).next_epoch_itr(shuffle=False) for batch in itr: ids = batch['id'] src_tokens = batch['net_input']['src_tokens'] src_lengths = batch['net_input']['src_lengths'] constraints = batch.get("constraints", None) yield Batch( ids=ids, src_tokens=src_tokens, src_lengths=src_lengths, constraints=constraints, )
def make_batches(lines, cfg, task, max_positions, encode_fn, constrainted_decoding=False): def encode_fn_target(x): return encode_fn(x) if constrainted_decoding: # Strip (tab-delimited) contraints, if present, from input lines, # store them in batch_constraints batch_constraints = [list() for _ in lines] for i, line in enumerate(lines): if "\t" in line: lines[i], *batch_constraints[i] = line.split("\t") # Convert each List[str] to List[Tensor] for i, constraint_list in enumerate(batch_constraints): batch_constraints[i] = [ task.target_dictionary.encode_line( encode_fn_target(constraint), append_eos=False, add_if_not_exist=False, ) for constraint in constraint_list ] if constrainted_decoding: constraints_tensor = pack_constraints(batch_constraints) else: constraints_tensor = None tokens, lengths = task.get_interactive_tokens_and_lengths(lines, encode_fn) itr = task.get_batch_iterator( dataset=task.build_dataset_for_inference( tokens, lengths, constraints=constraints_tensor), max_tokens=cfg.dataset.max_tokens, max_sentences=cfg.dataset.batch_size, max_positions=max_positions, ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, ).next_epoch_itr(shuffle=False) for batch in itr: ids = batch["id"] src_tokens = batch["net_input"]["src_tokens"] src_lengths = batch["net_input"]["src_lengths"] constraints = batch.get("constraints", None) yield Batch( ids=ids, src_tokens=src_tokens, src_lengths=src_lengths, constraints=constraints, )
def test_sequences(self): for i, (constraints, tokens, expected) in enumerate(self.sequences): state = OrderedConstraintState.create( pack_constraints([constraints])[0]) for token in tokens: state = state.advance(token) result = {} for attr in expected.keys(): result[attr] = getattr(state, attr) assert (result == expected ), f"TEST({tokens}) GOT: {result} WANTED: {expected}"
def constraint2tensor(self, constraints: [str]): for i, constraint_list in enumerate(constraints): constraints[i] = [ # encode with src_dict as this becomes tgt self.bart.src_dict.encode_line( self.bart.apply_bpe(constraint), append_eos=False, add_if_not_exist=False, ) for constraint in constraint_list ] return pack_constraints(constraints)
def test_packing(self): """Ensures the list of lists of tensors gets packed correctly.""" for batch_constraints, expected_tensor in self.examples: packed = pack_constraints(batch_constraints) assert torch.equal(packed, expected_tensor)
def make_batches(lines, cfg, task, max_positions, encode_fn): def encode_fn_target(x): return encode_fn(x) if cfg.generation.constraints: # Strip (tab-delimited) contraints, if present, from input lines, # store them in batch_constraints batch_constraints = [list() for _ in lines] for i, line in enumerate(lines): if "\t" in line: lines[i], *batch_constraints[i] = line.split("\t") # Convert each List[str] to List[Tensor] for i, constraint_list in enumerate(batch_constraints): batch_constraints[i] = [ task.target_dictionary.encode_line( encode_fn_target(constraint), append_eos=False, add_if_not_exist=False, ) for constraint in constraint_list ] # tokens = [ # task.source_dictionary.encode_line( # encode_fn(src_str), add_if_not_exist=False # ).long()[:max_positions[0] - 3] # for src_str in lines # ] if getattr(cfg.task, "truncate_source", False): # tokens = [ # task.source_dictionary.encode_line( # encode_fn(src_str), add_if_not_exist=False # ).long()[:128] # for src_str in lines # ] tokens = [ task.source_dictionary.encode_line( encode_fn(src_str), add_if_not_exist=False).long()[:max_positions[0] - 4] for src_str in lines ] else: tokens = [ task.source_dictionary.encode_line(encode_fn(src_str), add_if_not_exist=False).long() for src_str in lines ] if cfg.generation.constraints: constraints_tensor = pack_constraints(batch_constraints) else: constraints_tensor = None lengths = [t.numel() for t in tokens] itr = task.get_batch_iterator( dataset=task.build_dataset_for_inference( tokens, lengths, constraints=constraints_tensor), max_tokens=cfg.dataset.max_tokens, max_sentences=cfg.dataset.batch_size, max_positions=max_positions, ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, ).next_epoch_itr(shuffle=False) for batch in itr: ids = batch["id"] src_tokens = batch["net_input"]["src_tokens"] src_lengths = batch["net_input"]["src_lengths"] constraints = batch.get("constraints", None) yield Batch( ids=ids, src_tokens=src_tokens, src_lengths=src_lengths, constraints=constraints, )