Beispiel #1
0
 def score_batched_itr(self, data_itr, cuda=False, timer=None):
     """Iterate over a batched dataset and yield scored translations."""
     for sample in data_itr:
         s = utils.move_to_cuda(sample) if cuda else sample
         if timer is not None:
             timer.start()
         pos_scores, attn = self.score(s)
         for i, id in enumerate(s['id'].data):
             # remove padding from ref
             src = utils.strip_pad(s['net_input']['src_tokens'].data[i, :], self.pad)
             ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
             tgt_len = ref.numel()
             pos_scores_i = pos_scores[i][:tgt_len]
             score_i = pos_scores_i.sum() / tgt_len
             if attn is not None:
                 attn_i = attn[i]
                 _, alignment = attn_i.max(dim=0)
             else:
                 attn_i = alignment = None
             hypos = [{
                 'tokens': ref,
                 'score': score_i,
                 'attention': attn_i,
                 'alignment': alignment,
                 'positional_scores': pos_scores_i,
             }]
             if timer is not None:
                 timer.stop(s['ntokens'])
             # return results in the same format as SequenceGenerator
             yield id, src, ref, hypos
Beispiel #2
0
    def generate_batched_itr(
        self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
        cuda=False, timer=None, prefix_size=0,
    ):
        """Iterate over a batched dataset and yield individual translations.
        Args:
            maxlen_a/b: generate sequences of maximum length ax + b,
                where x is the source sentence length.
            cuda: use GPU for generation
            timer: StopwatchMeter for timing generations.
        """
        if maxlen_b is None:
            maxlen_b = self.maxlen

        for sample in data_itr:
            s = utils.move_to_cuda(sample) if cuda else sample
            if 'net_input' not in s:
                continue
            input = s['net_input']
            srclen = input['src_tokens'].size(1)
            if timer is not None:
                timer.start()
            with torch.no_grad():
                hypos = self.generate(
                    input['src_tokens'],
                    input['src_lengths'],
                    beam_size=beam_size,
                    maxlen=int(maxlen_a*srclen + maxlen_b),
                    prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
                )
            if timer is not None:
                timer.stop(sum(len(h[0]['tokens']) for h in hypos))
            for i, id in enumerate(s['id'].data):
                # remove padding
                src = utils.strip_pad(input['src_tokens'].data[i, :], self.pad)
                ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
                yield id, src, ref, hypos[i]
Beispiel #3
0
def eval_lm(
    models: List[fairseq.models.FairseqModel],
    source_dictionary: fairseq.data.Dictionary,
    batch_iterator: Iterable,
    post_process: Optional[str] = None,
    output_word_probs: bool = False,
    output_word_stats: bool = False,
    target_dictionary: Optional[fairseq.data.Dictionary] = None,
    softmax_batch: int = 0,
    remove_bos_token: bool = False,
    device: Optional[torch.device] = None,
):
    """
    Args:
        models (List[~fairseq.models.FairseqModel]): list of models to
            evaluate. Models are essentially `nn.Module` instances, but
            must be compatible with fairseq's `SequenceScorer`.
        source_dictionary (~fairseq.data.Dictionary): dictionary for
            applying any relevant post processing or outputing word
            probs/stats.
        batch_iterator (Iterable): yield batches of data
        post_process (Optional[str]): post-process text by removing BPE,
            letter segmentation, etc. Valid options can be found in
            fairseq.data.utils.post_process, although not all options
            are implemented here.
        output_word_probs (Optional[bool]): output words and their
            predicted log probabilities
        output_word_stats (Optional[bool]): output word statistics such
            as word count and average probability
        target_dictionary (Optional[~fairseq.data.Dictionary]): output
            dictionary (defaults to *source_dictionary*)
        softmax_batch (Optional[bool]): if BxT is more than this, will
            batch the softmax over vocab to this amount of tokens, in
            order to fit into GPU memory
        remove_bos_token (Optional[bool]): if True, confirm that the
            first token is the beginning-of-sentence symbol (according
            to the relevant dictionary) and remove it from the output
        device (Optional[torch.device]): device to use for evaluation
            (defaults to device of first model parameter)
    """
    if target_dictionary is None:
        target_dictionary = source_dictionary
    if device is None:
        device = next(models[0].parameters()).device

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(target_dictionary, softmax_batch)

    score_sum = 0.0
    count = 0

    if post_process is not None:
        if post_process in {"subword_nmt", "@@ "}:
            bpe_cont = post_process.rstrip()
            bpe_toks = {
                i
                for i in range(len(source_dictionary))
                if source_dictionary[i].endswith(bpe_cont)
            }
        else:
            raise NotImplementedError(
                "--post-process={post_process} is not implemented")
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    for sample in batch_iterator:
        if "net_input" not in sample:
            continue

        sample = utils.move_to_cuda(sample, device=device)

        gen_timer.start()
        hypos = scorer.generate(models, sample)
        gen_timer.stop(sample["ntokens"])

        for i, hypos_i in enumerate(hypos):
            hypo = hypos_i[0]
            sample_id = sample["id"][i]

            tokens = hypo["tokens"]
            tgt_len = tokens.numel()
            pos_scores = hypo["positional_scores"].float()

            if remove_bos_token:
                assert hypo["tokens"][0].item() == target_dictionary.bos()
                tokens = tokens[1:]
                pos_scores = pos_scores[1:]

            skipped_toks = 0
            if bpe_toks is not None:
                for i in range(tgt_len - 1):
                    if tokens[i].item() in bpe_toks:
                        skipped_toks += 1
                        pos_scores[i + 1] += pos_scores[i]
                        pos_scores[i] = 0

            inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(
                float("-inf"))
            if inf_scores.any():
                logger.info(
                    "skipping tokens with inf scores:",
                    target_dictionary.string(tokens[inf_scores.nonzero()]),
                )
                pos_scores = pos_scores[(~inf_scores).nonzero()]
            score_sum += pos_scores.sum().cpu()
            count += pos_scores.numel() - skipped_toks

            if output_word_probs or output_word_stats:
                w = ""
                word_prob = []
                is_bpe = False
                for i in range(len(tokens)):
                    w_ind = tokens[i].item()
                    w += source_dictionary[w_ind]
                    if bpe_toks is not None and w_ind in bpe_toks:
                        w = w[:-bpe_len]
                        is_bpe = True
                    else:
                        word_prob.append((w, pos_scores[i].item()))

                        next_prob = None
                        ind = i + 1
                        while ind < len(tokens):
                            if pos_scores[ind].item() != 0:
                                next_prob = pos_scores[ind]
                                break
                            ind += 1

                        word_stats.setdefault(w, WordStat(w, is_bpe)).add(
                            pos_scores[i].item(), next_prob)
                        is_bpe = False
                        w = ""
                if output_word_probs:
                    logger.info(
                        str(int(sample_id)) + " " +
                        ("\t".join("{} [{:2f}]".format(x[0], x[1])
                                   for x in word_prob)))

    avg_nll_loss = (-score_sum / count / math.log(2) if count > 0 else 0
                    )  # convert to base 2
    logger.info("Evaluated {:,} tokens in {:.1f}s ({:.2f} tokens/s)".format(
        gen_timer.n, gen_timer.sum,
        1.0 / gen_timer.avg if gen_timer.avg > 0 else 0))

    if output_word_stats:
        for ws in sorted(word_stats.values(),
                         key=lambda x: x.count,
                         reverse=True):
            logger.info(ws)

    return {
        "loss": avg_nll_loss,
        "perplexity": 2**avg_nll_loss,
    }
Beispiel #4
0
def main(args, task=None, model_state=None):
    check_args(args)

    if args.max_tokens is None and args.batch_size is None:
        args.max_tokens = 4000000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    if task is None:
        # Load dataset splits
        task = tasks.setup_task(args)
        task.load_dataset(args.gen_subset)

        logger.info(
            "| {} {} {} examples".format(
                args.data, args.gen_subset, len(task.dataset(args.gen_subset))
            )
        )

    # Set dictionary
    tgt_dict = task.target_dictionary

    logger.info("| decoding with criterion {}".format(args.criterion))

    # Load ensemble

    if args.load_emissions:
        models, criterions = [], []
    else:
        logger.info("| loading model(s) from {}".format(args.path))
        models, criterions, _ = load_models_and_criterions(
            args.path,
            data_path=args.data,
            arg_overrides=eval(args.model_overrides),  # noqa
            task=task,
            model_state=model_state,
        )
        optimize_models(args, use_cuda, models)

    # hack to pass transitions to W2lDecoder
    if args.criterion == "asg_loss":
        trans = criterions[0].asg.trans.data
        args.asg_transitions = torch.flatten(trans).tolist()

    # Load dataset (possibly sharded)
    itr = get_dataset_itr(args, task, models)

    # Initialize generator
    gen_timer = StopwatchMeter()

    def build_generator(args):
        w2l_decoder = getattr(args, "w2l_decoder", None)
        if w2l_decoder == "viterbi":
            from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder

            return W2lViterbiDecoder(args, task.target_dictionary)
        elif w2l_decoder == "kenlm":
            from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder

            return W2lKenLMDecoder(args, task.target_dictionary)
        elif w2l_decoder == "fairseqlm":
            from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder

            return W2lFairseqLMDecoder(args, task.target_dictionary)
        else:
            print('only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment')

    # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
    generator = build_generator(args)

    if args.load_emissions:
        generator = ExistingEmissionsDecoder(
            generator, np.load(args.load_emissions, allow_pickle=True)
        )
        logger.info("loaded emissions from " + args.load_emissions)

    num_sentences = 0

    if args.results_path is not None and not os.path.exists(args.results_path):
        os.makedirs(args.results_path)

    max_source_pos = (
        utils.resolve_max_positions(
            task.max_positions(), *[model.max_positions() for model in models]
        ),
    )

    if max_source_pos is not None:
        max_source_pos = max_source_pos[0]
        if max_source_pos is not None:
            max_source_pos = max_source_pos[0] - 1

    if args.dump_emissions:
        emissions = {}
    if args.dump_features:
        features = {}
        models[0].bert.proj = None
    else:
        res_files = prepare_result_files(args)
    errs_t = 0
    lengths_t = 0
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if "net_input" not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample["target"][:, : args.prefix_size]

            gen_timer.start()
            if args.dump_emissions:
                with torch.no_grad():
                    encoder_out = models[0](**sample["net_input"])
                    emm = models[0].get_normalized_probs(encoder_out, log_probs=True)
                    emm = emm.transpose(0, 1).cpu().numpy()
                    for i, id in enumerate(sample["id"]):
                        emissions[id.item()] = emm[i]
                    continue
            elif args.dump_features:
                with torch.no_grad():
                    encoder_out = models[0](**sample["net_input"])
                    feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
                    for i, id in enumerate(sample["id"]):
                        padding = encoder_out["encoder_padding_mask"][i].cpu().numpy() if encoder_out["encoder_padding_mask"] is not None else None
                        features[id.item()] = (feat[i], padding)
                    continue
            hypos = task.inference_step(generator, models, sample, prefix_tokens)
            num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample["id"].tolist()):
                speaker = None
                # id = task.dataset(args.gen_subset).ids[int(sample_id)]
                id = sample_id
                toks = sample["target"][i, :] if 'target_label' not in sample else sample["target_label"][i, :]
                target_tokens = (
                    utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
                )
                # Process top predictions
                errs, length = process_predictions(
                    args, hypos[i], None, tgt_dict, target_tokens, res_files, speaker, id
                )
                errs_t += errs
                lengths_t += length

            wps_meter.update(num_generated_tokens)
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel()

    wer = None
    if args.dump_emissions:
        emm_arr = []
        for i in range(len(emissions)):
            emm_arr.append(emissions[i])
        np.save(args.dump_emissions, emm_arr)
        logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}")
    elif args.dump_features:
        feat_arr = []
        for i in range(len(features)):
            feat_arr.append(features[i])
        np.save(args.dump_features, feat_arr)
        logger.info(f"saved {len(features)} emissions to {args.dump_features}")
    else:
        if lengths_t > 0:
            wer = errs_t * 100.0 / lengths_t
            logger.info(f"WER: {wer}")

        logger.info(
            "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
            "sentences/s, {:.2f} tokens/s)".format(
                num_sentences,
                gen_timer.n,
                gen_timer.sum,
                num_sentences / gen_timer.sum,
                1.0 / gen_timer.avg,
                )
        )
        logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
    return task, wer
