示例#1
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    torch.manual_seed(
        params.seed
    )  # Set random seed. NB: Multi-GPU also needs torch.cuda.manual_seed_all(params.seed)
    assert (params.sample_temperature
            == 0) or (params.beam_size == 1), 'Cannot sample with beam search.'
    assert params.amp <= 1, f'params.amp == {params.amp} not yet supported.'
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in [
            'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
            'mask_index'
    ]:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    encoder = TransformerModel(model_params,
                               dico,
                               is_encoder=True,
                               with_output=False).cuda().eval()
    decoder = TransformerModel(model_params,
                               dico,
                               is_encoder=False,
                               with_output=True).cuda().eval()
    if all([k.startswith('module.') for k in reloaded['encoder'].keys()]):
        reloaded['encoder'] = {
            k[len('module.'):]: v
            for k, v in reloaded['encoder'].items()
        }
    encoder.load_state_dict(reloaded['encoder'])
    if all([k.startswith('module.') for k in reloaded['decoder'].keys()]):
        reloaded['decoder'] = {
            k[len('module.'):]: v
            for k, v in reloaded['decoder'].items()
        }
    decoder.load_state_dict(reloaded['decoder'])

    if params.amp != 0:
        models = apex.amp.initialize([encoder, decoder],
                                     opt_level=('O%i' % params.amp))
        encoder, decoder = models

    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # read sentences from stdin
    src_sent = []
    for line in sys.stdin.readlines():
        assert len(line.strip().split()) > 0
        src_sent.append(line)
    logger.info("Read %i sentences from stdin. Translating ..." %
                len(src_sent))

    # f = io.open(params.output_path, 'w', encoding='utf-8')

    hypothesis = [[] for _ in range(params.beam_size)]
    for i in range(0, len(src_sent), params.batch_size):

        # prepare batch
        word_ids = [
            torch.LongTensor([dico.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + params.batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it
        encoded = encoder('fwd',
                          x=batch.cuda(),
                          lengths=lengths.cuda(),
                          langs=langs.cuda(),
                          causal=False)
        encoded = encoded.transpose(0, 1)
        max_len = int(1.5 * lengths.max().item() + 10)
        if params.beam_size == 1:
            decoded, dec_lengths = decoder.generate(
                encoded,
                lengths.cuda(),
                params.tgt_id,
                max_len=max_len,
                sample_temperature=(None if params.sample_temperature == 0 else
                                    params.sample_temperature))
        else:
            decoded, dec_lengths, all_hyp_strs = decoder.generate_beam(
                encoded,
                lengths.cuda(),
                params.tgt_id,
                beam_size=params.beam_size,
                length_penalty=params.length_penalty,
                early_stopping=params.early_stopping,
                max_len=max_len,
                output_all_hyps=True)
        # hypothesis.extend(convert_to_text(decoded, dec_lengths, dico, params))

        # convert sentences to words
        for j in range(decoded.size(1)):

            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = src_sent[i + j].strip().replace('<unk>', '<<unk>>')
            target = " ".join([dico[sent[k].item()] for k in range(len(sent))
                               ]).replace('<unk>', '<<unk>>')
            if params.beam_size == 1:
                hypothesis[0].append(target)
            else:
                for hyp_rank in range(params.beam_size):
                    print(
                        all_hyp_strs[j]
                        [hyp_rank if hyp_rank < len(all_hyp_strs[j]) else -1])
                    hypothesis[hyp_rank].append(
                        all_hyp_strs[j]
                        [hyp_rank if hyp_rank < len(all_hyp_strs[j]) else -1])

            sys.stderr.write("%i / %i: %s -> %s\n" %
                             (i + j, len(src_sent), source.replace(
                                 '@@ ', ''), target.replace('@@ ', '')))
            # f.write(target + "\n")

    # f.close()

    # export sentences to reference and hypothesis files / restore BPE segmentation
    save_dir, split = params.output_path.rsplit('/', 1)
    for hyp_rank in range(len(hypothesis)):
        hyp_name = f'hyp.st={params.sample_temperature}.bs={params.beam_size}.lp={params.length_penalty}.es={params.early_stopping}.seed={params.seed if (len(hypothesis) == 1) else str(hyp_rank)}.{params.src_lang}-{params.tgt_lang}.{split}.txt'
        hyp_path = os.path.join(save_dir, hyp_name)
        with open(hyp_path, 'w', encoding='utf-8') as f:
            f.write('\n'.join(hypothesis[hyp_rank]) + '\n')
        restore_segmentation(hyp_path)

        # evaluate BLEU score
        if params.ref_path:
            bleu = eval_moses_bleu(params.ref_path, hyp_path)
            logger.info("BLEU %s %s : %f" % (hyp_path, params.ref_path, bleu))
示例#2
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    model_params.add_pred = ""
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in [
            'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
            'mask_index'
    ]:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    src_dico = load_binarized(params.src_data)
    tgt_dico = load_binarized(params.tgt_data)

    encoder = TransformerModel(model_params,
                               src_dico,
                               is_encoder=True,
                               with_output=False).cuda().eval()
    decoder = TransformerModel(model_params,
                               tgt_dico,
                               is_encoder=False,
                               with_output=True).cuda().eval()

    if all([k.startswith('module.') for k in reloaded['encoder'].keys()]):
        reloaded['encoder'] = {
            k[len('module.'):]: v
            for k, v in reloaded['encoder'].items()
        }
        reloaded['decoder'] = {
            k[len('module.'):]: v
            for k, v in reloaded['decoder'].items()
        }

    encoder.load_state_dict(reloaded['encoder'], strict=False)
    decoder.load_state_dict(reloaded['decoder'], strict=False)

    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # # float16

    # # read sentences from stdin
    src_sent = []
    input_f = open(params.input_path, 'r')
    for line in input_f:
        line = line.strip()
        assert len(line.strip().split()) > 0
        src_sent.append(line)
    logger.info("Read %i sentences from stdin. Translating ..." %
                len(src_sent))

    f = io.open(params.output_path, 'w', encoding='utf-8')

    for i in range(0, len(src_sent), params.batch_size):

        # prepare batch
        word_ids = [
            torch.LongTensor([src_dico.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + params.batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it
        encoded = encoder('fwd',
                          x=batch.cuda(),
                          lengths=lengths.cuda(),
                          langs=langs.cuda(),
                          causal=False)
        encoded = [enc.transpose(0, 1) for enc in encoded]
        decoded, dec_lengths = decoder.generate(
            encoded,
            lengths.cuda(),
            params.tgt_id,
            max_len=int(1.5 * lengths.max().item() + 10))

        # convert sentences to words
        for j in range(decoded.size(1)):

            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = src_sent[i + j].strip()
            target = " ".join(
                [tgt_dico[sent[k].item()] for k in range(len(sent))])
            #sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target))
            f.write(target + "\n")

    f.close()
示例#3
0
class Translate():
    def __init__(self,
                 model_path,
                 tgt_lang,
                 src_lang,
                 dump_path="./dumped/",
                 exp_name="translate",
                 exp_id="test",
                 batch_size=32):

        # parse parameters
        parser = argparse.ArgumentParser(description="Translate sentences")

        # main parameters
        parser.add_argument("--dump_path",
                            type=str,
                            default=dump_path,
                            help="Experiment dump path")
        parser.add_argument("--exp_name",
                            type=str,
                            default=exp_name,
                            help="Experiment name")
        parser.add_argument("--exp_id",
                            type=str,
                            default=exp_id,
                            help="Experiment ID")
        parser.add_argument("--batch_size",
                            type=int,
                            default=batch_size,
                            help="Number of sentences per batch")
        # model / output paths
        parser.add_argument("--model_path",
                            type=str,
                            default=model_path,
                            help="Model path")
        # parser.add_argument("--max_vocab", type=int, default=-1, help="Maximum vocabulary size (-1 to disable)")
        # parser.add_argument("--min_count", type=int, default=0, help="Minimum vocabulary count")
        # source language / target language
        parser.add_argument("--src_lang",
                            type=str,
                            default=src_lang,
                            help="Source language")
        parser.add_argument("--tgt_lang",
                            type=str,
                            default=tgt_lang,
                            help="Target language")
        parser.add_argument('-d',
                            "--text",
                            type=str,
                            default="",
                            nargs='+',
                            help="Text to be translated")

        params = parser.parse_args()
        assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang

        # initialize the experiment
        logger = initialize_exp(params)

        # On a pas de GPU
        #reloaded = torch.load(params.model_path)
        reloaded = torch.load(params.model_path,
                              map_location=torch.device('cpu'))
        model_params = AttrDict(reloaded['params'])
        self.supported_languages = model_params.lang2id.keys()
        logger.info("Supported languages: %s" %
                    ", ".join(self.supported_languages))

        # update dictionary parameters
        for name in [
                'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
                'mask_index'
        ]:
            try:
                setattr(params, name, getattr(model_params, name))
            except AttributeError:
                key = list(model_params.meta_params.keys())[0]
                attr = getattr(model_params.meta_params[key], name)
                setattr(params, name, attr)
                setattr(model_params, name, attr)

        # build dictionary / build encoder / build decoder / reload weights
        self.dico = Dictionary(reloaded['dico_id2word'],
                               reloaded['dico_word2id'],
                               reloaded['dico_counts'])
        #self.encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()
        self.encoder = TransformerModel(model_params,
                                        self.dico,
                                        is_encoder=True,
                                        with_output=True).eval()
        #self.decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
        self.decoder = TransformerModel(model_params,
                                        self.dico,
                                        is_encoder=False,
                                        with_output=True).eval()
        self.encoder.load_state_dict(reloaded['encoder'])
        self.decoder.load_state_dict(reloaded['decoder'])
        params.src_id = model_params.lang2id[params.src_lang]
        params.tgt_id = model_params.lang2id[params.tgt_lang]
        self.model_params = model_params
        self.params = params

    def translate(self, src_sent=[]):
        flag = False
        if type(src_sent) == str:
            src_sent = [src_sent]
            flag = True
        tgt_sent = []
        for i in range(0, len(src_sent), self.params.batch_size):
            # prepare batch
            word_ids = [
                torch.LongTensor(
                    [self.dico.index(w) for w in s.strip().split()])
                for s in src_sent[i:i + self.params.batch_size]
            ]
            lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
            batch = torch.LongTensor(lengths.max().item(),
                                     lengths.size(0)).fill_(
                                         self.params.pad_index)
            batch[0] = self.params.eos_index
            for j, s in enumerate(word_ids):
                if lengths[j] > 2:  # if sentence not empty
                    batch[1:lengths[j] - 1, j].copy_(s)
                batch[lengths[j] - 1, j] = self.params.eos_index
            langs = batch.clone().fill_(self.params.src_id)

            # encode source batch and translate it
            #encoded = self.encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False)
            encoded = self.encoder('fwd',
                                   x=batch,
                                   lengths=lengths,
                                   langs=langs,
                                   causal=False)
            encoded = encoded.transpose(0, 1)
            #decoded, dec_lengths = self.decoder.generate(encoded, lengths.cuda(), self.params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10))
            decoded, dec_lengths = self.decoder.generate(
                encoded,
                lengths,
                self.params.tgt_id,
                max_len=int(1.5 * lengths.max().item() + 10))

            # convert sentences to words
            for j in range(decoded.size(1)):

                # remove delimiters
                sent = decoded[:, j]
                delimiters = (sent == self.params.eos_index).nonzero().view(-1)
                assert len(delimiters) >= 1 and delimiters[0].item() == 0
                sent = sent[1:] if len(
                    delimiters) == 1 else sent[1:delimiters[1]]

                # output translation
                source = src_sent[i + j].strip()
                target = " ".join(
                    [self.dico[sent[k].item()] for k in range(len(sent))])
                sys.stderr.write("%i / %i: %s -> %s\n" %
                                 (i + j, len(src_sent), source, target))
                tgt_sent.append(target)

        if flag:
            return tgt_sent[0]
        return tgt_sent
示例#4
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in [
            'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
            'mask_index'
    ]:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    encoder = TransformerModel(model_params,
                               dico,
                               is_encoder=True,
                               with_output=True).cuda().eval()
    decoder = TransformerModel(model_params,
                               dico,
                               is_encoder=False,
                               with_output=True).cuda().eval()
    encoder.load_state_dict(package_module(reloaded['encoder']))
    decoder.load_state_dict(package_module(reloaded['decoder']))
    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # float16
    if params.fp16:
        assert torch.backends.cudnn.enabled
        encoder = network_to_half(encoder)
        decoder = network_to_half(decoder)

    # read sentences from stdin
    src_sent = []
    for line in sys.stdin.readlines():
        assert len(line.strip().split()) > 0
        src_sent.append(line)
    logger.info("Read %i sentences from stdin. Translating ..." %
                len(src_sent))

    f = io.open(params.output_path, 'w', encoding='utf-8')

    for i in range(0, len(src_sent), params.batch_size):

        # prepare batch
        word_ids = [
            torch.LongTensor([dico.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + params.batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it
        encoded = encoder('fwd',
                          x=batch.cuda(),
                          lengths=lengths.cuda(),
                          langs=langs.cuda(),
                          causal=False)
        encoded = encoded.transpose(0, 1)
        if params.beam == 1:
            decoded, dec_lengths = decoder.generate(
                encoded,
                lengths.cuda(),
                params.tgt_id,
                max_len=int(1.5 * lengths.max().item() + 10))
        else:
            decoded, dec_lengths = decoder.generate_beam(
                encoded,
                lengths.cuda(),
                params.tgt_id,
                beam_size=params.beam,
                length_penalty=params.length_penalty,
                early_stopping=False,
                max_len=int(1.5 * lengths.max().item() + 10))

        # convert sentences to words
        for j in range(decoded.size(1)):

            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = src_sent[i + j].strip()
            target = " ".join([dico[sent[k].item()] for k in range(len(sent))])
            sys.stderr.write("%i / %i: %s -> %s\n" %
                             (i + j, len(src_sent), source, target))
            f.write(target + "\n")

    f.close()
示例#5
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
    encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()
    decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
    encoder.load_state_dict(reloaded['encoder'])
    decoder.load_state_dict(reloaded['decoder'])
    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # float16
    if params.fp16:
        assert torch.backends.cudnn.enabled
        encoder = network_to_half(encoder)
        decoder = network_to_half(decoder)

    input_data = torch.load(params.input)
    eval_dataset = Dataset(input_data["sentences"], input_data["positions"], params)

    if params.subset_start is not None:
        assert params.subset_end
        eval_dataset.select_data(params.subset_start, params.subset_end)

    eval_dataset.remove_empty_sentences()
    eval_dataset.remove_long_sentences(params.max_len)

    n_batch = 0

    out = io.open(params.output_path, "w", encoding="utf-8")
    inp_dump = io.open(os.path.join(params.dump_path, "input.txt"), "w", encoding="utf-8")
    logger.info("logging to {}".format(os.path.join(params.dump_path, 'input.txt')))

    with open(params.output_path, "w", encoding="utf-8") as out:

        for batch in eval_dataset.get_iterator(shuffle=False):
            n_batch += 1

            (x1, len1) = batch
            input_text = convert_to_text(x1, len1, input_data["dico"], params)
            inp_dump.write("\n".join(input_text))
            inp_dump.write("\n")

            langs1 = x1.clone().fill_(params.src_id)

            # cuda
            x1, len1, langs1 = to_cuda(x1, len1, langs1)

            # encode source sentence
            enc1 = encoder("fwd", x=x1, lengths=len1, langs=langs1, causal=False)
            enc1 = enc1.transpose(0, 1)

            # generate translation - translate / convert to text
            max_len = int(1.5 * len1.max().item() + 10)
            if params.beam_size == 1:
                generated, lengths = decoder.generate(enc1, len1, params.tgt_id, max_len=max_len)
            else:
                generated, lengths = decoder.generate_beam(
                    enc1, len1, params.tgt_id, beam_size=params.beam_size,
                    length_penalty=params.length_penalty,
                    early_stopping=params.early_stopping,
                    max_len=max_len)

            hypotheses_batch = convert_to_text(generated, lengths, input_data["dico"], params)

            out.write("\n".join(hypotheses_batch))
            out.write("\n")

            if n_batch % 100 == 0:
                logger.info("{} batches processed".format(n_batch))

    out.close()
    inp_dump.close()