Ejemplo n.º 1
0
def get_bound_length(batch_size, lens, min_len, max_len, min_len_factor,
                     max_len_factor):
    min_lens = None
    max_lens = None
    lens = lens.float()
    # 1. coarse-grained
    if min_len:
        min_lens = torch.as_tensor([min_len] * batch_size).long()
    if max_len:
        max_lens = torch.as_tensor([max_len] * batch_size).long()
    # 2. fine-grained
    if lens is not None:
        if min_len_factor:
            f_min_lens = lens * min_len_factor
            min_lens = f_min_lens if min_lens is None else torch.max(
                f_min_lens, min_lens)
        if max_len_factor:
            f_max_lens = lens * max_len_factor
            max_lens = f_max_lens if max_lens is None else torch.min(
                f_max_lens, max_lens)

    # plus 1, for bos token
    if min_lens is not None:
        min_lens += 1
    # plus 1, for eos token
    if max_lens is not None:
        max_lens += 1

    if min_lens is not None and max_lens is not None:
        assert (min_lens < max_lens).all()

    return cuda(min_lens), cuda(max_lens)
Ejemplo n.º 2
0
def main(args):
    states, models, paths = loads(args.models, args.select, args.n)
    vocabularies = states[0].get('vocabularies')
    source_vocab, target_vocab = vocabularies

    if args.verbose:
        sys.stderr.write(f'Decoding using checkpoints: {paths} \n')

    if len(models) == 1:
        model = models[0]
    else:
        model = ensemble.AvgPrediction(models)

    cuda(model)

    meter = TimeMeter()
    vocabs=[source_vocab]
    if len(args.input)==2:
        vocabs.append(target_vocab)

    it = get_eval_iterator(args, vocabs)

    n_tok = 0
    n_snt = 0
    for batch in it.iter_epoch():
        translate_batch(model, batch, True, getattr(states[0]['args'], 'r2l', None))
        meter.update(batch.data['n_tok'])
        n_snt += batch.data['n_snt']
        n_tok += batch.data['n_tok']

    sys.stderr.write(f'Snt={n_snt}, Tok={n_tok}, '
                     f'Time={timedelta(seconds=meter.elapsed_time)}, '
                     f'Avg={meter.avg:.2f}tok/s\n')
Ejemplo n.º 3
0
 def collate_fn(xs):
     return {
         'src':
         cuda(
             pack_tensors(aggregate_value_by_key(xs, 'src'),
                          source_vocab.pad_id)),
         'trg':
         cuda(
             pack_tensors(aggregate_value_by_key(xs, 'trg'),
                          target_vocab.pad_id)),
         'n_src_tok':
         aggregate_value_by_key(xs, 'n_src_tok', sum),
         'n_trg_tok':
         aggregate_value_by_key(xs, 'n_trg_tok', sum),
     }
Ejemplo n.º 4
0
 def collate_fn(xs):
     inputs = aggregate_value_by_key(xs, 'src')
     inputs = list(zip(*inputs))
     inputs = [cuda(pack_tensors(input, voc.pad_id)) for input, voc in zip(inputs, vocabs)]
     return {
         'src': inputs[0] if len(inputs) == 1 else inputs,
         'n_tok': aggregate_value_by_key(xs, 'n_tok', sum),
         'n_snt': inputs[0].size(0)
     }
Ejemplo n.º 5
0
def convert_data(data, voc, tokenizer=True, try_cuda=True):
    if tokenizer:
        data = [voc.to_indices(tokenize(seq)) for seq in data]
    else:
        data = [voc.to_indices(seq) for seq in data]
    seq = convert_to_array(data, voc.pad_id)
    if try_cuda:
        seq = cuda(seq)
    return seq
Ejemplo n.º 6
0
 def collate_fn(xs):
     return {
         'src':
         cuda(
             pack_tensors(aggregate_value_by_key(xs, 'src'),
                          source_vocab.pad_id)),
         'n_tok':
         aggregate_value_by_key(xs, 'n_tok', sum),
         'refs':
         aggregate_value_by_key(xs, 'refs')
     }