Beispiel #5
0
    def forward(self, model, sample, reduce=True):
        # sample mode
        #print('!!!RL loss.')
        model.eval()
        # src_dict = self.task.source_dictionary
        tgt_dict = self.task.target_dictionary
        eos_idx = self.task.target_dictionary.eos()
        sample_beam = 1
        translator = SequenceGenerator([model], tgt_dict=tgt_dict, sampling=self.args.multinomial_sample_train,
                                       beam_size=sample_beam, minlen=1)
        translator.cuda()
        ct = 0
        translations = []

        s = utils.move_to_cuda(sample)
        input = s['net_input']
        max_len = 200
        with torch.no_grad():
            hypos = translator.generate(
                input['src_tokens'],
                input['src_lengths'],
                beam_size=sample_beam,
                maxlen=max_len,
            )
        for i, id in enumerate(s['id'].data):
            src = input['src_tokens'].data[i, :]
            # remove padding from ref
            ref = utils.strip_pad(s['target'].data[i, :], tgt_dict.pad()) if s['target'] is not None else None
            translations.append((id, src, ref, hypos[i]))
            ct += 1
        # print("sample batch size:", ct)

        model.train()

        # MLE loss
        mle_net_output = model(**sample['net_input'])
        mle_lprobs = model.get_normalized_probs(mle_net_output, log_probs=True)
        mle_lprobs = mle_lprobs.view(-1, mle_lprobs.size(-1))
        mle_target = model.get_targets(sample, mle_net_output).view(-1)
        mle_loss = F.nll_loss(mle_lprobs, mle_target, size_average=False,
                              ignore_index=self.padding_idx, reduce=reduce)
        mle_tokens = sample['ntokens']
        avg_mle_loss = mle_loss / mle_tokens
        print('avg_mle_loss:', avg_mle_loss)
        # RL loss
        batch_rl_loss = 0
        batch_tokens = 0
        id = 0
        result = []
        for sample_id, src_tokens, tgt_tokens, hypos in translations:
            # calculate bleu
            id += 1
            hypo = hypos[0]  # only extract the first hypo (beam1 or sample1)
            trans_tokens = hypo['tokens']
            if self.args.delta_reward:
                reward = self.compute_sentence_bleu(tgt_tokens.cpu(), trans_tokens.cpu()).cuda()
            else:
                reward = self.compute_sentence_total_bleu(tgt_tokens.cpu(), trans_tokens.cpu()).cuda()

            result.append((id, reward.item(), tgt_tokens.size(0), trans_tokens.size(0)))
            # one_sample loss calculation
            tgt_input_tokens = trans_tokens.new(trans_tokens.shape).fill_(0)
            assert trans_tokens[-1] == eos_idx
            tgt_input_tokens[0] = eos_idx
            tgt_input_tokens[1:] = trans_tokens[:-1]
            train_sample = {
                'net_input': {
                    'src_tokens': src_tokens.view(1, -1),
                    'src_lengths': torch.LongTensor(src_tokens.numel()).view(1, -1),
                    'prev_output_tokens': tgt_input_tokens.view(1, -1),
                },
                'target': trans_tokens.view(1, -1)
            }
            train_sample = utils.move_to_cuda(train_sample)
            net_output = model(**train_sample['net_input'])
            lprobs = model.get_normalized_probs(net_output, log_probs=True)
            lprobs = lprobs.view(-1, lprobs.size(-1))
            target = model.get_targets(train_sample, net_output).view(-1, 1)
            non_pad_mask = target.ne(tgt_dict.pad())
            lprob = -lprobs.gather(dim=-1, index=target)[non_pad_mask]
            rl_loss = torch.sum(lprob * reward)  # one sample loss
            ntokens = len(train_sample['target'])

            batch_tokens += ntokens
            batch_rl_loss += rl_loss
        avg_rl_loss = batch_rl_loss / batch_tokens

        with open('./results/reward/v0_m'+str(self.args.mle_weight)+'r'+str(self.args.rl_weight)+'_lr'+str(self.args.lr)+'_r.csv','a', newline='') as csv_file:
            csv_writer = csv.writer(csv_file)
            for r in result:
                csv_writer.writerow(r)
        print('avg_rl_loss:', avg_rl_loss)
        if self.args.mle_weight:
            assert self.args.rl_weight
            total_loss = self.args.mle_weight * avg_mle_loss + self.args.rl_weight * avg_rl_loss
            total_tokens = batch_tokens + mle_tokens
        else:
            total_loss = avg_rl_loss
            total_tokens = batch_tokens
        logging_output = {
            'loss': utils.item(total_loss.data),
            'ntokens': total_tokens,
            'sample_size': total_tokens,
        }
        print('total: ',total_loss)
        with open('./results/loss/v0_m'+str(self.args.mle_weight)+'r'+str(self.args.rl_weight)+'_lr'+str(self.args.lr)+'_l.csv','a', newline='') as csv_file:
            csv_writer = csv.writer(csv_file)
            csv_writer.writerow((avg_mle_loss.item(), avg_rl_loss.item(), total_loss.item(), total_tokens))
        return total_loss, total_tokens, logging_output
Beispiel #6
0
def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('espresso.speech_recognize')
    if output_file is not sys.stdout:  # also print to stdout
        logger.addHandler(logging.StreamHandler(sys.stdout))

    print_options_meaning_changes(args, logger)

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    # Fix seed for stochastic decoding
    if args.seed is not None and not args.no_seed_provided:
        np.random.seed(args.seed)
        utils.set_torch_seed(args.seed)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset split
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionary
    dictionary = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args.path),
        arg_overrides=eval(args.model_overrides),
        task=task,
        suffix=getattr(args, "checkpoint_suffix", ""),
    )
    for i, m in enumerate(models):
        if hasattr(m, 'is_wordlm') and m.is_wordlm:
            # assume subword LM comes before word LM
            if isinstance(models[i - 1], FairseqLanguageModel):
                models[i - 1] = MultiLevelLanguageModel(
                    m,
                    models[i - 1],
                    subwordlm_weight=args.subwordlm_weight,
                    oov_penalty=args.oov_penalty,
                    open_vocab=not args.disable_open_vocab,
                )
                del models[i]
                logger.info('LM fusion with Multi-level LM')
            else:
                models[i] = TensorizedLookaheadLanguageModel(
                    m,
                    dictionary,
                    oov_penalty=args.oov_penalty,
                    open_vocab=not args.disable_open_vocab,
                )
                logger.info('LM fusion with Look-ahead Word LM')
        # assume subword LM comes after E2E models
        elif i == len(models) - 1 and isinstance(m, FairseqLanguageModel):
            logger.info('LM fusion with Subword LM')
    if args.lm_weight != 0.0:
        logger.info('using LM fusion with lm-weight={:.2f}'.format(
            args.lm_weight))

    # Optimize ensemble for generation
    for model in models:
        model.prepare_for_inference_(args)
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(), *[
                model.max_positions() if hasattr(model, 'encoder') else
                (None, model.max_positions()) for model in models
            ]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        default_log_format=('tqdm' if not args.no_progress_bar else 'none'),
    )

    # Initialize generator
    if args.match_source_len:
        logger.warning(
            'The option match_source_len is not applicable to speech recognition. Ignoring it.'
        )
    gen_timer = StopwatchMeter()
    generator = task.build_generator(models, args)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Generate and compute WER
    scorer = wer.Scorer(dictionary, wer_output_filter=args.wer_output_filter)
    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    for sample in progress:
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        prefix_tokens = None
        if args.prefix_size > 0:
            prefix_tokens = sample['target'][:, :args.prefix_size]

        gen_timer.start()
        hypos = task.inference_step(generator, models, sample, prefix_tokens)
        num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
        gen_timer.stop(num_generated_tokens)

        # obtain nonpad mask of encoder output to plot attentions
        if args.print_alignment:
            net_input = sample['net_input']
            src_tokens = net_input['src_tokens']
            output_lengths = models[0].encoder.output_lengths(
                net_input['src_lengths'])
            nonpad_idxs = sequence_mask(
                output_lengths,
                models[0].encoder.output_lengths(src_tokens.size(1)))

        for i in range(len(sample['id'])):
            has_target = sample['target'] is not None
            utt_id = sample['utt_id'][i]

            # Retrieve the original sentences
            if has_target:
                target_str = sample['target_raw_text'][i]
                if not args.quiet:
                    detok_target_str = decode_fn(target_str)
                    print('T-{}\t{}'.format(utt_id, detok_target_str),
                          file=output_file)

            # Process top predictions
            for j, hypo in enumerate(hypos[i][:args.nbest]):
                hypo_str = dictionary.string(
                    hypo['tokens'].int().cpu(),
                    bpe_symbol=None,
                    extra_symbols_to_ignore={dictionary.pad()},
                )  # not removing bpe at this point
                detok_hypo_str = decode_fn(hypo_str)
                if not args.quiet:
                    score = hypo['score'] / math.log(2)  # convert to base 2
                    print('H-{}\t{}\t{}'.format(utt_id, detok_hypo_str, score),
                          file=output_file)

                # Score and obtain attention only the top hypothesis
                if j == 0:
                    # src_len x tgt_len
                    attention = hypo['attention'][nonpad_idxs[i]].float().cpu() \
                        if args.print_alignment and hypo['attention'] is not None else None
                    if args.print_alignment and attention is not None:
                        save_dir = os.path.join(args.results_path,
                                                'attn_plots')
                        os.makedirs(save_dir, exist_ok=True)
                        plot_attention(attention, detok_hypo_str, utt_id,
                                       save_dir)
                    scorer.add_prediction(utt_id, hypo_str)
                    if has_target:
                        scorer.add_evaluation(utt_id, target_str, hypo_str)

        wps_meter.update(num_generated_tokens)
        progress.log({'wps': round(wps_meter.avg)})
        num_sentences += sample['nsentences']

    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info(
        'Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if args.print_alignment:
        logger.info('Saved attention plots in ' + save_dir)

    if has_target:
        scorer.add_ordered_utt_list(task.datasets[args.gen_subset].tgt.utt_ids)

    fn = 'decoded_char_results.txt'
    with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f:
        f.write(scorer.print_char_results())
        logger.info('Decoded char results saved as ' + f.name)

    fn = 'decoded_results.txt'
    with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f:
        f.write(scorer.print_results())
        logger.info('Decoded results saved as ' + f.name)

    if has_target:
        header = 'Recognize {} with beam={}: '.format(args.gen_subset,
                                                      args.beam)
        fn = 'wer'
        with open(os.path.join(args.results_path, fn), 'w',
                  encoding='utf-8') as f:
            res = 'WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format(
                *(scorer.wer()))
            logger.info(header + res)
            f.write(res + '\n')
            logger.info('WER saved in ' + f.name)

        fn = 'cer'
        with open(os.path.join(args.results_path, fn), 'w',
                  encoding='utf-8') as f:
            res = 'CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format(
                *(scorer.cer()))
            logger.info(' ' * len(header) + res)
            f.write(res + '\n')
            logger.info('CER saved in ' + f.name)

        fn = 'aligned_results.txt'
        with open(os.path.join(args.results_path, fn), 'w',
                  encoding='utf-8') as f:
            f.write(scorer.print_aligned_results())
            logger.info('Aligned results saved as ' + f.name)
    return scorer
Beispiel #7
0
def main(args, override_args=None):
    utils.import_user_module(args)

    use_fp16 = args.fp16
    use_cuda = torch.cuda.is_available() and not args.cpu

    if override_args is not None:
        overrides = vars(override_args)
        overrides.update(eval(getattr(override_args, 'model_overrides', '{}')))
    else:
        overrides = None

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
        [args.path],
        arg_overrides=overrides,
    )
    model = models[0]

    # Move models to GPU
    for model in models:
        if use_fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Print args
    print(model_args)

    # Build criterion
    criterion = task.build_criterion(model_args)
    criterion.eval()

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for subset in args.valid_subset.split(','):
        try:
            task.load_dataset(subset, combine=False, epoch=0)
            dataset = task.dataset(subset)
        except KeyError:
            raise Exception('Cannot find dataset: ' + subset)

        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=dataset,
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                *[m.max_positions() for m in models],
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args,
            itr,
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple')

        log_outputs = []
        for i, sample in enumerate(progress):
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            _loss, _sample_size, log_output = task.valid_step(
                sample, model, criterion)

            progress.log(log_output, step=i)
            log_outputs.append(log_output)

        log_output = task.aggregate_logging_outputs(log_outputs, criterion)

        progress.print(log_output, tag=subset, step=i)
Beispiel #8
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    detokenizer = MosesDetokenizer(args.moses_detokenizer)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()

        if args.edit_sample_index is not None:
            # make EditableTransformer

            assert len(models) == 1

            criterion = EditableTrainingCriterion(args, task).train(False)
            critetion_state_dict = torch.load(args.path)
            if 'criterion' in critetion_state_dict:
                criterion.load_state_dict(critetion_state_dict['criterion'])

            model = models[0]
            edit_sample = criterion.samples[args.edit_sample_index]
            device = 'cuda' if use_cuda else 'cpu'
            edited_model, success, _, complexity = criterion.get_edited_transformer(
                model, edit_sample, device, detach=True)
            edited_model.train(False)

            models[0] = edited_model.recover_transformer()

        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'].int().cpu()
                        if hypo['alignment'] is not None else None,
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                    hypo_str))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join(
                                    map(lambda x: str(utils.item(x)),
                                        alignment))))

                    # Score only the top hypothesis
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            if args.moses_detokenizer:
                                target_str = detokenizer(target_str.split())
                                hypo_str = detokenizer(hypo_str.split())

                            scorer.add_string(target_str, hypo_str)
                        else:
                            assert not args.moses_detokenizer, "detokenizer has no effect with current bleu scorer"
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))
        if args.edit_sample_index is not None:
            print('EditResult(success={}, complexity={})'.format(
                success, complexity))
    return scorer
Beispiel #9
0
 def _prepare_sample(self, sample):
     if sample is None or len(sample) == 0:
         return None
     return utils.move_to_cuda(sample)
Beispiel #10
0
def _main(cfg: DictConfig, output_file):
    logging.basicConfig(
        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=os.environ.get("LOGLEVEL", "INFO").upper(),
        stream=output_file,
    )
    logger = logging.getLogger("fairseq_cli.generate")

    utils.import_user_module(cfg.common)

    if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
        cfg.dataset.max_tokens = 12000
    logger.info(cfg)

    # Fix seed for stochastic decoding
    if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
        np.random.seed(cfg.common.seed)
        utils.set_torch_seed(cfg.common.seed)

    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    # Load dataset splits
    task = tasks.setup_task(cfg.task)

    # Set dictionaries
    try:
        src_dict = getattr(task, "source_dictionary", None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    overrides = ast.literal_eval(cfg.common_eval.model_overrides)

    # Load ensemble
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))
    models, saved_cfg = checkpoint_utils.load_model_ensemble(
        utils.split_paths(cfg.common_eval.path),
        arg_overrides=overrides,
        task=task,
        suffix=cfg.checkpoint.checkpoint_suffix,
        strict=(cfg.checkpoint.checkpoint_shard_count == 1),
        num_shards=cfg.checkpoint.checkpoint_shard_count,
    )

    # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
    task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)

    if cfg.generation.lm_path is not None:
        overrides["data"] = cfg.task.data

        try:
            lms, _ = checkpoint_utils.load_model_ensemble(
                [cfg.generation.lm_path], arg_overrides=overrides, task=None)
        except:
            logger.warning(
                f"Failed to load language model! Please make sure that the language model dict is the same "
                f"as target dict and is located in the data dir ({cfg.task.data})"
            )
            raise

        assert len(lms) == 1
    else:
        lms = [None]

    # Optimize ensemble for generation
    for model in chain(models, lms):
        if model is None:
            continue
        if cfg.common.fp16:
            model.half()
        if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
            model.cuda()
        model.prepare_for_inference_(cfg)

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(cfg.generation.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(cfg.dataset.gen_subset),
        max_tokens=cfg.dataset.max_tokens,
        max_sentences=cfg.dataset.batch_size,
        max_positions=utils.resolve_max_positions(
            task.max_positions(), *[m.max_positions() for m in models]),
        ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
        seed=cfg.common.seed,
        num_shards=cfg.distributed_training.distributed_world_size,
        shard_id=cfg.distributed_training.distributed_rank,
        num_workers=cfg.dataset.num_workers,
        data_buffer_size=cfg.dataset.data_buffer_size,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=cfg.common.log_format,
        log_interval=cfg.common.log_interval,
        default_log_format=("tqdm"
                            if not cfg.common.no_progress_bar else "simple"),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()

    extra_gen_cls_kwargs = {
        "lm_model": lms[0],
        "lm_weight": cfg.generation.lm_weight
    }
    generator = task.build_generator(models,
                                     cfg.generation,
                                     extra_gen_cls_kwargs=extra_gen_cls_kwargs)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(cfg.tokenizer)
    bpe = encoders.build_bpe(cfg.bpe)

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    scorer = scoring.build_scorer(cfg.scoring, tgt_dict)

    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    for sample in progress:
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if "net_input" not in sample:
            continue

        prefix_tokens = None
        if cfg.generation.prefix_size > 0:
            prefix_tokens = sample["target"][:, :cfg.generation.prefix_size]

        constraints = None
        if "constraints" in sample:
            constraints = sample["constraints"]

        gen_timer.start()
        hypos = task.inference_step(
            generator,
            models,
            sample,
            prefix_tokens=prefix_tokens,
            constraints=constraints,
        )
        num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
        gen_timer.stop(num_generated_tokens)

        for i, sample_id in enumerate(sample["id"].tolist()):
            has_target = sample["target"] is not None

            # Remove padding
            if "src_tokens" in sample["net_input"]:
                src_tokens = utils.strip_pad(
                    sample["net_input"]["src_tokens"][i, :], tgt_dict.pad())
            else:
                src_tokens = None

            target_tokens = None
            if has_target:
                target_tokens = (utils.strip_pad(sample["target"][i, :],
                                                 tgt_dict.pad()).int().cpu())

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(
                    cfg.dataset.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(
                    cfg.dataset.gen_subset).tgt.get_original_text(sample_id)
            else:
                if src_dict is not None:
                    src_str = src_dict.string(src_tokens,
                                              cfg.common_eval.post_process)
                else:
                    src_str = ""
                if has_target:
                    target_str = tgt_dict.string(
                        target_tokens,
                        cfg.common_eval.post_process,
                        escape_unk=True,
                        extra_symbols_to_ignore=
                        get_symbols_to_strip_from_output(generator),
                    )

            src_str = decode_fn(src_str)
            if has_target:
                target_str = decode_fn(target_str)

            if not cfg.common_eval.quiet:
                if src_dict is not None:
                    print("S-{}\t{}".format(sample_id, src_str),
                          file=output_file)
                if has_target:
                    print("T-{}\t{}".format(sample_id, target_str),
                          file=output_file)

            # Process top predictions
            for j, hypo in enumerate(hypos[i][:cfg.generation.nbest]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo["tokens"].int().cpu(),
                    src_str=src_str,
                    alignment=hypo["alignment"],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=cfg.common_eval.post_process,
                    extra_symbols_to_ignore=get_symbols_to_strip_from_output(
                        generator),
                )
                detok_hypo_str = decode_fn(hypo_str)
                if not cfg.common_eval.quiet:
                    score = hypo["score"] / math.log(2)  # convert to base 2
                    # original hypothesis (after tokenization and BPE)
                    print(
                        "H-{}\t{}\t{}".format(sample_id, score, hypo_str),
                        file=output_file,
                    )
                    # detokenized hypothesis
                    print(
                        "D-{}\t{}\t{}".format(sample_id, score,
                                              detok_hypo_str),
                        file=output_file,
                    )
                    print(
                        "P-{}\t{}".format(
                            sample_id,
                            " ".join(
                                map(
                                    lambda x: "{:.4f}".format(x),
                                    # convert from base e to base 2
                                    hypo["positional_scores"].div_(math.log(2)
                                                                   ).tolist(),
                                )),
                        ),
                        file=output_file,
                    )

                    if cfg.generation.print_alignment == "hard":
                        print(
                            "A-{}\t{}".format(
                                sample_id,
                                " ".join([
                                    "{}-{}".format(src_idx, tgt_idx)
                                    for src_idx, tgt_idx in alignment
                                ]),
                            ),
                            file=output_file,
                        )
                    if cfg.generation.print_alignment == "soft":
                        print(
                            "A-{}\t{}".format(
                                sample_id,
                                " ".join([
                                    ",".join(src_probs)
                                    for src_probs in alignment
                                ]),
                            ),
                            file=output_file,
                        )

                    if cfg.generation.print_step:
                        print(
                            "I-{}\t{}".format(sample_id, hypo["steps"]),
                            file=output_file,
                        )

                    if cfg.generation.retain_iter_history:
                        for step, h in enumerate(hypo["history"]):
                            _, h_str, _ = utils.post_process_prediction(
                                hypo_tokens=h["tokens"].int().cpu(),
                                src_str=src_str,
                                alignment=None,
                                align_dict=None,
                                tgt_dict=tgt_dict,
                                remove_bpe=None,
                            )
                            print(
                                "E-{}_{}\t{}".format(sample_id, step, h_str),
                                file=output_file,
                            )

                # Score only the top hypothesis
                if has_target and j == 0:
                    if align_dict is not None or cfg.common_eval.post_process is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tgt_dict.encode_line(
                            target_str, add_if_not_exist=True)
                        hypo_tokens = tgt_dict.encode_line(
                            detok_hypo_str, add_if_not_exist=True)
                    if hasattr(scorer, "add_string"):
                        scorer.add_string(target_str, detok_hypo_str)
                    else:
                        scorer.add(target_tokens, hypo_tokens)

        wps_meter.update(num_generated_tokens)
        progress.log({"wps": round(wps_meter.avg)})
        num_sentences += (sample["nsentences"]
                          if "nsentences" in sample else sample["id"].numel())

    logger.info("NOTE: hypothesis and token scores are output in base 2")
    logger.info(
        "Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)"
        .format(
            num_sentences,
            gen_timer.n,
            gen_timer.sum,
            num_sentences / gen_timer.sum,
            1.0 / gen_timer.avg,
        ))
    if has_target:
        if cfg.bpe and not cfg.generation.sacrebleu:
            if cfg.common_eval.post_process:
                logger.warning(
                    "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
                )
            else:
                logger.warning(
                    "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words.  Use --sacrebleu for standard 13a BLEU tokenization"
                )
        # use print to be consistent with other main outputs: S-, H-, T-, D- and so on
        print(
            "Generate {} with beam={}: {}".format(cfg.dataset.gen_subset,
                                                  cfg.generation.beam,
                                                  scorer.result_string()),
            file=output_file,
        )

    return scorer
Beispiel #11
0
def main(args):
    assert args.path is not None, "--path required for generation!"
    assert (not args.sampling or args.nbest
            == args.beam), "--sampling requires --nbest to be equal to --beam"
    assert (args.replace_unk is None or args.raw_text
            ), "--replace-unk requires a raw text dataset (--raw-text)"

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, "source_dictionary", None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print("| loading model(s) from {}".format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(":"),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if "net_input" not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample["target"][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample["id"].tolist()):
                has_target = sample["target"] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample["net_input"]["src_tokens"][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = (utils.strip_pad(
                        sample["target"][i, :], tgt_dict.pad()).int().cpu())

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print("S-{}\t{}".format(sample_id, src_str))
                    if has_target:
                        print("T-{}\t{}".format(sample_id, target_str))

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo["tokens"].int().cpu(),
                        src_str=src_str,
                        alignment=hypo["alignment"],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        print("H-{}\t{}\t{}".format(sample_id, hypo["score"],
                                                    hypo_str))
                        print("P-{}\t{}".format(
                            sample_id,
                            " ".join(
                                map(
                                    lambda x: "{:.4f}".format(x),
                                    hypo["positional_scores"].tolist(),
                                )),
                        ))

                        if args.print_alignment:
                            print("A-{}\t{}".format(
                                sample_id,
                                " ".join([
                                    "{}-{}".format(src_idx, tgt_idx)
                                    for src_idx, tgt_idx in alignment
                                ]),
                            ))

                        if args.print_step:
                            print("I-{}\t{}".format(sample_id, hypo["steps"]))

                        if getattr(args, "retain_iter_history", False):
                            print("\n".join([
                                "E-{}_{}\t{}".format(
                                    sample_id,
                                    step,
                                    utils.post_process_prediction(
                                        h["tokens"].int().cpu(),
                                        src_str,
                                        None,
                                        None,
                                        tgt_dict,
                                        None,
                                    )[1],
                                ) for step, h in enumerate(hypo["history"])
                            ]))

                    # Score only the top hypothesis
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, "add_string"):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += sample["nsentences"]

    print(
        "| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)"
        .format(
            num_sentences,
            gen_timer.n,
            gen_timer.sum,
            num_sentences / gen_timer.sum,
            1.0 / gen_timer.avg,
        ))
    if has_target:
        print("| Generate {} with beam={}: {}".format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))

    return scorer
Beispiel #12
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.dataset_impl == 'raw', \
        '--replace-unk requires a raw text dataset (--dataset-impl=raw)'

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))

    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=args.model_overrides,
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()

        if args.decoding_path is not None:
            src_sents = [[] for _ in range(5000000)]
            tgt_sents = [[] for _ in range(5000000)]
            hyp_sents = [[] for _ in range(5000000)]

        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                    hypo_str))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join([
                                    '{}-{}'.format(src_idx, tgt_idx)
                                    for src_idx, tgt_idx in alignment
                                ])))

                        if args.print_step:
                            print('I-{}\t{}'.format(sample_id, hypo['steps']))

                        if getattr(args, 'retain_iter_history', False):
                            print("\n".join([
                                'E-{}_{}\t{}'.format(
                                    sample_id, step,
                                    utils.post_process_prediction(
                                        h['tokens'].int().cpu(), src_str, None,
                                        None, tgt_dict, None)[1])
                                for step, h in enumerate(hypo['history'])
                            ]))

                    if args.decoding_path is not None:
                        src_sents[int(sample_id)].append(src_str)
                        tgt_sents[int(sample_id)].append(target_str)
                        hyp_sents[int(sample_id)].append(hypo_str)

                    # Score only the top hypothesis
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    print(
        '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset,
                                                      args.beam,
                                                      scorer.result_string()))

    if args.decoding_path is not None:
        with open(os.path.join(args.decoding_path, 'source.txt'),
                  'w',
                  encoding='utf-8') as f:
            for sents in src_sents:
                if len(sents) == 0:
                    continue
                for sent in sents:
                    f.write(sent + '\n')

        with open(os.path.join(args.decoding_path, 'target.txt'),
                  'w',
                  encoding='utf-8') as f:
            for sents in tgt_sents:
                if len(sents) == 0:
                    continue
                for sent in sents:
                    f.write(sent + '\n')

        with open(os.path.join(args.decoding_path, 'decoding.txt'),
                  'w',
                  encoding='utf-8') as f:
            for sents in hyp_sents:
                if len(sents) == 0:
                    continue
                for sent in sents:
                    f.write(sent + '\n')

    if len(list(args.num_ref.values())) == 1:
        num_ref = int(list(args.num_ref.values())[0])
    else:
        raise NotImplementedError

    ref_path = []

    if num_ref == 1:
        ref_path.append(
            os.path.join(args.valid_decoding_path,
                         args.gen_subset + '.tok.' + args.target_lang))
    else:
        for i in range(num_ref):
            ref_path.append(
                os.path.join(
                    args.valid_decoding_path,
                    args.gen_subset + '.tok.' + args.target_lang + str(i)))

    decoding_path = os.path.join(args.decoding_path, 'decoding.txt')

    #with open(decoding_path) as out_file:
    #    out_file.seek(0)
    #    subprocess.call(
    #        'perl %s/multi-bleu.perl %s' % (args.multi_bleu_path, ' '.join(ref_path)),
    #        stdin=out_file, shell=True)

    return scorer
Beispiel #13
0
def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('fairnr_cli.render')

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(os.pathsep),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    output_files, step = [], 0
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for i, sample in enumerate(t):
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            gen_timer.start()

            step, _output_files = task.inference_step(generator, models,
                                                      [sample, step])
            output_files += _output_files

            gen_timer.stop(500)
            wps_meter.update(500)
            t.log({'wps': round(wps_meter.avg)})

            break
            # if i > 5:
            #     break

    generator.save_images(output_files,
                          combine_output=args.render_combine_output)
Beispiel #14
0
def main(args):
    utils.import_user_module(args)

    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
        '--max-sentences/--batch-size cannot be larger than --buffer-size'

    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu
    use_ctc_loss = True if args.criterion == 'ctc_loss' else False

    # Setup task, e.g., image captioning
    task = tasks.setup_task(args)
    # Load dataset split
    task.load_dataset(args.gen_subset, combine=True, epoch=0)

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    model_paths = args.path.split(':')
    models, _model_args = checkpoint_utils.load_model_ensemble(
        model_paths,
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Set dictionaries
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())

    stats = collections.OrderedDict()
    num_sentences = 0
    num_correct = 0
    has_target = True

    with progress_bar.build_progress_bar(
        args, itr,
        prefix='inference on \'{}\' subset'.format(args.gen_subset),
        no_progress_bar='simple',
    ) as progress:
        wps_meter = TimeMeter()
        for sample in progress:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            gen_timer.start()
            hypos = task.inference_step(generator, models, sample)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None
                target_tokens = None
                if has_target:
                    if use_ctc_loss:
                        target_tokens = sample['target'][i]
                        target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
                    else:
                        # Remove padding
                        target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu()
                        # Regenerate original sentences from tokens.
                        target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

                if not args.quiet:
                    if has_target:
                        print('\nT-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                hypo = hypos[i][0]
                hypo_tokens = hypo['tokens'] if use_ctc_loss else hypo['tokens'].int().cpu()
                hypo_str = tgt_dict.string(hypo_tokens, args.remove_bpe, escape_unk=True)
                alignment = hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None

                if hypo_str == target_str:
                    num_correct += 1

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo_str, hypo['score']))
                    print('P-{}\t{}'.format(
                        sample_id,
                        ' '.join(map(
                            lambda x: '{:.4f}'.format(x),
                            hypo['positional_scores'].tolist(),
                        )) if not use_ctc_loss else None
                    ))

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id,
                            ' '.join(map(lambda x: str(utils.item(x)), alignment))
                        ))

                # Score only the top hypothesis
                if has_target:
                    if hasattr(scorer, 'add_string'):
                        scorer.add_string(target_str, hypo_str)
                    else:
                        scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            num_sentences += sample['nsentences']
            stats['wps'] = round(wps_meter.avg)
            stats['acc'] = num_correct / num_sentences
            progress.log(stats, tag='accuracy')
        progress.print(stats, tag='accuracy')

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
    return scorer