Ejemplo n.º 7
0
 def collate_fn(xs):
     return {
         'src':
         cuda(
             pack_tensors(aggregate_value_by_key(xs, 'src'),
                          source_vocab.pad_id)),
         'r2l':
         cuda(
             pack_tensors(aggregate_value_by_key(xs, 'r2l'),
                          target_vocab.pad_id)),
         'l2r':
         cuda(
             pack_tensors(aggregate_value_by_key(xs, 'l2r'),
                          target_vocab.pad_id)),
         'ntok_src':
         aggregate_value_by_key(xs, 'n_src_tok', sum),
         'ntok_r2l':
         aggregate_value_by_key(xs, 'ntok_r2l', sum),
         'ntok_l2r':
         aggregate_value_by_key(xs, 'ntok_l2r', sum),
     }
Ejemplo n.º 8
0
def main(args):
    set_seed(args.seed)

    # load vocabularies
    vocabularies = state_dict.get('vocabularies')

    if not vocabularies:
        if not args.vocab_size:
            args.vocab_size = [None]
        if len(args.vocab_size) == 1:
            args.vocab_size *= len(args.vocab)
        assert len(args.vocab_size) == len(args.vocab)
        vocabularies = [
            Vocabulary(filename, size)
            for filename, size in zip(args.vocab, args.vocab_size)
        ]

    source_vocab: Vocabulary = vocabularies[0]
    target_vocab: Vocabulary = vocabularies[1]

    # build model and criterion
    stop_watcher = StopwatchMeter(state_less=True)

    # 1. Build model

    model = models.build_model(args, vocabularies)
    if args.pretrain:
        logger.info('Loading pretraining parameters ...')
        pretrain = Loader.load_state(args.pretrain, 'cpu')
        pretrain = pretrain['model']
        from thseq.utils.misc import load_pretrain
        loaded, not_loaded = load_pretrain(model.encoder, pretrain, 'encoder')
        logger.info(f'Encoder loaded: {" ".join(loaded)}\n', )
        logger.info(f'Encoder not loaded: {" ".join(not_loaded)}\n', )
        loaded, not_loaded = load_pretrain(model.decoder.r2l, pretrain,
                                           'decoder')
        logger.info(f'Decoder loaded: {" ".join(loaded)}\n', )
        logger.info(f'Decoder not loaded: {" ".join(not_loaded)}\n', )

    # dummy_input = (torch.zeros(100, 10).long(), torch.zeros(80, 10).long())
    # with SummaryWriter(log_dir=log_dir) as writer:
    #     writer.add_graph(model,dummy_input)
    #     del dummy_input
    # import sys
    # sys.exit(0)

    # Initialize parameters
    if not resume:
        logger.info(f'Model: \n{model}')
        model.apply(init_parameters)

        stat_parameters(model)
        logger.info(
            f'Batch size = {args.batch_size[0] * torch.cuda.device_count()} '
            f'({args.batch_size[0]} x {torch.cuda.device_count()})')

    model = cuda(model)

    optimizer = optim.build_optimizer(args, model.parameters())
    lr_scheduler = thseq.optim.lr_scheduler.build_lr_scheduler(args, optimizer)

    # build trainer
    trainer = Trainer(args, model, optimizer, None, lr_scheduler)

    # build data iterator
    iterator = get_train_iterator(args, source_vocab, target_vocab)

    # Group stateful instances as a checkpoint
    state = State(args.save_checkpoint_secs,
                  args.save_checkpoint_steps,
                  args.keep_checkpoint_max,
                  args.keep_best_checkpoint_max,
                  args=args,
                  trainer=trainer,
                  model=model,
                  criterion=None,
                  optimizer=optimizer,
                  lr_scheduler=lr_scheduler,
                  iterator=iterator,
                  vocabularies=vocabularies)

    # Restore state
    state.load_state_dict(state_dict)

    # Train until the learning rate gets too small
    import math
    max_epoch = args.max_epoch or math.inf
    max_step = args.max_step or math.inf

    eval_iter = get_dev_iterator(args, [source_vocab, target_vocab])

    reseed = lambda: set_seed(args.seed + state.step)

    kwargs = {}
    if resume:
        kwargs = {'purge_step': state.step}
    reseed()

    def before_epoch_callback():
        # 0-based
        logger.info(f'Start epoch {state.epoch + 1}')

    def after_epoch_callback():
        step0, step1 = state.step_in_epoch, iterator.step_in_epoch
        total0, total1 = state.step, iterator.step
        logger.info(
            f'Finished epoch {state.epoch + 1}. '
            f'Failed steps: {step1 - step0} out of {step1} in last epoch and '
            f'{total1 - total0} out of {total1} in total. ')

        state.increase_epoch()
        if state.eval_scores:
            eval_score = -state.eval_scores[-1]
            trainer.lr_step(state.epoch, -eval_score)

    trainer.reset_meters()

    with SummaryWriter(log_dir=os.path.join(args.model, 'tensorboard'),
                       **kwargs) as writer:
        batches = []
        for batch in iterator.while_true(predicate=(
                lambda: (args.min_lr is None or trainer.get_lr() > args.min_lr)
                and state.epoch < max_epoch and state.step < max_step),
                                         before_epoch=before_epoch_callback,
                                         after_epoch=after_epoch_callback):
            model.train()
            reseed()

            batches.append(batch)
            if len(batches) % args.accumulate == 0:
                samples = []
                for batch in batches:
                    input = batch.data['src']
                    r2l = batch.data['r2l']
                    l2r = batch.data['l2r']

                    ntok_l2r = batch.data['ntok_l2r']
                    ntok_src = batch.data['ntok_src']

                    sample = {
                        'net_input': (input, r2l, l2r),
                        'target_r2l': r2l,
                        'target_l2r': l2r,
                        'ntok_src': ntok_src,
                        'ntok_l2r': ntok_l2r,
                    }
                    samples.append(sample)
                batches.clear()
                log = trainer.train_step(samples)
                if not log:
                    continue
            else:
                continue

            state.increase_num_steps()
            trainer.lr_step_update(state.step)
            pwc = log["per_word_loss"]  # natural logarithm
            total_steps = state.step

            wps = trainer.meters["wps"].avg
            gnorm = trainer.meters['gnorm'].val
            cur_lr = trainer.get_lr()
            info = f'{total_steps} ' \
                f'|loss={log["loss"]:.4f} ' \
                f'|pwc={pwc:.4f} ' \
                f'|lr={cur_lr:.6e} ' \
                f'|norm={gnorm:.2f} ' \
                f'|wps={wps:.2f} ' \
                f'|input={(log.get("ntok_src", 0), log.get("ntok_l2r", 0))} '
            logger.info(info)
            # torch.cuda.empty_cache()

            writer.add_scalar('loss', log['loss_l2r'], total_steps)
            writer.add_scalar('lr', cur_lr, total_steps)

            if total_steps % args.eval_steps == 0:
                stop_watcher.start()
                with torch.no_grad():
                    val_score = trainer.evaluate(eval_iter, r2l=args.r2l)
                stop_watcher.stop()
                state.add_valid_score(val_score)
                writer.add_scalar(f'dev/bleu', val_score, total_steps)
                logger.info(
                    f'Validation bleu at {total_steps}: {val_score:.2f}, '
                    f'took {timedelta(seconds=stop_watcher.sum // 1)}')

            state.try_save()

        # Evaluate at the end of training.
        stop_watcher.start()
        with torch.no_grad():
            val_score = trainer.evaluate(eval_iter, r2l=args.r2l)
        stop_watcher.stop()
        state.add_valid_score(val_score)
        writer.add_scalar(f'dev/bleu', val_score, state.step)
        logger.info(f'Validation bleu at {state.step}: {val_score:.2f}, '
                    f'took {timedelta(seconds=stop_watcher.sum // 1)}')
    logger.info(
        f'Training finished at {strftime("%b %d, %Y, %H:%M:%S", localtime())}, '
        f'took {timedelta(seconds=state.elapsed_time // 1)}')
    logger.info(
        f'Best validation bleu: {max(state.eval_scores)}, at {state.get_best_time()}'
    )