Beispiel #15
0
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'

    utils.import_user_module(parsed_args)

    logger.info(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    logger.info('loading model(s) from {}'.format(parsed_args.path))
    models, args = checkpoint_utils.load_model_ensemble(
        parsed_args.path.split(os.pathsep),
        arg_overrides=eval(parsed_args.model_overrides),
        task=task,
    )

    for arg in vars(parsed_args).keys():
        if arg not in {
                'self_target',
                'future_target',
                'past_target',
                'tokens_per_sample',
                'output_size_dictionary',
                'add_bos_token',
        }:
            setattr(args, arg, getattr(parsed_args, arg))

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    dataset = task.dataset(args.gen_subset)
    if args.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=args.tokens_per_sample,
            context_window=args.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    logger.info('{} {} {} examples'.format(args.data, args.gen_subset,
                                           len(dataset)))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    assert len(models) > 0

    logger.info('num. model params: {}'.format(
        sum(p.numel() for p in models[0].parameters())))

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=True,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(task.target_dictionary,
                            args.softmax_batch,
                            args=args)

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
            bpe_toks = {
                i
                for i in range(len(task.source_dictionary))
                if task.source_dictionary[i].endswith(bpe_cont)
            }
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    if args.knnlm and args.save_knnlm_dstore:
        raise ValueError(
            "Cannot use knnlm while trying to build the datastore!")

    if args.knnlm:
        knn_dstore = KNN_Dstore(args)

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()

        if args.save_knnlm_dstore:
            print('keytype being saved:', args.knn_keytype)
            if args.dstore_fp16:
                print('Saving fp16')
                dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy',
                                        dtype=np.float16,
                                        mode='w+',
                                        shape=(args.dstore_size,
                                               args.decoder_embed_dim))
                dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy',
                                        dtype=np.int16,
                                        mode='w+',
                                        shape=(args.dstore_size, 1))
            else:
                print('Saving fp32')
                dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy',
                                        dtype=np.float32,
                                        mode='w+',
                                        shape=(args.dstore_size,
                                               args.decoder_embed_dim))
                dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy',
                                        dtype=np.int,
                                        mode='w+',
                                        shape=(args.dstore_size, 1))

        dstore_idx = 0

        #knn_probs_file = open(args.output_log_probs_file_prefix + '.knn.txt', 'w')
        #orig_probs_file = open(args.output_log_probs_file_prefix + '.orig.txt', 'w')
        if args.knnlm:
            dists_file = open(args.output_log_probs_file_prefix + '.dists.txt',
                              'w')
            knns_file = open(
                args.output_log_probs_file_prefix + '.knn_indices.txt', 'w')
        if args.save_knnlm_dstore or args.knnlm:
            tokens_file = open(args.output_tokens_file, 'w')

        for ex_i, sample in enumerate(t):
            if 'net_input' not in sample:
                continue

            sample = utils.move_to_cuda(sample) if use_cuda else sample

            gen_timer.start()
            if args.knnlm:
                hypos = scorer.generate(models, sample, knn_dstore=knn_dstore)
            else:
                hypos = scorer.generate(models, sample)
            gen_timer.stop(sample['ntokens'])

            for i, hypos_i in enumerate(hypos):
                if i == len(hypos) - 1:
                    continue
                hypo = hypos_i[0]
                skipped = False
                if args.save_knnlm_dstore:
                    shape = hypo['dstore_keys'].shape
                    if shape[0] == args.tokens_per_sample:
                        if dstore_idx + shape[0] > args.dstore_size:
                            shape = [args.dstore_size - dstore_idx]
                            hypo['dstore_keys'] = hypo[
                                'dstore_keys'][:shape[0]]
                        if args.dstore_fp16:
                            dstore_keys[dstore_idx:shape[0] +
                                        dstore_idx] = hypo['dstore_keys'].view(
                                            -1, args.decoder_embed_dim).cpu(
                                            ).numpy().astype(np.float16)
                            dstore_vals[dstore_idx:shape[0] +
                                        dstore_idx] = hypo['tokens'].view(
                                            -1,
                                            1).cpu().numpy().astype(np.int16)
                        else:
                            dstore_keys[dstore_idx:shape[0] +
                                        dstore_idx] = hypo['dstore_keys'].view(
                                            -1, args.decoder_embed_dim).cpu(
                                            ).numpy().astype(np.float32)
                            dstore_vals[dstore_idx:shape[0] +
                                        dstore_idx] = hypo['tokens'].view(
                                            -1, 1).cpu().numpy().astype(np.int)

                        dstore_idx += shape[0]
                    else:
                        skipped = True
                        print('Skipping this one with shape', shape)

                sample_id = sample['id'][i]

                tokens = hypo['tokens']
                tgt_len = tokens.numel()
                pos_scores = hypo['positional_scores'].float()
                orig_scores = hypo['original_scores'].float()
                yhat_scores = hypo['yhat_scores'].float()
                if args.knnlm:
                    assert hypo['dists_full'] != None
                    dists_full = hypo['dists_full'].float()
                    knns_full = hypo['knns_full']

                    # knn_probs_file.write('\n'.join([str(prob) for prob in yhat_scores.tolist()]) + '\n')
                    # orig_probs_file.write('\n'.join([str(prob) for prob in orig_scores.tolist()]) + '\n')
                    dists_file.write('\n'.join([
                        str(dists_for_token)
                        for dists_for_token in dists_full.tolist()
                    ]) + '\n')
                    knns_file.write('\n'.join([
                        str(knns_for_token)
                        for knns_for_token in knns_full.tolist()
                    ]) + '\n')

                if args.save_knnlm_dstore or args.knnlm:
                    if not skipped:
                        word_tokens = [
                            task.target_dictionary[token]
                            for token in hypo['tokens']
                        ]
                        tokens_file.write('\n'.join(word_tokens) + '\n')
                        assert len(
                            hypo['yhat_scores'].float().tolist()) == len(
                                word_tokens)
                '''
                doc = spacy.tokens.doc.Doc(
                    nlp.vocab, words=word_tokens, spaces=[True for token in tokens])
                for name, proc in nlp.pipeline:
                    doc = proc(doc)
                '''

                if args.add_bos_token:
                    assert hypo['tokens'][0].item(
                    ) == task.target_dictionary.bos()
                    tokens = tokens[1:]
                    pos_scores = pos_scores[1:]

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(tgt_len - 1):
                        if tokens[i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                #inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                #if inf_scores.any():
                #    logger.info(
                #        'skipping tokens with inf scores:',
                #        task.target_dictionary.string(tokens[inf_scores.nonzero()])
                #    )
                #    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum().cpu()
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs or args.output_word_stats:
                    w = ''
                    word_prob = []
                    is_bpe = False
                    for i in range(len(tokens)):
                        w_ind = tokens[i].item()
                        w += task.source_dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))

                            next_prob = None
                            ind = i + 1
                            while ind < len(tokens):
                                if pos_scores[ind].item() != 0:
                                    next_prob = pos_scores[ind]
                                    break
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(
                                pos_scores[i].item(), next_prob)
                            is_bpe = False
                            w = ''

            wps_meter.update(sample['ntokens'])
            t.log({'wps': round(wps_meter.avg)})

    if args.save_knnlm_dstore:
        print("dstore_idx", dstore_idx, "final shape", shape)
        print("Keys", dstore_keys.shape, dstore_keys.dtype)
        print("Vals", dstore_vals.shape, dstore_vals.dtype)

    # knn_probs_file.close()
    # orig_probs_file.close()
    tokens_file.close()

    # Entities
    # mask = torch.tensor([1 if token.ent_type_ else 0 for token in doc], dtype=float)
    # count_entities = mask.sum()
    # if torch.cuda.is_available() and not parsed_args.cpu:
    #     mask = mask.cuda()
    # avg_nll_loss_entities = - (pos_scores * mask).sum() / count_entities.cpu() / math.log(2)

    avg_nll_loss = -score_sum / count / math.log(2)  # convert to base 2
    logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(
        gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format(
        avg_nll_loss, 2**avg_nll_loss))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(),
                         key=lambda x: x.count,
                         reverse=True):
            logger.info(ws)
Beispiel #16
0
 def _prepare_sample(self, sample):
     if sample is None or len(sample) == 0:
         return None
     return utils.move_to_cuda(sample)
Beispiel #17
0
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()
        model.eval()

    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    if args.print_vanilla_alignment:
        import string
        punc = string.punctuation
        src_punc_tokens = [
            w for w in range(len(src_dict)) if src_dict[w] in punc
        ]
    else:
        src_punc_tokens = None

    import time
    print('start time is :', time.strftime("%Y-%m-%d %X"))
    # import pdb;pdb.set_trace()
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.decoding_path is not None:
            align_sents = [[] for _ in range(4000000)]
            f_align_sents = [[] for _ in range(4000000)]
            b_align_sents = [[] for _ in range(4000000)]

        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            if args.print_vanilla_alignment:
                net_output = models[0](sample)
                alignments = models[0].extract_merge_then_align_alignment(
                    sample, net_output)
                f_alignments = models[0].extract_align_then_merge_alignment(
                    sample, net_output, src_punc_tokens)
                b_alignments = models[0].extract_align_then_merge_alignment(
                    sample, net_output, src_punc_tokens, reverse=True)
            else:
                alignments, f_alignments, b_alignments = None, None, None

            for i, sample_id in enumerate(sample['id'].tolist()):
                if args.print_vanilla_alignment and args.decoding_path is not None:
                    align_sents[int(sample_id)].append(
                        alignments[str(sample_id)])
                    f_align_sents[int(sample_id)].append(
                        f_alignments[sample_id])
                    b_align_sents[int(sample_id)].append(
                        b_alignments[sample_id])

    print('end time is :', time.strftime("%Y-%m-%d %X"))
    if args.decoding_path is not None and args.print_vanilla_alignment:
        with open(
                os.path.join(
                    args.decoding_path,
                    f'{args.gen_subset}.{args.source_lang}2{args.target_lang}.bidual.align'
                ), 'w') as f:
            for sents in align_sents:
                if len(sents) == 0:
                    continue
                for sent in sents:
                    f.write(str(sent) + '\n')

        with open(
                os.path.join(
                    args.decoding_path,
                    f'{args.gen_subset}.{args.source_lang}2{args.target_lang}.dualf.align'
                ), 'w') as f:
            for sents in f_align_sents:
                if len(sents) == 0:
                    continue
                for sent in sents:
                    f.write(str(sent) + '\n')

        with open(
                os.path.join(
                    args.decoding_path,
                    f'{args.gen_subset}.{args.source_lang}2{args.target_lang}.dualb.align'
                ), 'w') as f:
            for sents in b_align_sents:
                if len(sents) == 0:
                    continue
                for sent in sents:
                    f.write(str(sent) + '\n')
        print("finished ...")
Beispiel #18
0
def main(rank, world_size, args):
    if world_size > 1:
        torch.distributed.init_process_group(backend="gloo",
                                             init_method="env://",
                                             world_size=world_size,
                                             rank=rank)
        torch.cuda.set_device(rank)

    raw_args = args
    args = convert_namespace_to_omegaconf(args)
    if args.common.seed is not None:
        random.seed(args.common.seed)
        np.random.seed(args.common.seed)
        utils.set_torch_seed(args.common.seed)

    models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
        [raw_args.path], arg_overrides={"data": args.task.data})
    tgt_dict = task.target_dictionary

    for model in models:
        model.prepare_for_inference_(args)
        model.cuda().eval()
        if raw_args.fp16:
            model = model.half()
    model = models[0]

    config = ExpressiveCodeDataConfig(args.task.data)

    dataset = CodeDataset(
        manifest=config.manifests[raw_args.subset],
        dictionary=task.source_dictionary,
        dur_dictionary=task.source_duration_dictionary,
        f0_dictionary=task.source_f0_dictionary,
        config=config,
        discrete_dur=task.cfg.discrete_duration,
        discrete_f0=task.cfg.discrete_f0,
        log_f0=task.cfg.log_f0,
        normalize_f0_mean=task.cfg.normalize_f0_mean,
        normalize_f0_std=task.cfg.normalize_f0_std,
        interpolate_f0=task.cfg.interpolate_f0,
        shifts=task.cfg.stream_shifts,
        return_filename=True,
        strip_filename=False,
    )
    tgt_dict = task.target_dictionary
    shifts = dataset.shifts.dur, dataset.shifts.f0
    max_shift = max(shifts)

    fname = raw_args.output
    if world_size > 1:
        fname += f"_{rank}"
    output_file = open(fname, "w")

    if raw_args.filter_names:
        dataset = FilterNamesDataset(dataset, raw_args.filter_names)

    dataset = InferenceDataset(dataset,
                               raw_args.prefix_length,
                               filter_short=True)
    print(f"Dataset size {len(dataset)}")
    sampler = (None if world_size == 1 else DistributedSampler(
        dataset, num_replicas=world_size, rank=rank, shuffle=False))
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=dataset.collater,
        sampler=sampler,
    )

    Ts = raw_args.T_token, raw_args.T_duration, raw_args.T_f0
    decoder = TemperatureDecoder(Ts,
                                 discrete_dur=task.cfg.discrete_duration,
                                 discrete_f0=task.cfg.discrete_f0)

    dataset_size = len(dataset)

    f0_decoder = None
    if raw_args.f0_discretization_bounds:
        assert task.cfg.discrete_f0
        f0_decoder = Naive_F0_Decoder(raw_args.f0_discretization_bounds).cuda()

    pbar = (tqdm.tqdm(
        total=dataset_size if raw_args.max_samples is None else min(
            raw_args.max_samples, dataset_size)) if world_size == 1 else None)

    samples_produced = 0

    for batch in dataloader:
        if (raw_args.max_samples is not None
                and samples_produced >= raw_args.max_samples):
            break

        prefix = batch["prefix"][0]

        batch = explode_batch(batch, raw_args.batch_explosion_rate)
        batch = move_to_cuda(batch)

        if not raw_args.short_curcuit:
            produced_tokens, produced_durations, produced_f0, _ = do_sampling(
                models[0],
                batch,
                tgt_dict.eos(),
                decoder,
                autoregressive_steps=raw_args.max_length - prefix + max_shift,
                teacher_force_tokens=raw_args.teacher_force_tokens,
                match_duration=raw_args.match_duration,
                teacher_force_duration=raw_args.teacher_force_duration,
                teacher_force_f0=raw_args.teacher_force_f0,
            )

            # stip entries corresponding to <s>
            produced_tokens = produced_tokens[:, 1:]
            produced_durations = produced_durations[:, 1:]
            produced_f0 = produced_f0[:, 1:]

        else:
            max_length = raw_args.max_length + max_shift
            produced_tokens, produced_durations, produced_f0 = (
                batch["target"][:, :max_length],
                batch["dur_target"][:, :max_length],
                batch["f0_target"][:, :max_length],
            )

        if f0_decoder is not None:
            produced_f0 = f0_decoder(produced_f0)

        produced_tokens, produced_durations, produced_f0 = (
            produced_tokens.cpu().tolist(),
            produced_durations.cpu().tolist(),
            produced_f0.cpu().tolist(),
        )

        bsz = batch["target"].size(0)
        assert bsz == raw_args.batch_explosion_rate

        for i in range(bsz):
            if (raw_args.max_samples is not None
                    and samples_produced >= raw_args.max_samples):
                break

            produced_tokens_i = produced_tokens[i]
            produced_durations_i = produced_durations[i]
            produced_f0_i = produced_f0[i]

            (
                produced_tokens_i,
                produced_durations_i,
                produced_f0_i,
            ) = realign_shifted_streams(produced_tokens_i,
                                        produced_durations_i, produced_f0_i,
                                        shifts)

            produced_tokens_i, produced_durations_i, produced_f0_i = maybe_cut_eos(
                produced_tokens_i, produced_durations_i, produced_f0_i,
                tgt_dict.eos())

            produced_tokens_i, produced_durations_i, produced_f0_i = maybe_filter_pad(
                produced_tokens_i, produced_durations_i, produced_f0_i,
                tgt_dict.pad())

            if raw_args.match_duration:
                # NB: here we cheat a bit and use that padding has duration 0
                # so no need to re-align and remove padding
                dur_target_i = batch["dur_target"][i, :].sum().item()
                produced_tokens_i, produced_durations_i, produced_f0_i = match_duration(
                    produced_tokens_i, produced_durations_i, produced_f0_i,
                    dur_target_i)

            if raw_args.cut_prompt:
                produced_tokens_i, produced_durations_i, produced_f0_i = (
                    produced_tokens_i[prefix:],
                    produced_durations_i[prefix:],
                    produced_f0_i[prefix:],
                )

            prompt_fname = batch["filename"][0]
            fname = str(
                pathlib.Path(prompt_fname).with_suffix("")) + f"__{i}.wav"

            token_stream = unroll_duration(produced_tokens_i,
                                           produced_durations_i)
            f0_stream = unroll_duration(produced_f0_i, produced_durations_i)
            output_line = json.dumps({
                "audio":
                fname,
                "prompt":
                prompt_fname,
                raw_args.code_type:
                " ".join(map(str, token_stream)),
                "duration":
                round(
                    sum(produced_durations_i) *
                    CODETYPE_TO_FRAMETIME[raw_args.code_type],
                    3,
                ),
                "raw_duration":
                produced_durations_i,
                "raw_f0":
                produced_f0_i,
                "f0": [round(f0, 3) for f0 in f0_stream],
            })
            print(output_line, file=output_file)

            if pbar:
                pbar.update(1)
            samples_produced += 1

        if raw_args.debug:
            break

    output_file.close()

    if world_size > 1:
        # important that everything is flushed before aggregating
        torch.distributed.barrier()

    if world_size > 1 and rank == 0:
        with open(raw_args.output, "w") as fout:
            for i in range(world_size):
                f = raw_args.output + f"_{i}"
                with open(f, "r") as fin:
                    fout.write(fin.read())
                os.remove(f)
def main(args, task=None, model_state=None):
    check_args(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 4000000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    if task is None:
        # Load dataset splits
        task = tasks.setup_task(args)
        task.load_dataset(args.gen_subset)
        logger.info("| {} {} {} examples".format(
            args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    all_trans = []
    if 'audio' in args.task:
        """
            tasks that load tsv data
            trans_path: raw trans (before bpe)
        """
        trans_path = os.path.join(args.data, "{}.word".format(args.gen_subset))
        with open(trans_path, "r") as f:
            for line in f:
                all_trans.append(line)

    # Set dictionary
    tgt_dict = task.target_dictionary

    logger.info("| decoding with criterion {}".format(args.criterion))

    # Load ensemble

    logger.info("| loading model(s) from {}".format(args.path))
    models, criterions, _ = load_models_and_criterions(
        args.path,
        data_path=args.data,
        arg_overrides=eval(args.model_overrides),  # noqa
        task=task,
        model_state=model_state,
    )
    optimize_models(args, use_cuda, models)

    # Load dataset (possibly sharded)
    itr = get_dataset_itr(args, task, models)

    # Initialize generator
    gen_timer = StopwatchMeter()

    generator = CIF_BERT_Decoder(args, task.target_dictionary)

    num_sentences = 0

    if args.results_path is not None and not os.path.exists(args.results_path):
        os.makedirs(args.results_path)

    res_files = prepare_result_files(args)
    errs_t = 0
    lengths_t = 0
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if "net_input" not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample["target"][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample["id"].tolist()):
                speaker = None
                # id = task.dataset(args.gen_subset).ids[int(sample_id)]
                id = sample_id
                toks = sample["target"][
                    i, :] if 'target_label' not in sample else sample[
                        "target_label"][i, :]
                target_tokens = (utils.strip_pad(toks,
                                                 tgt_dict.pad()).int().cpu())
                trans = all_trans[id] if all_trans else task.dataset(
                    args.gen_subset).ids[sample_id][1]['output']['text'].strip(
                    )
                # Process top predictions
                errs, length = process_predictions(args, hypos[i], None,
                                                   tgt_dict, target_tokens,
                                                   res_files, speaker, id,
                                                   trans)
                errs_t += errs
                lengths_t += length

            wps_meter.update(num_generated_tokens)
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += sample[
                "nsentences"] if "nsentences" in sample else sample[
                    "id"].numel()

    wer = None

    if lengths_t > 0:
        wer = errs_t * 100.0 / lengths_t
        logger.info(f"WER: {wer}")

    logger.info("| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
                "sentences/s, {:.2f} tokens/s)".format(
                    num_sentences,
                    gen_timer.n,
                    gen_timer.sum,
                    num_sentences / gen_timer.sum,
                    1.0 / gen_timer.avg,
                ))
    logger.info("| Generate {} with beam={}".format(args.gen_subset,
                                                    args.beam))

    return task, wer
def _main(cfg, output_file):
    logging.basicConfig(
        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=os.environ.get("LOGLEVEL", "INFO").upper(),
        stream=output_file,
    )
    logger = logging.getLogger("espresso.speech_recognize")
    if output_file is not sys.stdout:  # also print to stdout
        logger.addHandler(logging.StreamHandler(sys.stdout))

    print_options_meaning_changes(cfg, logger)

    utils.import_user_module(cfg.common)

    if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
        cfg.dataset.max_tokens = 12000
    logger.info(cfg)

    # Fix seed for stochastic decoding
    if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
        np.random.seed(cfg.common.seed)
        utils.set_torch_seed(cfg.common.seed)

    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    task = tasks.setup_task(cfg.task)
    task.build_tokenizer(cfg.tokenizer)
    task.build_bpe(cfg.bpe)

    # Set dictionary
    dictionary = task.target_dictionary

    overrides = ast.literal_eval(cfg.common_eval.model_overrides)

    # Load ensemble
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))
    models, saved_cfg = checkpoint_utils.load_model_ensemble(
        utils.split_paths(cfg.common_eval.path),
        arg_overrides=overrides,
        task=task,
        suffix=cfg.checkpoint.checkpoint_suffix,
        strict=(cfg.checkpoint.checkpoint_shard_count == 1),
        num_shards=cfg.checkpoint.checkpoint_shard_count,
    )

    # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
    task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)

    if cfg.generation.lm_path is not None:
        overrides["data"] = cfg.task.data

        try:
            lms, _ = checkpoint_utils.load_model_ensemble(
                utils.split_paths(cfg.generation.lm_path),
                arg_overrides=overrides,
                task=None,
            )
        except:
            logger.warning(
                f"Failed to load language model! Please make sure that the language model dict is the same "
                f"as target dict and is located in the data dir ({cfg.task.data})"
            )
            raise

        assert len(lms) == 1 or len(lms) == 2  # Multi-level LM expects two LMs
    else:
        lms = [None]

    for i, m in enumerate(lms):
        if m is None:
            continue
        if hasattr(m, "is_wordlm") and m.is_wordlm:
            # assume subword LM comes before word LM
            if i > 0 and isinstance(lms[i - 1], FairseqLanguageModel):
                lms[i - 1] = MultiLevelLanguageModel(
                    m,
                    lms[i - 1],
                    subwordlm_weight=cfg.generation.subwordlm_weight,
                    oov_penalty=cfg.generation.oov_penalty,
                    open_vocab=not cfg.generation.disable_open_vocab,
                )
                del lms[i]
                logger.info("LM fusion with Multi-level LM")
            else:
                lms[i] = TensorizedLookaheadLanguageModel(
                    m,
                    dictionary,
                    oov_penalty=cfg.generation.oov_penalty,
                    open_vocab=not cfg.generation.disable_open_vocab,
                )
                logger.info("LM fusion with Look-ahead Word LM")
        else:
            assert isinstance(m, FairseqLanguageModel)
            logger.info("LM fusion with Subword LM")
    if cfg.generation.lm_weight != 0.0:
        logger.info("using LM fusion with lm-weight={:.2f}".format(
            cfg.generation.lm_weight))

    # Optimize ensemble for generation
    for model in chain(models, lms):
        if model is None:
            continue
        if cfg.common.fp16:
            model.half()
        if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
            model.cuda()
        model.prepare_for_inference_(cfg)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(cfg.dataset.gen_subset),
        max_tokens=cfg.dataset.max_tokens,
        max_sentences=cfg.dataset.batch_size,
        max_positions=utils.resolve_max_positions(
            task.max_positions(), *[m.max_positions() for m in models]),
        ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
        seed=cfg.common.seed,
        num_shards=cfg.distributed_training.distributed_world_size,
        shard_id=cfg.distributed_training.distributed_rank,
        num_workers=cfg.dataset.num_workers,
        data_buffer_size=cfg.dataset.data_buffer_size,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=cfg.common.log_format,
        log_interval=cfg.common.log_interval,
        default_log_format=("tqdm"
                            if not cfg.common.no_progress_bar else "simple"),
    )

    # Initialize generator
    if cfg.generation.match_source_len:
        logger.warning(
            "The option match_source_len is not applicable to speech recognition. Ignoring it."
        )
    gen_timer = StopwatchMeter()

    extra_gen_cls_kwargs = {
        "lm_model": lms[0],
        "lm_weight": cfg.generation.lm_weight,
        "eos_factor": cfg.generation.eos_factor,
    }
    cfg.generation.score_reference = False  # not applicable for ASR
    save_attention_plot = cfg.generation.print_alignment is not None
    cfg.generation.print_alignment = None  # not applicable for ASR
    generator = task.build_generator(models,
                                     cfg.generation,
                                     extra_gen_cls_kwargs=extra_gen_cls_kwargs)

    # Handle tokenization and BPE
    tokenizer = task.build_tokenizer(cfg.tokenizer)
    bpe = task.build_bpe(cfg.bpe)

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    scorer = wer.Scorer(dictionary,
                        wer_output_filter=cfg.task.wer_output_filter)

    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    for sample in progress:
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if "net_input" not in sample:
            continue

        prefix_tokens = None
        if cfg.generation.prefix_size > 0:
            prefix_tokens = sample["target"][:, :cfg.generation.prefix_size]

        constraints = None
        if "constraints" in sample:
            constraints = sample["constraints"]

        gen_timer.start()
        hypos = task.inference_step(
            generator,
            models,
            sample,
            prefix_tokens=prefix_tokens,
            constraints=constraints,
        )
        num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
        gen_timer.stop(num_generated_tokens)

        # obtain nonpad mask of encoder output to plot attentions
        if save_attention_plot:
            net_input = sample["net_input"]
            src_tokens = net_input["src_tokens"]
            output_lengths = models[0].encoder.output_lengths(
                net_input["src_lengths"])
            nonpad_idxs = sequence_mask(
                output_lengths,
                models[0].encoder.output_lengths(src_tokens.size(1)))

        for i in range(len(sample["id"])):
            has_target = sample["target"] is not None
            utt_id = sample["utt_id"][i]

            # Retrieve the original sentences
            if has_target:
                target_str = dictionary.wordpiece_encode(sample["text"][i])
                if not cfg.common_eval.quiet:
                    print("T-{}\t{}".format(utt_id, sample["text"][i]),
                          file=output_file)

            # Process top predictions
            for j, hypo in enumerate(hypos[i][:cfg.generation.nbest]):
                hypo_str = dictionary.string(
                    hypo["tokens"].int().cpu(),
                    bpe_symbol=None,
                    extra_symbols_to_ignore=get_symbols_to_strip_from_output(
                        generator),
                )  # not removing bpe at this point
                detok_hypo_str = decode_fn(hypo_str)
                if not cfg.common_eval.quiet:
                    score = hypo["score"] / math.log(2)  # convert to base 2
                    print("H-{}\t{}\t{}".format(utt_id, detok_hypo_str, score),
                          file=output_file)

                # Score and obtain attention only the top hypothesis
                if j == 0:
                    # src_len x tgt_len
                    attention = hypo["attention"][nonpad_idxs[i]].float().cpu() \
                        if save_attention_plot and hypo["attention"] is not None else None
                    if save_attention_plot and attention is not None:
                        save_dir = os.path.join(cfg.common_eval.results_path,
                                                "attn_plots")
                        os.makedirs(save_dir, exist_ok=True)
                        plot_attention(attention, detok_hypo_str, utt_id,
                                       save_dir)
                    scorer.add_prediction(utt_id, hypo_str)
                    if has_target:
                        scorer.add_evaluation(utt_id, target_str, hypo_str)

        wps_meter.update(num_generated_tokens)
        progress.log({"wps": round(wps_meter.avg)})
        num_sentences += sample[
            "nsentences"] if "nsentences" in sample else sample["id"].numel()

    logger.info("NOTE: hypothesis and token scores are output in base 2")
    logger.info(
        "Recognized {:,} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)"
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if save_attention_plot:
        logger.info("Saved attention plots in " + save_dir)

    if has_target:
        scorer.add_ordered_utt_list(
            task.datasets[cfg.dataset.gen_subset].tgt.utt_ids)

    fn = "decoded_char_results.txt"
    with open(os.path.join(cfg.common_eval.results_path, fn),
              "w",
              encoding="utf-8") as f:
        f.write(scorer.print_char_results())
        logger.info("Decoded char results saved as " + f.name)

    fn = "decoded_results.txt"
    with open(os.path.join(cfg.common_eval.results_path, fn),
              "w",
              encoding="utf-8") as f:
        f.write(scorer.print_results())
        logger.info("Decoded results saved as " + f.name)

    if has_target:
        header = "Recognize {} with beam={}: ".format(cfg.dataset.gen_subset,
                                                      cfg.generation.beam)
        fn = "wer"
        with open(os.path.join(cfg.common_eval.results_path, fn),
                  "w",
                  encoding="utf-8") as f:
            res = "WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%".format(
                *(scorer.wer()))
            logger.info(header + res)
            f.write(res + "\n")
            logger.info("WER saved in " + f.name)

        fn = "cer"
        with open(os.path.join(cfg.common_eval.results_path, fn),
                  "w",
                  encoding="utf-8") as f:
            res = "CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%".format(
                *(scorer.cer()))
            logger.info(" " * len(header) + res)
            f.write(res + "\n")
            logger.info("CER saved in " + f.name)

        fn = "aligned_results.txt"
        with open(os.path.join(cfg.common_eval.results_path, fn),
                  "w",
                  encoding="utf-8") as f:
            f.write(scorer.print_aligned_results())
            logger.info("Aligned results saved as " + f.name)
    return scorer
Beispiel #21
0
def main(args, override_args=None):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    use_fp16 = args.fp16
    use_cuda = torch.cuda.is_available() and not args.cpu

    if use_cuda:
        torch.cuda.set_device(args.device_id)

    if override_args is not None:
        overrides = vars(override_args)
        overrides.update(eval(getattr(override_args, 'model_overrides', '{}')))
    else:
        overrides = None

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
        [args.path],
        arg_overrides=overrides,
        suffix=getattr(args, "checkpoint_suffix", ""),
    )
    model = models[0]

    # Move models to GPU
    for model in models:
        if use_fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Print args
    logger.info(model_args)

    # Build criterion
    criterion = task.build_criterion(model_args)
    criterion.eval()

    for subset in args.valid_subset.split(','):
        try:
            task.load_dataset(subset, combine=False, epoch=1)
            dataset = task.dataset(subset)
        except KeyError:
            raise Exception('Cannot find dataset: ' + subset)

        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=dataset,
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                *[m.max_positions() for m in models],
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            prefix=f"valid on '{subset}' subset",
            default_log_format=('tqdm'
                                if not args.no_progress_bar else 'simple'),
        )

        log_outputs = []
        for i, sample in enumerate(progress):
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            _loss, _sample_size, log_output = task.valid_step(
                sample, model, criterion)
            progress.log(log_output, step=i)
            log_outputs.append(log_output)

        if args.distributed_world_size > 1:
            log_outputs = distributed_utils.all_gather_list(
                log_outputs,
                max_size=getattr(args, 'all_gather_list_size', 16384),
            )
            log_outputs = list(chain.from_iterable(log_outputs))

        with metrics.aggregate() as agg:
            task.reduce_metrics(log_outputs, criterion)
            log_output = agg.get_smoothed_values()

        progress.print(log_output, tag=subset, step=i)