Ejemplo n.º 9
0
def beamsearch_kbest(fn,
                     state,
                     lens,
                     batch_size,
                     beam_width,
                     eos: int,
                     bos: int = None,
                     length_penalty: float = 1.0,
                     min_len_factor: float = 0.5,
                     max_len_factor: float = 3.0,
                     min_len: int = None,
                     max_len: int = None,
                     topk: int = 1,
                     stop_criteria: str = 'find_K_ended',
                     expand_args: bool = False) -> List[List[Path]]:
    """
    Args:
        fn: A callable function that takes `state` as input and outputs a tuple of (log_prob, new_state).
        state: A tuple, list or a dictionary. This is the initial state of the decoding process.
        lens: A list of ints representing source sequence lengths.
            Setting it to `None` will disable fine-grained length constraint.
        batch_size:
        beam_width:
        eos:
        bos: (Optional.) if not provided, reuse eos as bos.
        length_penalty:
        min_len_factor: Fine-grained constraint over each output sequence's length.
        max_len_factor:
        min_len: Coarse-grained constraint over all output sequences' length.
        max_len:
        topk:
        device:
        stop_criteria: Available options: ['find_K_ended', 'top_path_ended'],
            also support rules combination. For example, 'find_K_ended || top_path_ended' for or logic,
             or 'A && B' for and logic.
        expand_args: feed expanded `state` as args to `fn`.
    Returns:

    """
    B = batch_size

    if lens is not None:
        assert len(lens) == B, (len(lens), B)

    bos = eos if bos is None else bos

    # set up constraint as the intersection of the intervals.
    min_lens, max_lens = get_bound_length(batch_size, lens, min_len, max_len,
                                          min_len_factor, max_len_factor)

    # initialize beam
    beam = Beam(beam_width, [[Path([bos])] for _ in range(B)])
    is_path_ended = lambda path: path.nodes[-1] == eos

    # set up stopping criteria
    criterion = StopCriterion(beam_width, min_lens, max_lens, is_path_ended)
    stop_criteria = get_stop_criteria(criterion, stop_criteria)

    length = 1  # including bos token
    while not beam.empty():
        B = beam.effective_batch_size()
        # batch size might be changed when a source meets stopping criteria
        # input = [path.nodes[-1] for paths in beam.alive for path in paths]
        # input = to_cuda(torch.as_tensor(input).long()).unsqueeze(1)  # BK x 1
        input = [[path.nodes for path in paths] for paths in beam.alive]
        input = cuda(torch.as_tensor(input).long())  # B x K x T
        input = input.view(-1, input.size(-1))  # BK x T

        # log_prob is of shape BK x 1 x V
        if expand_args:
            if isinstance(state, (tuple, list)):
                log_prob, state = fn(input, *state)
            elif isinstance(state, dict):
                log_prob, state = fn(input, **state)
            else:
                log_prob, state = fn(input, state)
        else:
            log_prob, state = fn(input, state)

        log_prob = log_prob.view(B, -1, log_prob.size(-1))  # B x K x V
        K_ = log_prob.size(1)  # 1 at first step and K at succeeding steps.

        # tweak probs
        idxs = torch.as_tensor(list(beam.source_indices())).long()

        if min_lens is not None:
            mask = length < min_lens[idxs]
            if mask.any():
                mask = cuda(mask)
                log_prob[:, :, eos].masked_fill_(mask.view(-1, 1), -numpy.inf)
        if max_lens is not None:
            mask = length == max_lens[idxs] - 1
            if mask.any():
                mask = cuda(mask)
                log_prob[:, :, :eos].masked_fill_(mask.view(-1, 1, 1),
                                                  -numpy.inf)
                log_prob[:, :, eos + 1:].masked_fill_(mask.view(-1, 1, 1),
                                                      -numpy.inf)

        if K_ == 1:
            repeat_idxs = torch.arange(B).view(-1,
                                               1).expand(-1,
                                                         beam_width).flatten()
            repeat_idxs = cuda(torch.as_tensor(repeat_idxs).long())
            state = select(state, repeat_idxs)

        beam_idxs = beam.forward(log_prob, is_path_ended, stop_criteria)

        if beam_idxs is not None:
            # B' x K
            beam_idxs = cuda(torch.as_tensor(beam_idxs).long())
            beam_idxs = beam_idxs.view(-1)
            state = select(state, beam_idxs)

        length += 1

    def rescore(path: Path):
        """
        Re-score the path.
        """
        score = path.score / (len(path) - 1)  # exclude bos token
        path.penalized_score = score

    def top(paths: List[Path], k=None):
        """
        Select from paths
        """
        paths.sort(key=lambda path: path.penalized_score
                   if path.penalized_score is not None else path.score,
                   reverse=True)
        return paths[:(k or 1)]

    ended = beam.ended

    for paths in ended:
        for path in paths:
            rescore(path)

    paths = [top(paths, topk) for paths in ended]

    return paths