Beispiel #22
0
def main():

    args = parser.parse_args()

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)
    model = TransformerModel.build_model(args, task).cuda()
    criterion = task.build_criterion(args).cuda()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 betas=eval(args.adam_betas),
                                 eps=args.adam_eps,
                                 weight_decay=args.weight_decay)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    epoch_itr = data.EpochBatchIterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=(args.max_source_positions, args.max_target_positions),
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=1,
        num_shards=1,
        shard_id=0,
    )

    losses = AverageMeter()

    encoder_layer_forward = [
        AverageMeter() for _ in range(len(model.encoder.layers[0].layer))
    ]
    decoder_layer_forward = [
        AverageMeter() for _ in range(len(model.decoder.layers[0].layer))
    ]
    encoder_layer_backward = [
        AverageMeter() for _ in range(len(model.encoder.layers[0].layer))
    ]
    decoder_layer_backward = [
        AverageMeter() for _ in range(len(model.decoder.layers[0].layer))
    ]

    def measure_hook(forward, backward):
        def hook(module, input, output):
            for i, layer in enumerate(module.layer):

                if len(input) == 2:
                    x, _ = input
                else:
                    x, = input
                x = x.detach().clone().requires_grad_()

                # warm-up
                for _ in range(5):
                    if isinstance(layer, nn.MultiheadAttention):
                        out, _ = layer(x, x, x)
                    else:
                        out = layer(x)
                    torch.autograd.backward(out, out)

                starter, ender = torch.cuda.Event(
                    enable_timing=True), torch.cuda.Event(enable_timing=True)
                for _ in range(50):
                    starter.record()
                    if isinstance(layer, nn.MultiheadAttention):
                        out, _ = layer(x, x, x)
                    else:
                        out = layer(x)
                    ender.record()
                    torch.cuda.synchronize()
                    forward[i].update(starter.elapsed_time(ender))

                    starter.record()
                    torch.autograd.backward(out, out)
                    ender.record()
                    torch.cuda.synchronize()
                    backward[i].update(starter.elapsed_time(ender))

        return hook

    for layer in model.encoder.layers:
        layer.register_forward_hook(
            measure_hook(encoder_layer_forward, encoder_layer_backward))

    for layer in model.decoder.layers:
        layer.register_forward_hook(
            measure_hook(decoder_layer_forward, decoder_layer_backward))

    embed_forward = AverageMeter()
    embed_backward = AverageMeter()

    def embed_hook(module, input, output):
        tokens, _ = input

        # warm-up
        for _ in range(5):
            x = module.embed_scale * module.embed_tokens(tokens)
            x += module.embed_positions(tokens)
            torch.autograd.backward(x, x)

        starter, ender = torch.cuda.Event(
            enable_timing=True), torch.cuda.Event(enable_timing=True)
        for _ in range(50):
            starter.record()
            x = module.embed_scale * module.embed_tokens(tokens)
            x += module.embed_positions(tokens)
            ender.record()
            torch.cuda.synchronize()
            embed_forward.update(starter.elapsed_time(ender))

            starter.record()
            torch.autograd.backward(x, x)
            ender.record()
            torch.cuda.synchronize()
            embed_backward.update(starter.elapsed_time(ender))

    model.encoder.register_forward_hook(embed_hook)

    linear_forward = AverageMeter()
    linear_backward = AverageMeter()

    def linear_hook(module, input, output):
        _, encode_out = input
        encode_out = encode_out.detach().clone().requires_grad_()

        # warm-up
        for _ in range(5):
            x = encode_out.transpose(0, 1)
            out = F.linear(x, module.embed_out)
            torch.autograd.backward(out, out)

        starter, ender = torch.cuda.Event(
            enable_timing=True), torch.cuda.Event(enable_timing=True)
        for _ in range(50):
            starter.record()
            x = encode_out.transpose(0, 1)
            out = F.linear(x, module.embed_out)
            ender.record()
            torch.cuda.synchronize()
            linear_forward.update(starter.elapsed_time(ender))

            starter.record()
            torch.autograd.backward(out, out)
            ender.record()
            torch.cuda.synchronize()
            linear_backward.update(starter.elapsed_time(ender))

    model.decoder.register_forward_hook(linear_hook)

    itr = epoch_itr.next_epoch_itr()
    max_positions = (args.max_source_positions, args.max_target_positions)
    for i, sample in enumerate(itr):
        sample = task.dataset('train').get_dummy_batch(args.max_tokens,
                                                       max_positions)
        sample = utils.move_to_cuda(sample)
        loss, _, logging_output = criterion(model, sample)
        num_tokens = logging_output['ntokens']
        losses.update(loss.item() / num_tokens / math.log(2), num_tokens)
        if i % 100 == 0:
            print('Loss: {loss.val:.4f} ({loss.avg:.4f})'.format(loss=losses))
            print(
                'Time: {forward_time.avg:.3f} ({backward_time.avg:.3f})'
                '{forward_time_decoder.avg:.3f} ({backward_time_decoder.avg:.3f})'
                .format(forward_time=encoder_layer_forward[0],
                        backward_time=encoder_layer_backward[0],
                        forward_time_decoder=decoder_layer_forward[-1],
                        backward_time_decoder=decoder_layer_backward[-1]))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        break

    stat = {i: {} for i in range(len(decoder_layer_forward))}
    for i, (f,
            b) in enumerate(zip(encoder_layer_forward,
                                encoder_layer_backward)):
        stat[i]['encoder'] = {}
        stat[i]['encoder']['forward'] = f.avg
        stat[i]['encoder']['backward'] = b.avg

    for i, (f,
            b) in enumerate(zip(decoder_layer_forward,
                                decoder_layer_backward)):
        stat[i]['decoder'] = {}
        stat[i]['decoder']['forward'] = f.avg
        stat[i]['decoder']['backward'] = b.avg

    stat['embed'] = {}
    stat['embed']['forward'] = embed_forward.avg
    stat['embed']['backward'] = embed_backward.avg

    stat['linear'] = {}
    stat['linear']['forward'] = linear_forward.avg
    stat['linear']['backward'] = linear_backward.avg

    with open('time.json', 'w') as file:
        json.dump(stat, file, indent=4)
Beispiel #23
0
def main(cfg: DictConfig, override_args=None):
    if isinstance(cfg, Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    utils.import_user_module(cfg.common)

    assert (
        cfg.dataset.max_tokens is not None
        or cfg.dataset.batch_size is not None
    ), "Must specify batch size either with --max-tokens or --batch-size"

    use_fp16 = cfg.common.fp16
    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    if use_cuda:
        torch.cuda.set_device(cfg.distributed_training.device_id)

    if override_args is not None:
        overrides = vars(override_args)
        overrides.update(eval(getattr(override_args, "model_overrides", "{}")))
    else:
        overrides = None

    # Load ensemble
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))
    models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
        [cfg.common_eval.path],
        arg_overrides=overrides,
        suffix=cfg.checkpoint.checkpoint_suffix,
    )
    model = models[0]

    # Move models to GPU
    for model in models:
        if use_fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Print args
    logger.info(model_args)

    # Build criterion
    criterion = task.build_criterion(model_args.criterion)
    criterion.eval()

    for subset in cfg.dataset.valid_subset.split(","):
        try:
            task.load_dataset(subset, combine=False, epoch=1)
            dataset = task.dataset(subset)
        except KeyError:
            raise Exception("Cannot find dataset: " + subset)

        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=dataset,
            max_tokens=cfg.dataset.max_tokens,
            max_sentences=cfg.dataset.batch_size,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                *[m.max_positions() for m in models],
            ),
            ignore_invalid_inputs=cfg.dataset.
            skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=cfg.dataset.
            required_batch_size_multiple,
            seed=cfg.common.seed,
            num_shards=cfg.distributed_training.distributed_world_size,
            shard_id=cfg.distributed_training.distributed_rank,
            num_workers=cfg.dataset.num_workers,
            data_buffer_size=cfg.dataset.data_buffer_size,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=cfg.common.log_format,
            log_interval=cfg.common.log_interval,
            prefix=f"valid on '{subset}' subset",
            default_log_format=("tqdm" if not cfg.common.no_progress_bar else
                                "simple"),
        )

        log_outputs = []
        for i, sample in enumerate(progress):
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            _loss, _sample_size, log_output = task.valid_step(
                sample, model, criterion)
            progress.log(log_output, step=i)
            log_outputs.append(log_output)

        if cfg.distributed_training.distributed_world_size > 1:
            log_outputs = distributed_utils.all_gather_list(
                log_outputs,
                max_size=cfg.common.all_gather_list_size,
            )
            log_outputs = list(chain.from_iterable(log_outputs))

        with metrics.aggregate() as agg:
            task.reduce_metrics(log_outputs, criterion)
            log_output = agg.get_smoothed_values()

        progress.print(log_output, tag=subset, step=i)
def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('fairseq_cli.generate')

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    # Fix seed for stochastic decoding
    if args.seed is not None and not args.no_seed_provided:
        np.random.seed(args.seed)
        utils.set_torch_seed(args.seed)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args.path),
        arg_overrides=eval(args.model_overrides),
        task=task,
        suffix=getattr(args, "checkpoint_suffix", ""),
    )

    # Optimize ensemble for generation
    for model in models:
        model.prepare_for_inference_(args)
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        default_log_format=('tqdm' if not args.no_progress_bar else 'none'),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(models, args)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    for sample in progress:
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        prefix_tokens = None
        if args.prefix_size > 0:
            prefix_tokens = sample['target'][:, :args.prefix_size]

        gen_timer.start()
        hypos = task.inference_step(generator, models, sample, prefix_tokens)
        num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
        gen_timer.stop(num_generated_tokens)

        for i, sample_id in enumerate(sample['id'].tolist()):
            has_target = sample['target'] is not None

            # Remove padding
            src_tokens = utils.strip_pad(
                sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
            target_tokens = None
            if has_target:
                target_tokens = utils.strip_pad(sample['target'][i, :],
                                                tgt_dict.pad()).int().cpu()

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(
                    args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(
                    args.gen_subset).tgt.get_original_text(sample_id)
            else:
                if src_dict is not None:
                    src_str = src_dict.string(src_tokens, args.remove_bpe)
                else:
                    src_str = ""
                if has_target:
                    target_str = tgt_dict.string(
                        target_tokens,
                        args.remove_bpe,
                        escape_unk=True,
                        extra_symbols_to_ignore=
                        get_symbols_to_strip_from_output(generator),
                    )

            src_str = decode_fn(src_str)
            if has_target:
                target_str = decode_fn(target_str)

            if not args.quiet:
                if src_dict is not None:
                    print('S-{}\t{}'.format(sample_id, src_str),
                          file=output_file)
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str),
                          file=output_file)

            # Process top predictions
            for j, hypo in enumerate(hypos[i][:args.nbest]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                    extra_symbols_to_ignore=get_symbols_to_strip_from_output(
                        generator),
                )
                detok_hypo_str = decode_fn(hypo_str)
                if not args.quiet:
                    score = hypo['score'] / math.log(2)  # convert to base 2
                    # original hypothesis (after tokenization and BPE)
                    print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str),
                          file=output_file)
                    # detokenized hypothesis
                    print('D-{}\t{}\t{}'.format(sample_id, score,
                                                detok_hypo_str),
                          file=output_file)
                    print(
                        'P-{}\t{}'.format(
                            sample_id,
                            ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    # convert from base e to base 2
                                    hypo['positional_scores'].div_(math.log(2)
                                                                   ).tolist(),
                                ))),
                        file=output_file)

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id, ' '.join([
                                '{}-{}'.format(src_idx, tgt_idx)
                                for src_idx, tgt_idx in alignment
                            ])),
                              file=output_file)

                    if args.print_step:
                        print('I-{}\t{}'.format(sample_id, hypo['steps']),
                              file=output_file)

                    if getattr(args, 'retain_iter_history', False):
                        for step, h in enumerate(hypo['history']):
                            _, h_str, _ = utils.post_process_prediction(
                                hypo_tokens=h['tokens'].int().cpu(),
                                src_str=src_str,
                                alignment=None,
                                align_dict=None,
                                tgt_dict=tgt_dict,
                                remove_bpe=None,
                            )
                            print('E-{}_{}\t{}'.format(sample_id, step, h_str),
                                  file=output_file)

                # Score only the top hypothesis
                if has_target and j == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tgt_dict.encode_line(
                            target_str, add_if_not_exist=True)
                        hypo_tokens = tgt_dict.encode_line(
                            detok_hypo_str, add_if_not_exist=True)
                    if hasattr(scorer, 'add_string'):
                        scorer.add_string(target_str, detok_hypo_str)
                    else:
                        scorer.add(target_tokens, hypo_tokens)

        wps_meter.update(num_generated_tokens)
        progress.log({'wps': round(wps_meter.avg)})
        num_sentences += sample['nsentences']

    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info(
        'Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        if args.bpe and not args.sacrebleu:
            if args.remove_bpe:
                logger.warning(
                    "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
                )
            else:
                logger.warning(
                    "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words.  Use --sacrebleu for standard 13a BLEU tokenization"
                )
        logger.info('Generate {} with beam={}: {}'.format(
            args.gen_subset, args.beam, scorer.result_string()))

    return scorer
Beispiel #25
0
def test(parsed_args):
    # Make sure we didn't screw up the params
    assert parsed_args.path is not None, '--path required for evaluation!'
    assert parsed_args.sample_break_mode == 'eos', 'Sample break mode must be eos!'

    # Print the args
    import_user_module(parsed_args)
    print(parsed_args)

    # Do we use CUDA
    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    # Get the task (Language Modeling)
    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
    models, args = utils.load_ensemble_for_inference(
        parsed_args.path.split(':'),
        task,
        model_arg_overrides=eval(parsed_args.model_overrides),
    )

    for arg in vars(parsed_args).keys():
        if arg not in {
                'self_target', 'future_target', 'past_target',
                'tokens_per_sample', 'output_size_dictionary'
        }:
            setattr(args, arg, getattr(parsed_args, arg))

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    dataset = task.dataset(args.gen_subset)
    if args.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=args.tokens_per_sample,
            context_window=args.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(dataset)))

    # Optimize model for generation
    assert len(models) > 0
    model = models[0]
    model.make_generation_fast_()
    if args.fp16:
        model.half()
    if use_cuda:
        model.cuda()

    # Make data iterator
    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=True,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Iterate over batches of sentences
    # Get the sentence logps for the batch
    all_s_logps = []
    all_n_tokens = 0
    all_n_sentences = 0
    for sample in itr:
        if 'net_input' not in sample:
            continue

        # Move sample to GPU if possible
        sample = utils.move_to_cuda(sample) if use_cuda else sample

        # Number of sentences in this batch
        bsz = sample['nsentences']
        all_n_sentences += bsz

        # Get the softmax outputs for the batch
        # The resultant tensor has shape: BATCH_SZ x N_TOKENS x VOCAB_SZ
        probs = []
        net_input = sample['net_input']
        with torch.no_grad():
            model.eval()
            decoder_out = model.forward(**net_input)
            probs = model.get_normalized_probs(decoder_out,
                                               log_probs=True,
                                               sample=sample).data

        # Make sure we have a softmax-sequence for each sentence in the batch
        assert len(probs) == bsz

        # Assert that the softmax output is correct
        assert torch.allclose(torch.sum(torch.exp(probs), dim=2),
                              torch.ones(probs.shape[:2]))

        # Get the token logps for each sentence from the softmax outputs
        target = sample['target']
        logps = probs.gather(
            dim=2,
            index=target.unsqueeze(-1),
        ).squeeze(2)

        # Iterate over each sentence in the batch
        # Get the sum of logps for each sentence
        start_idxs = [0] * bsz
        for i in range(bsz):

            # Get the token indices / strings
            tokens = utils.strip_pad(target[i, start_idxs[i]:],
                                     task.source_dictionary.pad())
            token_idxs = [tokens[i].item() for i in range(len(tokens))]
            token_strings = [task.source_dictionary[idx] for idx in token_idxs]

            # Maintain total number of tokens
            all_n_tokens += len(tokens)

            # This is the original sentence
            sentence = ' '.join(token_strings)

            # Get the token logps for this sentence
            s_len = len(tokens)
            s_logps = logps[i][:s_len]
            all_s_logps.append(torch.sum(s_logps))

    # Get the average sentence logp over all sentences in the test set
    avg_s_logp = sum(all_s_logps) / all_n_sentences
    print(-1 * avg_s_logp.item())
Beispiel #26
0
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'

    utils.import_user_module(parsed_args)

    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
    models, args = checkpoint_utils.load_model_ensemble(
        parsed_args.path.split(':'),
        arg_overrides=eval(str(parsed_args.model_overrides)),   # Add str() by xxx, for some reason, model_overrides={} not '{}' after training.
        task=task,
    )

    for arg in vars(parsed_args).keys():
        if arg not in {
            'self_target', 'future_target', 'past_target', 'tokens_per_sample',
            'output_size_dictionary', 'add_bos_token',
        }:
            setattr(args, arg, getattr(parsed_args, arg))

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
    task = tasks.setup_task(args)

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    dataset = task.dataset(args.gen_subset)
    if args.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=args.tokens_per_sample,
            context_window=args.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    assert len(models) > 0

    print('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))

    itr = task.get_batch_iterator(
        dataset=dataset,
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(*[
            model.max_positions() for model in models
        ]),
        ignore_invalid_inputs=True,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
            bpe_toks = set(
                i
                for i in range(len(task.source_dictionary))
                if task.source_dictionary[i].endswith(bpe_cont)
            )
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()

        for sample in t:
            if 'net_input' not in sample:
                continue

            sample = utils.move_to_cuda(sample) if use_cuda else sample

            gen_timer.start()
            hypos = scorer.generate(models, sample)
            gen_timer.stop(sample['ntokens'])

            for i, hypos_i in enumerate(hypos):
                hypo = hypos_i[0]
                sample_id = sample['id'][i]

                tokens = hypo['tokens']
                tgt_len = tokens.numel()
                pos_scores = hypo['positional_scores'].float()

                if args.add_bos_token:
                    assert hypo['tokens'][0].item() == task.target_dictionary.bos()
                    tokens = tokens[1:]
                    pos_scores = pos_scores[1:]

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(tgt_len - 1):
                        if tokens[i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
                          task.target_dictionary.string(tokens[inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum().cpu()
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs or args.output_word_stats:
                    w = ''
                    word_prob = []
                    is_bpe = False
                    for i in range(len(tokens)):
                        w_ind = tokens[i].item()
                        w += task.source_dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))

                            next_prob = None
                            ind = i + 1
                            while ind < len(tokens):
                                if pos_scores[ind].item() != 0:
                                    next_prob = pos_scores[ind]
                                    break
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob)
                            is_bpe = False
                            w = ''
                    if args.output_word_probs:
                        # Modified by xxx
                        #print(
                        #    str(int(sample_id)) + " "
                        #    + ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
                        #)
                        print(
                            str(int(sample_id)) + "|||"
                                + (' '.join('{:2f}'.format(x[1]) for x in word_prob))
                        )


            wps_meter.update(sample['ntokens'])
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            print(ws)
Beispiel #27
0
def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('fairseq_cli.generate')

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(os.pathsep),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    args.vocab_size = len(tgt_dict)
    for arg in vars(_model_args).keys():
        if arg in {'decoder_embed_dim', 'vocab_size'}:
            setattr(args, arg, getattr(_model_args, arg))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    if args.knnlm and args.save_knnlm_dstore:
        raise ValueError(
            "Cannot use knnlm while trying to build the datastore!")

    if args.knnlm:
        knn_dstore = KNN_Dstore(args)

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())

    if args.save_knnlm_dstore:
        print('keytype being saved:', args.knn_keytype)
        if args.dstore_fp16:
            print('Saving fp16')
            dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy',
                                    dtype=np.float16,
                                    mode='w+',
                                    shape=(args.dstore_size,
                                           args.decoder_embed_dim))
            dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy',
                                    dtype=np.int16,
                                    mode='w+',
                                    shape=(args.dstore_size, 1))
        else:
            print('Saving fp32')
            dstore_keys = np.memmap(args.dstore_mmap + '_keys.npy',
                                    dtype=np.float32,
                                    mode='w+',
                                    shape=(args.dstore_size,
                                           args.decoder_embed_dim))
            dstore_vals = np.memmap(args.dstore_mmap + '_vals.npy',
                                    dtype=np.int,
                                    mode='w+',
                                    shape=(args.dstore_size, 1))
    dstore_idx = 0

    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            if args.knnlm:
                hypos = task.inference_step(generator,
                                            models,
                                            sample,
                                            prefix_tokens,
                                            knn_dstore=knn_dstore)
            else:
                hypos = task.inference_step(generator, models, sample,
                                            prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            if args.save_knnlm_dstore:
                for i, hypos_i in enumerate(hypos):
                    hypo = hypos_i[0]
                    shape = hypo['dstore_keys'].shape
                    if dstore_idx + shape[0] > args.dstore_size:
                        shape = [args.dstore_size - dstore_idx]
                        hypo['dstore_keys'] = hypo['dstore_keys'][:shape[0]]
                    # import pdb; pdb.set_trace()
                    # print(hypo)
                    if args.dstore_fp16:
                        dstore_keys[dstore_idx:shape[0] +
                                    dstore_idx] = hypo['dstore_keys'].view(
                                        -1, args.decoder_embed_dim).cpu(
                                        ).numpy().astype(np.float16)
                        dstore_vals[dstore_idx:shape[0] +
                                    dstore_idx] = hypo['tokens'].view(
                                        -1, 1).cpu().numpy().astype(np.int16)
                    else:
                        dstore_keys[dstore_idx:shape[0] +
                                    dstore_idx] = hypo['dstore_keys'].view(
                                        -1, args.decoder_embed_dim).cpu(
                                        ).numpy().astype(np.float32)
                        dstore_vals[dstore_idx:shape[0] +
                                    dstore_idx] = hypo['tokens'].view(
                                        -1, 1).cpu().numpy().astype(np.int)
                    dstore_idx += shape[0]

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str),
                              file=output_file)
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str),
                              file=output_file)

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        score = hypo['score'] / math.log(
                            2)  # convert to base 2
                        print('H-{}\t{}\t{}'.format(sample_id, score,
                                                    hypo_str),
                              file=output_file)
                        print(
                            'P-{}\t{}'.format(
                                sample_id,
                                ' '.join(
                                    map(
                                        lambda x: '{:.4f}'.format(x),
                                        # convert from base e to base 2
                                        hypo['positional_scores'].div_(
                                            math.log(2)).tolist(),
                                    ))),
                            file=output_file)

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join([
                                    '{}-{}'.format(src_idx, tgt_idx)
                                    for src_idx, tgt_idx in alignment
                                ])),
                                  file=output_file)

                        if args.print_step:
                            print('I-{}\t{}'.format(sample_id, hypo['steps']),
                                  file=output_file)

                        if getattr(args, 'retain_iter_history', False):
                            for step, h in enumerate(hypo['history']):
                                _, h_str, _ = utils.post_process_prediction(
                                    hypo_tokens=h['tokens'].int().cpu(),
                                    src_str=src_str,
                                    alignment=None,
                                    align_dict=None,
                                    tgt_dict=tgt_dict,
                                    remove_bpe=None,
                                )
                                print('E-{}_{}\t{}'.format(
                                    sample_id, step, h_str),
                                      file=output_file)

                    # Score only the top hypothesis
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info(
        'Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        logger.info('Generate {} with beam={}: {}'.format(
            args.gen_subset, args.beam, scorer.result_string()))

    if args.save_knnlm_dstore:
        print("dstore_idx", dstore_idx, "final shape", shape)
        print("Keys", dstore_keys.shape, dstore_keys.dtype)
        print("Vals", dstore_vals.shape, dstore_vals.dtype)

    return scorer
Beispiel #28
0
def backtranslate_samples(samples,
                          collate_fn,
                          generate_fn,
                          cuda=True,
                          noising=None):
    """Backtranslate a list of samples.

    Given an input (*samples*) of the form:

        [{'id': 1, 'source': 'hallo welt'}]

    this will return:

        [{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}]

    Args:
        samples (List[dict]): samples to backtranslate. Individual samples are
            expected to have a 'source' key, which will become the 'target'
            after backtranslation.
        collate_fn (callable): function to collate samples into a mini-batch
        generate_fn (callable): function to generate backtranslations
        cuda (bool): use GPU for generation (default: ``True``)

    Returns:
        List[dict]: an updated list of samples with a backtranslated source
    """
    collated_samples = collate_fn(samples)
    s = utils.move_to_cuda(collated_samples) if cuda else collated_samples

    generated_sources = generate_fn(s)

    id_to_src = {sample['id']: sample['source'] for sample in samples}

    # Go through each tgt sentence in batch and its corresponding best
    # generated hypothesis and create a backtranslation data pair
    # {id: id, source: generated backtranslation, target: original tgt}
    #return samples

    ret_samples = [{
        'id': id.item(),
        'target': id_to_src[id.item()],
        'source': hypos[0]['tokens'].cpu()
    } for id, hypos in zip(collated_samples['id'], generated_sources)]
    if noising is not None:
        backward_samples = []
        for id, hypos in zip(collated_samples['id'], generated_sources):
            s = id_to_src[id.item()]
            src_len = torch.LongTensor([s.size(0)])
            s = s.unsqueeze(1)

            ns = noising.noising(s, src_len)
            ns = torch.t(ns)[0]

            backward_samples.append({
                'id': id.item(),
                'source': ns,
                'target': hypos[0]['tokens'].cpu()
            })
    else:
        backward_samples = [{
            'id': id.item(),
            'source': id_to_src[id.item()],
            'target': hypos[0]['tokens'].cpu()
        } for id, hypos in zip(collated_samples['id'], generated_sources)]
    return ret_samples, backward_samples