Ejemplo n.º 10
0
def main(args):
    set_seed(args.seed)

    # load vocabularies
    vocabularies = state_dict.get('vocabularies')

    if not vocabularies:
        if not args.vocab_size:
            args.vocab_size = [None]
        if len(args.vocab_size) == 1:
            args.vocab_size *= len(args.vocab)
        assert len(args.vocab_size) == len(args.vocab)
        vocabularies = [
            Vocabulary(filename, size)
            for filename, size in zip(args.vocab, args.vocab_size)
        ]

    source_vocab: Vocabulary = vocabularies[0]
    target_vocab: Vocabulary = vocabularies[1]

    # build model and criterion
    stop_watcher = StopwatchMeter(state_less=True)

    # 1. Build model
    model = models.build_model(args, vocabularies)
    # 2. Set up training criterion
    criterion = criterions.build_criterion(args, target_vocab)

    # dummy_input = (torch.zeros(100, 10).long(), torch.zeros(80, 10).long())
    # with SummaryWriter(log_dir=log_dir) as writer:
    #     writer.add_graph(model,dummy_input)
    #     del dummy_input
    # import sys
    # sys.exit(0)

    # Initialize parameters
    if not resume:
        logger.info(f'Model: \n{model}')
        model.apply(init_parameters)

        stat_parameters(model)
        logger.info(
            f'Batch size = {args.batch_size[0] * torch.cuda.device_count()} '
            f'({args.batch_size[0]} x {torch.cuda.device_count()})')

    model = cuda(model)
    criterion = cuda(criterion)

    optimizer = optim.build_optimizer(args, model.parameters())
    lr_scheduler = thseq.optim.lr_scheduler.build_lr_scheduler(args, optimizer)

    # build trainer
    trainer = Trainer(args, model, optimizer, criterion, lr_scheduler)

    # build data iterator
    iterator = get_train_iterator(args, source_vocab, target_vocab)

    # Group stateful instances as a checkpoint
    state = State(args.save_checkpoint_secs,
                  args.save_checkpoint_steps,
                  args.keep_checkpoint_max,
                  args.keep_best_checkpoint_max,
                  args=args,
                  trainer=trainer,
                  model=model,
                  criterion=criterion,
                  optimizer=optimizer,
                  lr_scheduler=lr_scheduler,
                  iterator=iterator,
                  vocabularies=vocabularies)

    # Restore state
    state.load_state_dict(state_dict)

    # Train until the learning rate gets too small
    import math
    max_epoch = args.max_epoch or math.inf
    max_step = args.max_step or math.inf

    eval_iter = get_dev_iterator(args, source_vocab)

    reseed = lambda: set_seed(args.seed + state.step)

    kwargs = {}
    if resume:
        kwargs = {'purge_step': state.step}
    reseed()

    def before_epoch_callback():
        # 0-based
        logger.info(f'Start epoch {state.epoch + 1}')

    def after_epoch_callback():
        step0, step1 = state.step_in_epoch, iterator.step_in_epoch
        total0, total1 = state.step, iterator.step
        logger.info(
            f'Finished epoch {state.epoch + 1}. '
            f'Failed steps: {step1 - step0} out of {step1} in last epoch and '
            f'{total1 - total0} out of {total1} in total. ')

        state.increase_epoch()
        if state.eval_scores:
            eval_score = -state.eval_scores[-1]
            trainer.lr_step(state.epoch, -eval_score)

    trainer.reset_meters()

    with SummaryWriter(log_dir=os.path.join(args.model, 'tensorboard'),
                       **kwargs) as writer:
        for batch in iterator.while_true(predicate=(
                lambda: (args.min_lr is None or trainer.get_lr() > args.min_lr)
                and state.epoch < max_epoch and state.step < max_step),
                                         before_epoch=before_epoch_callback,
                                         after_epoch=after_epoch_callback):

            model.train()
            reseed()

            input = batch.data['src']
            output = batch.data['trg']

            n_src_tok = batch.data['n_src_tok']
            n_trg_tok = batch.data['n_trg_tok']

            sample = {
                'net_input': (input, output),
                'target': output,
                'ntokens': n_trg_tok
            }
            if (state.step + 1) % args.accumulate > 0:
                # accumulate updates according to --update-freq
                trainer.train_step(sample, update_params=False)
                continue
            else:
                log_output = trainer.train_step(sample, update_params=True)
                if not log_output:  # failed
                    continue
            state.increase_num_steps()
            trainer.lr_step_update(state.step)
            pwc = log_output["per_word_loss"]  # natural logarithm
            total_steps = state.step

            wps = trainer.meters["wps"].avg
            gnorm = trainer.meters['gnorm'].val
            cur_lr = trainer.get_lr()

            batch_size = output.size(
                0) if args.batch_by_sentence else n_trg_tok

            info = f'{total_steps} ' \
                f'|loss={pwc:.4f} ' \
                f'|lr={cur_lr:.6e} ' \
                f'|norm={gnorm:.2f} ' \
                f'|batch={batch_size}/{wps:.2f} ' \
                f'|input={list(input.shape)}/{n_src_tok}, {list(output.shape)}/{n_trg_tok} '
            logger.info(info)
            # torch.cuda.empty_cache()

            writer.add_scalar('loss', log_output['loss'], total_steps)
            writer.add_scalar('lr', cur_lr, total_steps)

            if total_steps % args.eval_steps == 0:
                stop_watcher.start()
                with torch.no_grad():
                    val_score = trainer.evaluate(eval_iter, r2l=args.r2l)
                stop_watcher.stop()
                state.add_valid_score(val_score)
                writer.add_scalar(f'dev/bleu', val_score, total_steps)
                logger.info(
                    f'Validation bleu at {total_steps}: {val_score:.2f}, '
                    f'took {timedelta(seconds=stop_watcher.sum // 1)}')

            state.try_save()

        # Evaluate at the end of training.
        stop_watcher.start()
        with torch.no_grad():
            val_score = trainer.evaluate(eval_iter, r2l=args.r2l)
        stop_watcher.stop()
        state.add_valid_score(val_score)
        writer.add_scalar(f'dev/bleu', val_score, state.step)
        logger.info(f'Validation bleu at {state.step}: {val_score:.2f}, '
                    f'took {timedelta(seconds=stop_watcher.sum // 1)}')
    logger.info(
        f'Training finished at {strftime("%b %d, %Y, %H:%M:%S", localtime())}, '
        f'took {timedelta(seconds=state.elapsed_time // 1)}')
    logger.info(
        f'Best validation bleu: {max(state.eval_scores)}, at {state.get_best_time()}'
    )