Beispiel #29
0
def main(args):
    check_args(args)
    import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 30000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    logger.info("| {} {} {} examples".format(
        args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Set dictionary
    tgt_dict = task.target_dictionary

    logger.info("| decoding with criterion {}".format(args.criterion))

    # Load ensemble
    logger.info("| loading model(s) from {}".format(args.path))
    models, criterions, _model_args = load_models_and_criterions(
        args.path.split(os.pathsep),
        arg_overrides=eval(args.model_overrides),  # noqa
        task=task,
    )
    optimize_models(args, use_cuda, models)

    # hack to pass transitions to W2lDecoder
    if args.criterion == "asg_loss":
        trans = criterions[0].asg.trans.data
        args.asg_transitions = torch.flatten(trans).tolist()

    # Load dataset (possibly sharded)
    itr = get_dataset_itr(args, task)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        default_log_format=('tqdm' if not args.no_progress_bar else 'none'),
    )

    # Initialize generator
    gen_timer = meters.StopwatchMeter()
    generator = task.build_generator(args)

    num_sentences = 0

    if not os.path.exists(args.results_path):
        os.makedirs(args.results_path)

    sp = spm.SentencePieceProcessor()
    sp.Load(os.path.join(args.data, "spm.model"))

    res_files = prepare_result_files(args)
    wps_meter = meters.TimeMeter()
    for sample in progress:
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if "net_input" not in sample:
            continue

        prefix_tokens = None
        if args.prefix_size > 0:
            prefix_tokens = sample["target"][:, :args.prefix_size]

        gen_timer.start()
        hypos = task.inference_step(generator, models, sample, prefix_tokens)
        num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
        gen_timer.stop(num_generated_tokens)

        for i, sample_id in enumerate(sample["id"].tolist()):
            speaker = task.dataset(args.gen_subset).speakers[int(sample_id)]
            id = task.dataset(args.gen_subset).ids[int(sample_id)]
            target_tokens = (utils.strip_pad(sample["target"][i, :],
                                             tgt_dict.pad()).int().cpu())
            # Process top predictions
            process_predictions(args, hypos[i], sp, tgt_dict, target_tokens,
                                res_files, speaker, id)

        wps_meter.update(num_generated_tokens)
        progress.log({"wps": round(wps_meter.avg)})
        num_sentences += sample["nsentences"]

    logger.info("| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
                "sentences/s, {:.2f} tokens/s)".format(
                    num_sentences,
                    gen_timer.n,
                    gen_timer.sum,
                    num_sentences / gen_timer.sum,
                    1.0 / gen_timer.avg,
                ))
    logger.info("| Generate {} with beam={}".format(args.gen_subset,
                                                    args.beam))
Beispiel #30
0
def main(args):
    assert args.path is not None, '--path required for recognition!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset split
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionary
    dictionary = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )
    for i, m in enumerate(models):
        if hasattr(m, 'is_wordlm') and m.is_wordlm:
            # assume subword LM comes before word LM
            if isinstance(models[i - 1], FairseqLanguageModel):
                models[i-1] = MultiLevelLanguageModel(
                    m, models[i-1],
                    subwordlm_weight=args.subwordlm_weight,
                    oov_penalty=args.oov_penalty,
                    open_vocab=not args.disable_open_vocab,
                )
                del models[i]
                print('| LM fusion with Multi-level LM')
            else:
                models[i] = TensorizedLookaheadLanguageModel(
                    m, dictionary,
                    oov_penalty=args.oov_penalty,
                    open_vocab=not args.disable_open_vocab,
                )
                print('| LM fusion with Look-ahead Word LM')
        # assume subword LM comes after E2E models
        elif i == len(models) - 1 and isinstance(m, FairseqLanguageModel):
            print('| LM fusion with Subword LM')
    if args.lm_weight != 0.0:
        print('| using LM fusion with lm-weight={:.2f}'.format(args.lm_weight))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() if hasattr(model, 'encoder')
              else (None, model.max_positions()) for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    if args.match_source_len:
        print('| The option match_source_len is not applicable to '
              'speech recognition. Ignoring it.')
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Generate and compute WER
    scorer = wer.Scorer(dictionary, wer_output_filter=args.wer_output_filter)
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(
                generator, models, sample, prefix_tokens, lm_weight=args.lm_weight,
            )
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            # obtain nonpad mask of encoder output to plot attentions
            if args.print_alignment:
                net_input = sample['net_input']
                src_tokens = net_input['src_tokens']
                output_lengths = models[0].encoder.output_lengths(net_input['src_lengths'])
                nonpad_idxs = sequence_mask(output_lengths, models[0].encoder.output_lengths(src_tokens.size(1)))

            for i in range(len(sample['id'])):
                has_target = sample['target'] is not None
                utt_id = sample['utt_id'][i]

                # Retrieve the original sentences
                if has_target:
                    target_str = sample['target_raw_text'][i]
                    if not args.quiet:
                        target_sent = dictionary.tokens_to_sentence(
                            target_str, use_unk_sym=False, bpe_symbol=args.remove_bpe,
                        )
                        print('T-{}\t{}'.format(utt_id, target_sent))

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_str = dictionary.string(hypo['tokens'].int().cpu())  # not removing bpe at this point
                    if not args.quiet or i == 0:
                        hypo_sent = dictionary.tokens_to_sentence(hypo_str, bpe_symbol=args.remove_bpe)

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(utt_id, hypo_sent, hypo['score']))

                    # Score and obtain attention only the top hypothesis
                    if j == 0:
                        # src_len x tgt_len
                        attention = hypo['attention'][nonpad_idxs[i]].float().cpu() \
                            if args.print_alignment and hypo['attention'] is not None else None
                        if args.print_alignment and attention is not None:
                            save_dir = os.path.join(args.results_path, 'attn_plots')
                            os.makedirs(save_dir, exist_ok=True)
                            plot_attention(attention, hypo_sent, utt_id, save_dir)
                        scorer.add_prediction(utt_id, hypo_str, bpe_symbol=args.remove_bpe)
                        if has_target:
                            scorer.add_evaluation(utt_id, target_str, hypo_str, bpe_symbol=args.remove_bpe)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    print('| Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if args.print_alignment:
        print('| Saved attention plots in ' + save_dir)

    if has_target:
        assert args.test_text_files is not None
        scorer.add_ordered_utt_list(*args.test_text_files)

    os.makedirs(args.results_path, exist_ok=True)

    fn = 'decoded_char_results.txt'
    with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f:
        f.write(scorer.print_char_results())
        print('| Decoded char results saved as ' + f.name)

    fn = 'decoded_results.txt'
    with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f:
        f.write(scorer.print_results())
        print('| Decoded results saved as ' + f.name)

    if has_target:
        header = ' Recognize {} with beam={}: '.format(args.gen_subset, args.beam)
        fn = 'wer'
        with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f:
            res = 'WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format(
                *(scorer.wer()))
            print('|' + header + res)
            f.write(res + '\n')
            print('| WER saved in ' + f.name)

        fn = 'cer'
        with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f:
            res = 'CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format(
                *(scorer.cer()))
            print('|' + ' ' * len(header) + res)
            f.write(res + '\n')
            print('| CER saved in ' + f.name)

        fn = 'aligned_results.txt'
        with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f:
            f.write(scorer.print_aligned_results())
            print('| Aligned results saved as ' + f.name)
    return scorer
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset,
                                       len(task.dataset(args.gen_subset))))

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(':'),
        task,
        model_arg_overrides=eval(args.model_overrides),
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # create dir
    file_head = os.getcwd() + "/result/" + args.results_path
    result_output_list = [
        "/systems/", "/models/", "/alignments/", "-nounk/systems/",
        "-nounk/models/", "/attention/"
    ]
    for item in result_output_list:
        if not os.path.exists(file_head + item):
            os.makedirs(file_head + item)

    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str))
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str))

                # Process top predictions
                for idx, hypo in enumerate(
                        hypos[i][:min(len(hypos), args.nbest)]):

                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'].int().cpu()
                        if hypo['alignment'] is not None else None,
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        print('H-{}\t{}\t{}'.format(sample_id, hypo['score'],
                                                    hypo_str))
                        print('P-{}\t{}'.format(
                            sample_id, ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    hypo['positional_scores'].tolist(),
                                ))))

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join(
                                    map(lambda x: str(utils.item(x)),
                                        alignment))))

                    # Score only the top hypothesis
                    if has_target and idx == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)

                    # unified identifier for unk
                    # target_str = target_str.replace("<<unk>>","UNK")
                    hypo_str = hypo_str.replace("<unk>", "UNK")

                    # generate results
                    with open(
                            file_head + "/systems/system." + str(sample_id) +
                            ".txt", "w") as f:
                        f.writelines(target_str)
                    with open(
                            file_head + "/models/model." + str(sample_id) +
                            ".txt", "w") as f:
                        f.writelines(hypo_str)

                    # generte attention alignments
                    # with open(file_head + "/alignments/alignment." + str(sample_id) + ".txt", "w") as f:
                    #     f.writelines(' '.join(map(lambda x: str(utils.item(x)), alignment)))

                    # remove unk, generate results without unk
                    # target_str = target_str.replace("UNK","")
                    hypo_str = hypo_str.replace("UNK", "")
                    with open(
                            file_head + "-nounk/systems/system." +
                            str(sample_id) + ".txt", "w") as f:
                        f.writelines(target_str)
                    with open(
                            file_head + "-nounk/models/model." +
                            str(sample_id) + ".txt", "w") as f:
                        f.writelines(hypo_str)

                    # save attention distribution
                    # pickle.dump(hypo['attention'], open(file_head + "/attention/attention_distribution_"+str(sample_id)+".dat", "wb"), True)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    print(
        '| Summarized {} articles ({} tokens) in {:.1f}s ({:.2f} articles/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}'.format(args.gen_subset, args.beam))
    return None
Beispiel #32
0
 def _prepare_sample(self, sample):
     if sample is None or len(sample) == 0:
         return None
     if self.cuda:
         sample = utils.move_to_cuda(sample)
     return sample
Beispiel #33
0
    def transcribe(self, wav_files):
        process_dir = uuid.uuid1().hex
        process_dir = os.path.join(self.temp_path, process_dir)
        os.makedirs(process_dir)
        self.args.data = process_dir
        self.args.gen_subset = 'test'
        self.args.results_path = process_dir
        copy2(self.args.w2vec_dict, process_dir)

        test_words = os.path.join(process_dir, 'test.wrd')
        test_letters = os.path.join(process_dir, 'test.ltr')
        test_map = os.path.join(process_dir, 'test.tsv')

        paths = [os.path.abspath(d) for d in wav_files]
        for i in range(0, len(paths)):
            audio_info = soundfile.info(paths[i])
            frames = audio_info.frames
            paths[i] = paths[i] + '\t' + str(frames)

        words = ['THIS IS A SAMPLE'] * len(paths)
        letters = [d.replace(' ', '|') for d in words]
        letters = [' '.join(list(d)) + ' |' for d in letters]

        with open(test_words, 'w') as f:
            f.write('\n'.join(words))

        with open(test_letters, 'w') as f:
            f.write('\n'.join(letters))

        with open(test_map, 'w') as f:
            f.write('\n')
            f.write('\n'.join(paths))

        args = self.args

        if args.max_tokens is None and args.batch_size is None:
            args.max_tokens = 4000000

        use_cuda = torch.cuda.is_available() and not args.cpu
        task = tasks.setup_task(args)

        if self.state is None:
            state = checkpoint_utils.load_checkpoint_to_cpu(args.path, None)
            state['cfg']['model']['w2v_path'] = self.pretrain_model
            state['cfg']['generation']['beam'] = self.beam_size
            self.state = state
        else:
            state = self.state

        if self.models is None:
            models, saved_cfg = checkpoint_utils.load_model_ensemble(
                utils.split_paths(args.path),
                arg_overrides=ast.literal_eval(args.model_overrides),
                task=task,
                suffix=args.checkpoint_suffix,
                strict=(args.checkpoint_shard_count == 1),
                num_shards=args.checkpoint_shard_count,
                state=state,
            )
            self.models, self.saved_cfg = models, saved_cfg
        else:
            models, saved_cfg = self.models, self.saved_cfg

        optimize_models(args, use_cuda, models)
        task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task)

        # Set dictionary
        tgt_dict = task.target_dictionary

        # hack to pass transitions to W2lDecoder
        if args.criterion == "asg_loss":
            raise NotImplementedError("asg_loss is currently not supported")
            # trans = criterions[0].asg.trans.data
            # args.asg_transitions = torch.flatten(trans).tolist()

        # Load dataset (possibly sharded)
        itr = get_dataset_itr(args, task, models)

        # Initialize generator
        gen_timer = StopwatchMeter()

        def build_generator(args):
            w2l_decoder = getattr(args, "w2l_decoder", None)
            if w2l_decoder == "viterbi":
                from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder

                return W2lViterbiDecoder(args, task.target_dictionary)
            elif w2l_decoder == "kenlm":
                from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder

                return W2lKenLMDecoder(args, task.target_dictionary)
            elif w2l_decoder == "fairseqlm":
                from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder

                return W2lFairseqLMDecoder(args, task.target_dictionary)
            else:
                print(
                    "only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
                )

        # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
        if self.generator is None:
            generator = build_generator(args)
        else:
            generator = self.generator

        if args.load_emissions:
            generator = ExistingEmissionsDecoder(
                generator, np.load(args.load_emissions, allow_pickle=True))

        num_sentences = 0

        if args.results_path is not None and not os.path.exists(
                args.results_path):
            os.makedirs(args.results_path)

        max_source_pos = (utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]), )

        if max_source_pos is not None:
            max_source_pos = max_source_pos[0]
            if max_source_pos is not None:
                max_source_pos = max_source_pos[0] - 1

        if args.dump_emissions:
            emissions = {}
        if args.dump_features:
            features = {}
            models[0].bert.proj = None
        else:
            res_files = prepare_result_files(args)
        errs_t = 0
        lengths_t = 0
        with progress_bar.build_progress_bar(args, itr) as t:
            wps_meter = TimeMeter()
            for sample in t:
                sample = utils.move_to_cuda(sample) if use_cuda else sample
                if "net_input" not in sample:
                    continue

                prefix_tokens = None
                if args.prefix_size > 0:
                    prefix_tokens = sample["target"][:, :args.prefix_size]

                gen_timer.start()
                if args.dump_emissions:
                    with torch.no_grad():
                        encoder_out = models[0](**sample["net_input"])
                        emm = models[0].get_normalized_probs(encoder_out,
                                                             log_probs=True)
                        emm = emm.transpose(0, 1).cpu().numpy()
                        for i, id in enumerate(sample["id"]):
                            emissions[id.item()] = emm[i]
                        continue
                elif args.dump_features:
                    with torch.no_grad():
                        encoder_out = models[0](**sample["net_input"])
                        feat = encoder_out["encoder_out"].transpose(
                            0, 1).cpu().numpy()
                        for i, id in enumerate(sample["id"]):
                            padding = (encoder_out["encoder_padding_mask"]
                                       [i].cpu().numpy()
                                       if encoder_out["encoder_padding_mask"]
                                       is not None else None)
                            features[id.item()] = (feat[i], padding)
                        continue
                hypos = task.inference_step(generator, models, sample,
                                            prefix_tokens)
                num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
                gen_timer.stop(num_generated_tokens)

                for i, sample_id in enumerate(sample["id"].tolist()):
                    speaker = None
                    # id = task.dataset(args.gen_subset).ids[int(sample_id)]
                    id = sample_id
                    toks = (sample["target"][i, :] if "target_label"
                            not in sample else sample["target_label"][i, :])
                    target_tokens = utils.strip_pad(
                        toks, tgt_dict.pad()).int().cpu()
                    # Process top predictions
                    errs, length = process_predictions(
                        args,
                        hypos[i],
                        None,
                        tgt_dict,
                        target_tokens,
                        res_files,
                        speaker,
                        id,
                    )
                    errs_t += errs
                    lengths_t += length

                wps_meter.update(num_generated_tokens)
                t.log({"wps": round(wps_meter.avg)})
                num_sentences += (sample["nsentences"] if "nsentences"
                                  in sample else sample["id"].numel())

        wer = None
        if args.dump_emissions:
            emm_arr = []
            for i in range(len(emissions)):
                emm_arr.append(emissions[i])
            np.save(args.dump_emissions, emm_arr)
        elif args.dump_features:
            feat_arr = []
            for i in range(len(features)):
                feat_arr.append(features[i])
            np.save(args.dump_features, feat_arr)
        else:
            if lengths_t > 0:
                wer = errs_t * 100.0 / lengths_t

        hypo_file = [
            file for file in os.listdir(process_dir) if 'hypo.word' in file
        ][0]
        hypo_file = os.path.join(process_dir, hypo_file)

        with open(hypo_file) as f:
            hypos = f.read().splitlines()

        for i in range(0, len(hypos)):
            words = ' '.join(hypos[i].split()[:-1])
            idx_ = hypos[i].split()[-1].split('-')[1][:-1]
            hypos[i] = (words, int(idx_))

        hypos = sorted(hypos, key=lambda x: x[1])
        hypos = [h[0] for h in hypos]

        os.system('rm -rf ' + process_dir)
        return hypos