示例#1
0
文件: bli.py 项目: 15091444119/MASS
def load_xlm_embeddings(path, model_name="model"):
    """
    Load all xlm embeddings
    Params:
        path:
        model_name: model name in the reloaded path, "model" for pretrained xlm encoder; "encoder" for encoder of translation model "decoder" for decoder of translation model
    """
    reloaded = torch.load(path)

    assert model_name in ["model", "encoder", "decoder"]
    state_dict = reloaded[model_name]

    # handle models from multi-GPU checkpoints
    state_dict = {(k[7:] if k.startswith('module.') else k): v
                  for k, v in state_dict.items()}

    # reload dictionary and model parameters
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    pretrain_params = AttrDict(reloaded['params'])
    pretrain_params.n_words = len(dico)
    pretrain_params.bos_index = dico.index(BOS_WORD)
    pretrain_params.eos_index = dico.index(EOS_WORD)
    pretrain_params.pad_index = dico.index(PAD_WORD)
    pretrain_params.unk_index = dico.index(UNK_WORD)
    pretrain_params.mask_index = dico.index(MASK_WORD)

    # build model and reload weights
    if model_name != "decoder":
        model = TransformerModel(pretrain_params, dico, True, True)
    else:
        model = TransformerModel(pretrain_params, dico, False, True)
    model.load_state_dict(state_dict)

    return model.embeddings.weight.data, dico
示例#2
0
文件: mymodel.py 项目: RunxinXu/XLM
    def reload(path, params):
        """
        Create a sentence embedder from a pretrained model.
        """
        # reload model
        reloaded = torch.load(path)
        state_dict = reloaded['model']

        # handle models from multi-GPU checkpoints
        if 'checkpoint' in path:
            state_dict = {(k[7:] if k.startswith('module.') else k): v
                          for k, v in state_dict.items()}

        # reload dictionary and model parameters
        dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                          reloaded['dico_counts'])
        pretrain_params = AttrDict(reloaded['params'])
        pretrain_params.n_words = len(dico)
        pretrain_params.bos_index = dico.index(BOS_WORD)
        pretrain_params.eos_index = dico.index(EOS_WORD)
        pretrain_params.pad_index = dico.index(PAD_WORD)
        pretrain_params.unk_index = dico.index(UNK_WORD)
        pretrain_params.mask_index = dico.index(MASK_WORD)

        # build model and reload weights
        model = TransformerModel(pretrain_params, dico, True, True)
        model.load_state_dict(state_dict)
        model.eval()

        # adding missing parameters
        params.max_batch_size = 0

        return MyModel(model, dico, pretrain_params, params)
示例#3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-config', type=str, default='config/aishell.yaml')
    parser.add_argument('-log', type=str, default='eval.log')
    parser.add_argument('-mode', type=str, default='retrain')
    opt = parser.parse_args()

    configfile = open(opt.config)
    config = AttrDict(yaml.load(configfile, Loader=yaml.FullLoader))

    exp_name = os.path.join('egs', config.data.name, 'exp', config.model.type,
                            config.training.save_model)
    if not os.path.isdir(exp_name):
        os.makedirs(exp_name)
    logger = init_logger(os.path.join(exp_name, opt.log))

    os.environ["CUDA_VISIBLE_DEVICES"] = config.evaling.gpus
    config.evaling.num_gpu = num_gpus(config.evaling.gpus)
    logger.info('Number of gpu:' + str(config.evaling.num_gpu))
    num_workers = 6 * (config.evaling.num_gpu
                       if config.evaling.num_gpu > 0 else 1)
    batch_size = config.data.batch_size * config.evaling.num_gpu if config.evaling.num_gpu > 0 else config.data.batch_size

    dev_dataset = AudioDataset(config.data, 'test')
    dev_sampler = Batch_RandomSampler(len(dev_dataset),
                                      batch_size=batch_size,
                                      shuffle=config.data.shuffle)
    validate_data = AudioDataLoader(dataset=dev_dataset,
                                    num_workers=num_workers,
                                    batch_sampler=dev_sampler)
    logger.info('Load Test Set!')

    if config.evaling.num_gpu > 0:
        torch.cuda.manual_seed(config.evaling.seed)
        torch.backends.cudnn.deterministic = True
    else:
        torch.manual_seed(config.evaling.seed)
    logger.info('Set random seed: %d' % config.evaling.seed)

    if config.evaling.num_gpu == 0:
        checkpoint = torch.load(config.evaling.load_model, map_location='cpu')
    else:
        checkpoint = torch.load(config.evaling.load_model)
    logger.info(str(checkpoint.keys()))

    with torch.no_grad():
        model = new_model(config, checkpoint).eval()
        beam_rnnt_decoder = build_beam_rnnt_decoder(config, model)
        beamctc_decoder = build_ctc_beam_decoder(config, model)
        if config.evaling.num_gpu > 0:
            model = model.cuda()

        _ = eval(config,
                 model,
                 validate_data,
                 logger,
                 beamctc_decoder=beamctc_decoder,
                 beam_rnnt_decoder=beam_rnnt_decoder)
示例#4
0
def initialize_model():
    """
    """
    print('launching model')

    chemin = getcwd()
    curPath = chemin if "xlm" in chemin else (os.path.join(getcwd(), 'xlm'))

    onlyfiles = [f for f in listdir(chemin) if isfile(join(chemin, f))]
    print(onlyfiles)

    print(os.path.normpath(os.path.join(getcwd(),
                                        './mlm_tlm_xnli15_1024.pth')))
    model_path = os.path.normpath(
        os.path.join(getcwd(), './mlm_tlm_xnli15_1024.pth'))
    reloaded = torch.load(model_path)

    #     print('allez le model')
    #     response = requests.get(url)
    #     print('response downloaded')
    #     f = io.BytesIO(response.content)
    #     reloaded = torch.load(f)
    #     print('file downloaded')

    #    reloaded = Reloaded.serve()

    params = AttrDict(reloaded['params'])
    print("Supported languages: %s" % ", ".join(params.lang2id.keys()))

    # build dictionary / update parameters
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    params.n_words = len(dico)
    params.bos_index = dico.index(BOS_WORD)
    params.eos_index = dico.index(EOS_WORD)
    params.pad_index = dico.index(PAD_WORD)
    params.unk_index = dico.index(UNK_WORD)
    params.mask_index = dico.index(MASK_WORD)

    # build model / reload weights
    model = TransformerModel(params, dico, True, True)
    model.load_state_dict(reloaded['model'])

    #    bpe = fastBPE.fastBPE(
    #            path.normpath(path.join(curPath, "./codes_xnli_15") ),
    #            path.normpath(path.join(curPath, "./vocab_xnli_15") )  )
    print('fin lecture')

    return model, params, dico
示例#5
0
文件: qg.py 项目: PythaGorilla/QAQG
def load_model(params):
    # check parameters
    assert os.path.isdir(params.data_path)
    assert os.path.isfile(params.model_path)
    reloaded = torch.load(params.model_path)

    encoder_model_params = AttrDict(reloaded['enc_params'])
    decoder_model_params = AttrDict(reloaded['dec_params'])

    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])

    params.n_langs = encoder_model_params['n_langs']
    params.id2lang = encoder_model_params['id2lang']
    params.lang2id = encoder_model_params['lang2id']
    params.n_words = len(dico)
    params.bos_index = dico.index(BOS_WORD)
    params.eos_index = dico.index(EOS_WORD)
    params.pad_index = dico.index(PAD_WORD)
    params.unk_index = dico.index(UNK_WORD)
    params.mask_index = dico.index(MASK_WORD)

    encoder = TransformerModel(encoder_model_params,
                               dico,
                               is_encoder=True,
                               with_output=False)
    decoder = TransformerModel(decoder_model_params,
                               dico,
                               is_encoder=False,
                               with_output=True)

    def _process_state_dict(state_dict):
        return {(k[7:] if k.startswith('module.') else k): v
                for k, v in state_dict.items()}

    encoder.load_state_dict(_process_state_dict(reloaded['encoder']))
    decoder.load_state_dict(_process_state_dict(reloaded['decoder']))

    return encoder, decoder, dico
示例#6
0
def reload_ar_checkpoint(path):
    """ Reload autoregressive params, dictionary, model from a given path """
    # Load dictionary/model/datasets first
    reloaded = torch.load(path)
    params = AttrDict(reloaded['params'])

    # build dictionary / update parameters
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    params.n_words = len(dico)
    params.n_langs = 1
    params.bos_index = dico.index(BOS_WORD)
    params.eos_index = dico.index(EOS_WORD)
    params.pad_index = dico.index(PAD_WORD)
    params.unk_index = dico.index(UNK_WORD)
    params.mask_index = dico.index(MASK_WORD)

    # build Transformer model
    model = TransformerModel(params, is_encoder=False, with_output=True)
    model.load_state_dict(reloaded['model'])
    return params, dico, model
示例#7
0
    def __init__(self, model_path, tgt_lang, src_lang,dump_path = "./dumped/", exp_name="translate", exp_id="test", batch_size=32):
        
        # parse parameters
        parser = argparse.ArgumentParser(description="Translate sentences")
        
        # main parameters
        parser.add_argument("--dump_path", type=str, default=dump_path, help="Experiment dump path")
        parser.add_argument("--exp_name", type=str, default=exp_name, help="Experiment name")
        parser.add_argument("--exp_id", type=str, default=exp_id, help="Experiment ID")
        parser.add_argument("--batch_size", type=int, default=batch_size, help="Number of sentences per batch")
        # model / output paths
        parser.add_argument("--model_path", type=str, default=model_path, help="Model path")
        # parser.add_argument("--max_vocab", type=int, default=-1, help="Maximum vocabulary size (-1 to disable)")
        # parser.add_argument("--min_count", type=int, default=0, help="Minimum vocabulary count")
        # source language / target language
        parser.add_argument("--src_lang", type=str, default=src_lang, help="Source language")
        parser.add_argument("--tgt_lang", type=str, default=tgt_lang, help="Target language")

        params = parser.parse_args()
        assert params.src_lang != '' and params.tgt_lang != '' and params.src_lang != params.tgt_lang

        # initialize the experiment
        logger = initialize_exp(params)
        
        # On a pas de GPU
        #reloaded = torch.load(params.model_path)
        reloaded = torch.load(params.model_path, map_location=torch.device('cpu'))
        model_params = AttrDict(reloaded['params'])
        self.supported_languages = model_params.lang2id.keys() 
        logger.info("Supported languages: %s" % ", ".join(self.supported_languages))

        # update dictionary parameters
        for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
            setattr(params, name, getattr(model_params, name))

        # build dictionary / build encoder / build decoder / reload weights
        self.dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
        #self.encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()
        self.encoder = TransformerModel(model_params, self.dico, is_encoder=True, with_output=True).eval()
        #self.decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
        self.decoder = TransformerModel(model_params, self.dico, is_encoder=False, with_output=True).eval()
        self.encoder.load_state_dict(reloaded['encoder'])
        self.decoder.load_state_dict(reloaded['decoder'])
        params.src_id = model_params.lang2id[params.src_lang]
        params.tgt_id = model_params.lang2id[params.tgt_lang]
        self.model_params = model_params
        self.params = params
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded["params"])
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in [
            "n_words",
            "bos_index",
            "eos_index",
            "pad_index",
            "unk_index",
            "mask_index",
    ]:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded["dico_id2word"], reloaded["dico_word2id"],
                      reloaded["dico_counts"])
    encoder = (TransformerModel(model_params,
                                dico,
                                is_encoder=True,
                                with_output=True).cuda().eval())
    decoder = (TransformerModel(model_params,
                                dico,
                                is_encoder=False,
                                with_output=True).cuda().eval())
    encoder.load_state_dict(reloaded["encoder"])
    decoder.load_state_dict(reloaded["decoder"])
    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # read sentences from stdin
    src_sent = []
    for line in sys.stdin.readlines():
        assert len(line.strip().split()) > 0
        src_sent.append(line)
    logger.info("Read %i sentences from stdin. Translating ..." %
                len(src_sent))

    f = io.open(params.output_path, "w", encoding="utf-8")

    for i in range(0, len(src_sent), params.batch_size):

        # prepare batch
        word_ids = [
            torch.LongTensor([dico.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + params.batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it
        encoded = encoder(
            "fwd",
            x=batch.cuda(),
            lengths=lengths.cuda(),
            langs=langs.cuda(),
            causal=False,
        )
        encoded = encoded.transpose(0, 1)
        decoded, dec_lengths = decoder.generate(
            encoded,
            lengths.cuda(),
            params.tgt_id,
            max_len=int(1.5 * lengths.max().item() + 10),
        )

        # convert sentences to words
        for j in range(decoded.size(1)):

            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = src_sent[i + j].strip()
            target = " ".join([dico[sent[k].item()] for k in range(len(sent))])
            sys.stderr.write("%i / %i: %s -> %s\n" %
                             (i + j, len(src_sent), source, target))
            f.write(target + "\n")

    f.close()
示例#9
0
def main(params):
    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])

    # update dictionary parameters
    for name in ['src_n_words', 'tgt_n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    source_dico = Dictionary(reloaded['source_dico_id2word'], reloaded['source_dico_word2id'])
    target_dico = Dictionary(reloaded['target_dico_id2word'], reloaded['target_dico_word2id'])
    encoder = TransformerEncoder(model_params, source_dico, with_output=False).cuda().eval()
    decoder = TransformerDecoder(model_params, target_dico, with_output=True).cuda().eval()
    encoder.load_state_dict(reloaded['encoder'])
    decoder.load_state_dict(reloaded['decoder'])

    # read sentences from stdin
    table_lines = []
    table_inf = open(params.table_path, 'r', encoding='utf-8')

    for table_line in table_inf:
        table_lines.append(table_line)

    outf = io.open(params.output_path, 'w', encoding='utf-8')

    for i in range(0, len(table_lines), params.batch_size):
        # prepare batch
        enc_x1_ids = []
        enc_x2_ids = []
        enc_x3_ids = []
        enc_x4_ids = []
        for table_line in table_lines[i:i + params.batch_size]:
            record_seq = [each.split('|') for each in table_line.split()]
            assert all([len(x) == 4 for x in record_seq])
            enc_x1_ids.append(torch.LongTensor([source_dico.index(x[0]) for x in record_seq]))
            enc_x2_ids.append(torch.LongTensor([source_dico.index(x[1]) for x in record_seq]))
            enc_x3_ids.append(torch.LongTensor([source_dico.index(x[2]) for x in record_seq]))
            enc_x4_ids.append(torch.LongTensor([source_dico.index(x[3]) for x in record_seq]))

        enc_xlen = torch.LongTensor([len(x) + 2 for x in enc_x1_ids])
        enc_x1 = torch.LongTensor(enc_xlen.max().item(), enc_xlen.size(0)).fill_(params.pad_index)
        enc_x1[0] = params.eos_index
        enc_x2 = torch.LongTensor(enc_xlen.max().item(), enc_xlen.size(0)).fill_(params.pad_index)
        enc_x2[0] = params.eos_index
        enc_x3 = torch.LongTensor(enc_xlen.max().item(), enc_xlen.size(0)).fill_(params.pad_index)
        enc_x3[0] = params.eos_index
        enc_x4 = torch.LongTensor(enc_xlen.max().item(), enc_xlen.size(0)).fill_(params.pad_index)
        enc_x4[0] = params.eos_index

        for j, (s1,s2,s3,s4) in enumerate(zip(enc_x1_ids, enc_x2_ids, enc_x3_ids, enc_x4_ids)):
            if enc_xlen[j] > 2:  # if sentence not empty
                enc_x1[1:enc_xlen[j] - 1, j].copy_(s1)
                enc_x2[1:enc_xlen[j] - 1, j].copy_(s2)
                enc_x3[1:enc_xlen[j] - 1, j].copy_(s3)
                enc_x4[1:enc_xlen[j] - 1, j].copy_(s4)
            enc_x1[enc_xlen[j] - 1, j] = params.eos_index
            enc_x2[enc_xlen[j] - 1, j] = params.eos_index
            enc_x3[enc_xlen[j] - 1, j] = params.eos_index
            enc_x4[enc_xlen[j] - 1, j] = params.eos_index

        enc_x1 = enc_x1.cuda()
        enc_x2 = enc_x2.cuda()
        enc_x3 = enc_x3.cuda()
        enc_x4 = enc_x4.cuda()
        enc_xlen = enc_xlen.cuda()

        # encode source batch and translate it
        encoder_output = encoder('fwd', x1=enc_x1, x2=enc_x2, x3=enc_x3, x4=enc_x4, lengths=enc_xlen)
        encoder_output = encoder_output.transpose(0, 1)

        # max_len = int(1.5 * enc_xlen.max().item() + 10)
        max_len = 602
        if params.beam_size <= 1:
            decoded, dec_lengths = decoder.generate(encoder_output, enc_xlen, max_len=max_len)
        elif params.beam_size > 1:
            decoded, dec_lengths = decoder.generate_beam(encoder_output, enc_xlen, params.beam_size, 
                                            params.length_penalty, params.early_stopping, max_len=max_len)

        for j in range(decoded.size(1)):

            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = table_lines[i + j].strip()
            target = " ".join([target_dico[sent[k].item()] for k in range(len(sent))])
            sys.stderr.write("%i / %i: %s\n" % (i + j, len(table_lines), target))
            outf.write(target + "\n")

    outf.close()
示例#10
0
    else:
        proj = None
    return proj


# parse parameters
params = parser.parse_args()
if params.tokens_per_batch > -1:
    params.group_by_size = True

# check parameters
assert os.path.isdir(params.data_path)
assert os.path.isfile(params.model_path)

reloaded = torch.load('./mlm_xnli15_1024.pth')
pretrain_params = AttrDict(reloaded['params'])
# reload pretrained model
embedder = SentenceEmbedder.reload(params.model_path, params, pretrain_params)

proj = reloaded_proj(params.model_path, embedder)

# reload langs from pretrained model
params.n_langs = embedder.pretrain_params['n_langs']
params.id2lang = embedder.pretrain_params['id2lang']
params.lang2id = embedder.pretrain_params['lang2id']

# initialize the experiment / build sentence embedder
logger = initialize_exp(params)
scores = {}

# prepare trainers / evaluators
示例#11
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in [
            'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
            'mask_index'
    ]:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    encoder = TransformerModel(model_params,
                               dico,
                               is_encoder=True,
                               with_output=True).cuda().eval()
    encoder.load_state_dict(reloaded['encoder'])
    decoder = None
    #    decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
    #    decoder.load_state_dict(reloaded['decoder'])
    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # read sentences from stdin
    src_sent = []
    for line in sys.stdin.readlines():
        assert len(line.strip().split()) > 0
        src_sent.append(line)
    logger.info("Read %i sentences from stdin. Translating ..." %
                len(src_sent))

    all_encodings = []
    # For each sentence...
    for i in range(0, len(src_sent), params.batch_size):
        # prepare batch
        word_ids = [
            torch.LongTensor([dico.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + params.batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it, deal with padding
        encodings = encoderouts(encoder, batch, lengths, langs)

        # batch is actually in original order, append each sent to all_encodings
        for idx in encodings:
            all_encodings.append(idx.cpu().numpy())

    # Save all encodings to npy
    np.save(params.output_path, np.stack(all_encodings))
示例#12
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in [
            'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
            'mask_index'
    ]:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    encoder = TransformerModel(model_params,
                               dico,
                               is_encoder=True,
                               with_output=True).cuda().eval()
    decoder = TransformerModel(model_params,
                               dico,
                               is_encoder=False,
                               with_output=True).cuda().eval()
    encoder.load_state_dict(reloaded['encoder'])
    decoder.load_state_dict(reloaded['decoder'])
    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # read sentences from stdin
    src_sent = []
    with open(params.sentences_path, 'r') as file1:
        for line in file1:
            if 0 < len(line.strip().split()) < 100:
                src_sent.append(line)
    #print(len(src_sent))
    logger.info(
        "Read %i sentences from sentences file.Writing them to a src file. Translating ..."
        % len(src_sent))
    f = io.open(params.output_path + '.src_sent', 'w', encoding='utf-8')
    for sentence in src_sent:
        f.write(sentence)
    f.close()
    logger.info("Wrote them to a src file")
    f = io.open(params.output_path, 'w', encoding='utf-8')

    for i in range(0, len(src_sent), params.batch_size):

        # prepare batch
        word_ids = [
            torch.LongTensor([dico.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + params.batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it
        encoded, _ = encoder('fwd',
                             x=batch.cuda(),
                             lengths=lengths.cuda(),
                             langs=langs.cuda(),
                             causal=False)
        encoded = encoded.transpose(0, 1)
        #decoded, dec_lengths = decoder.generate(encoded, lengths.cuda(), params.tgt_id, max_len=int(1.5 * lengths.max().item() + 10))
        decoded, dec_lengths = decoder.generate_beam(
            encoded,
            lengths.cuda(),
            params.tgt_id,
            beam_size=params.beam_size,
            length_penalty=params.length_penalty,
            early_stopping=params.early_stopping,
            max_len=int(1.5 * lengths.cuda().max().item() + 10))
        # convert sentences to words
        for j in range(decoded.size(1)):

            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = src_sent[i + j].strip()
            target = " ".join([dico[sent[k].item()] for k in range(len(sent))])
            if (i + j) % 100 == 0:
                logger.info(
                    "Translation of %i / %i:\n Source sentence: %s \n Translation: %s\n"
                    % (i + j, len(src_sent), source, target))
            f.write(target + "\n")

    f.close()
示例#13
0
def main():
    parser.add_argument("--input", type=str, default="", help="input file")
    parser.add_argument("--model", type=str, default="", help="model path")
    parser.add_argument("--spm_model",
                        type=str,
                        default="",
                        help="spm model path")
    parser.add_argument("--batch_size",
                        type=int,
                        default=64,
                        help="batch size")
    parser.add_argument("--max_words", type=int, default=100, help="max words")
    parser.add_argument("--cuda", type=str, default="True", help="use cuda")
    parser.add_argument("--output", type=str, default="", help="output file")
    args = parser.parse_args()

    # Reload a pretrained model
    reloaded = torch.load(args.model)
    params = AttrDict(reloaded['params'])

    # Reload the SPM model
    spm_model = spm.SentencePieceProcessor()
    spm_model.Load(args.spm_model)

    # cuda
    assert args.cuda in ["True", "False"]
    args.cuda = eval(args.cuda)

    # build dictionary / update parameters
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    params.n_words = len(dico)
    params.bos_index = dico.index(BOS_WORD)
    params.eos_index = dico.index(EOS_WORD)
    params.pad_index = dico.index(PAD_WORD)
    params.unk_index = dico.index(UNK_WORD)
    params.mask_index = dico.index(MASK_WORD)

    # build model / reload weights
    model = TransformerModel(params, dico, True, True)
    reloaded['model'] = OrderedDict({
        key.replace('module.', ''): reloaded['model'][key]
        for key in reloaded['model']
    })
    model.load_state_dict(reloaded['model'])
    model.eval()

    if args.cuda:
        model.cuda()

    # load sentences
    sentences = []
    with open(args.input) as f:
        for line in f:
            line = spm_model.EncodeAsPieces(line.rstrip())
            line = line[:args.max_words - 1]
            sentences.append(line)

    # encode sentences
    embs = []
    for i in range(0, len(sentences), args.batch_size):
        batch = sentences[i:i + args.batch_size]
        lengths = torch.LongTensor([len(s) + 1 for s in batch])
        bs, slen = len(batch), lengths.max().item()
        assert slen <= args.max_words

        x = torch.LongTensor(slen, bs).fill_(params.pad_index)
        for k in range(bs):
            sent = torch.LongTensor([params.eos_index] +
                                    [dico.index(w) for w in batch[k]])
            x[:len(sent), k] = sent

        if args.cuda:
            x = x.cuda()
            lengths = lengths.cuda()

        with torch.no_grad():
            embedding = model('fwd',
                              x=x,
                              lengths=lengths,
                              langs=None,
                              causal=False).contiguous()[0].cpu()

        embs.append(embedding)

    # save embeddings
    torch.save(torch.cat(embs, dim=0).squeeze(0), args.output)
示例#14
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)
    parser = get_parser()
    params = parser.parse_args()
    models_path = params.model_path.split(',')

    # generate parser / parse parameters
    models_reloaded = []
    for model_path in models_path:
        models_reloaded.append(torch.load(model_path))
    model_params = AttrDict(models_reloaded[0]['params'])
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in [
            'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
            'mask_index'
    ]:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(models_reloaded[0]['dico_id2word'],
                      models_reloaded[0]['dico_word2id'],
                      models_reloaded[0]['dico_counts'])
    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    encoders = []
    decoders = []

    def package_module(modules):
        state_dict = OrderedDict()
        for k, v in modules.items():
            if k.startswith('module.'):
                state_dict[k[7:]] = v
            else:
                state_dict[k] = v
        return state_dict

    for reloaded in models_reloaded:
        encoder = TransformerModel(model_params,
                                   dico,
                                   is_encoder=True,
                                   with_output=True).to(params.device).eval()
        decoder = TransformerModel(model_params,
                                   dico,
                                   is_encoder=False,
                                   with_output=True).to(params.device).eval()
        encoder.load_state_dict(package_module(reloaded['encoder']))
        decoder.load_state_dict(package_module(reloaded['decoder']))

        # float16
        if params.fp16:
            assert torch.backends.cudnn.enabled
            encoder = network_to_half(encoder)
            decoder = network_to_half(decoder)

        encoders.append(encoder)
        decoders.append(decoder)

    #src_sent = ['Poly@@ gam@@ ie statt Demokratie .']
    src_sent = []
    for line in sys.stdin.readlines():
        assert len(line.strip().split()) > 0
        src_sent.append(line)

    f = io.open(params.output_path, 'w', encoding='utf-8')

    for i in range(0, len(src_sent), params.batch_size):

        # prepare batch
        word_ids = [
            torch.LongTensor([dico.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + params.batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it
        encodeds = []
        for encoder in encoders:
            encoded = encoder('fwd',
                              x=batch.to(params.device),
                              lengths=lengths.to(params.device),
                              langs=langs.to(params.device),
                              causal=False)
            encoded = encoded.transpose(0, 1)
            encodeds.append(encoded)

            assert encoded.size(0) == lengths.size(0)

        decoded, dec_lengths = generate_beam(
            decoders,
            encodeds,
            lengths.to(params.device),
            params.tgt_id,
            beam_size=params.beam,
            length_penalty=params.length_penalty,
            early_stopping=False,
            max_len=int(1.5 * lengths.max().item() + 10),
            params=params)

        # convert sentences to words
        for j in range(decoded.size(1)):

            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = src_sent[i + j].strip()
            target = " ".join([dico[sent[k].item()] for k in range(len(sent))])
            sys.stderr.write("%i / %i: %s -> %s\n" %
                             (i + j, len(src_sent), source, target))
            f.write(target + "\n")

    f.close()
示例#15
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    model_params['mnmt'] = params.mnmt
    logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    if model_params.share_word_embeddings or model_params.share_all_embeddings:
        dico = Dictionary(reloaded['dico_id2word'],
                          reloaded['dico_word2id'],
                          reloaded['dico_counts'])
    else:
        dico = {}
        for lang in [params.src_lang, params.tgt_lang]:
            dico[lang] = Dictionary(reloaded[lang]['dico_id2word'],
                                    reloaded[lang]['dico_word2id'],
                                    reloaded[lang]['dico_counts'])


    if model_params.share_word_embeddings or model_params.share_all_embeddings:
        encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=False).cuda().eval()
        decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
    else:
        src_dico = dico[params.src_lang]
        tgt_dico = dico[params.tgt_lang]
        encoder = TransformerModel(model_params, src_dico, is_encoder=True, with_output=False).cuda().eval()
        decoder = TransformerModel(model_params, tgt_dico, is_encoder=False, with_output=True).cuda().eval()

    try:
        encoder.load_state_dict(reloaded['encoder'])
        decoder.load_state_dict(reloaded['decoder'])
    except RuntimeError:
        enc_reload = reloaded['encoder']
        if all([k.startswith('module.') for k in enc_reload.keys()]):
            enc_reload = {k[len('module.'):]: v for k, v in enc_reload.items()}

        dec_reload = reloaded['decoder']
        if all(k.startswith('module.') for k in dec_reload.keys()):
            dec_reload = {k[len('moduls.'):]: v for k, v in dec_reload.items()}

        encoder.load_state_dict(enc_reload)
        decoder.load_state_dict(dec_reload)

    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # read sentences from stdin
    src_sent = []
    for line in sys.stdin.readlines():
        assert len(line.strip().split()) > 0
        src_sent.append(line)
    logger.info("Read %i sentences from stdin. Translating ..." % len(src_sent))

    f = io.open(params.output_path, 'w', encoding='utf-8')

    for i in range(0, len(src_sent), params.batch_size):

        word_ids = [torch.LongTensor([src_dico.index(w) for w in s.strip().split()])
                        for s in src_sent[i:i + params.batch_size]]

        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it
        encoded = encoder('fwd', x=batch.cuda(), lengths=lengths.cuda(), langs=langs.cuda(), causal=False)
        encoded = encoded.transpose(0, 1)
        if params.beam_size > 1:
            decoded, dec_lengths = decoder.generate_beam(encoded, lengths.cuda(), params.tgt_id,
                                                         beam_size=params.beam_size,
                                                         length_penalty=params.lenpen,
                                                         early_stopping=params.early_stopping,
                                                         max_len=int(1.5 * lengths.max().item() + 10))
        else:
            decoded, dec_lengths = decoder.generate(encoded, lengths.cuda(), params.tgt_id,
                                                    max_len=int(1.5 * lengths.max().item() + 10))

        # convert sentences to words
        for j in range(decoded.size(1)):
            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = src_sent[i + j].strip()
            target = " ".join([tgt_dico[sent[k].item()] for k in range(len(sent))])
            sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target))
            f.write(target + "\n")
    f.close()
import os
import torch
from logging import getLogger
from src.utils import AttrDict
from src.data.dictionary import Dictionary, BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD, MASK_WORD
from src.model.transformer import TransformerModel

logger = getLogger()


# NOTE: remember to replace the model path here
model_path = './dumped/XLM_bora_es/abcedf/checkpoint.pth'
reloaded = torch.load(model_path)
params = AttrDict(reloaded['params'])
print("Supported languages: %s" % ", ".join(params.lang2id.keys()))

dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
params.n_words = len(dico)
params.bos_index = dico.index(BOS_WORD)
params.eos_index = dico.index(EOS_WORD)
params.pad_index = dico.index(PAD_WORD)
params.unk_index = dico.index(UNK_WORD)
params.mask_index = dico.index(MASK_WORD)

# build model / reload weights
model = TransformerModel(params, dico, True, True)
model.eval()
model.load_state_dict(reloaded['model'])

codes = "./data/processed/XLM_bora_es/60k/codes"  # path to the codes of the model
fastbpe = os.path.join(os.getcwd(), 'tools/fastBPE/fast')
示例#17
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-config', type=str, default='config/aishell.yaml')
    parser.add_argument('-log', type=str, default='train.log')
    parser.add_argument('-mode', type=str, default='retrain')
    opt = parser.parse_args()

    configfile = open(opt.config)
    config = AttrDict(yaml.load(configfile, Loader=yaml.FullLoader))

    exp_name = os.path.join('egs', config.data.name, 'exp', config.model.type,
                            config.training.save_model)
    if not os.path.isdir(exp_name):
        os.makedirs(exp_name)
    logger = init_logger(os.path.join(exp_name, opt.log))

    shutil.copyfile(opt.config, os.path.join(exp_name, 'config.yaml'))
    logger.info('Save config info.')
    os.environ["CUDA_VISIBLE_DEVICES"] = config.training.gpus

    config.training.num_gpu = num_gpus(config.training.gpus)
    num_workers = 6 * (config.training.num_gpu
                       if config.training.num_gpu > 0 else 1)
    batch_size = config.data.batch_size * config.training.num_gpu if config.training.num_gpu > 0 else config.data.batch_size

    train_dataset = LmDataset(config.data, 'train')
    train_sampler = Batch_RandomSampler(len(train_dataset),
                                        batch_size=batch_size,
                                        shuffle=config.data.shuffle)
    training_data = AudioDataLoader(dataset=train_dataset,
                                    num_workers=num_workers,
                                    batch_sampler=train_sampler)
    logger.info('Load Train Set!')

    dev_dataset = LmDataset(config.data, 'dev')
    dev_sampler = Batch_RandomSampler(len(dev_dataset),
                                      batch_size=batch_size,
                                      shuffle=config.data.shuffle)
    validate_data = AudioDataLoader(dataset=dev_dataset,
                                    num_workers=num_workers,
                                    batch_sampler=dev_sampler)
    logger.info('Load Dev Set!')

    if config.training.num_gpu > 0:
        torch.cuda.manual_seed(config.training.seed)
        torch.backends.cudnn.deterministic = True
    else:
        torch.manual_seed(config.training.seed)
    logger.info('Set random seed: %d' % config.training.seed)

    if config.model.type == "transducer":
        model = Transducer(config.model)
    elif config.model.type == "ctc":
        model = CTC(config.model)
    elif config.model.type == "lm":
        model = LM(config.model)
    else:
        raise NotImplementedError

    if config.training.load_model:
        if config.training.num_gpu == 0:
            checkpoint = torch.load(config.training.load_model,
                                    map_location='cpu')
        else:
            checkpoint = torch.load(config.training.load_model)
        logger.info(str(checkpoint.keys()))
        load_model(model, checkpoint)
        logger.info('Loaded model from %s' % config.training.new_model)

    if config.training.load_encoder or config.training.load_decoder:
        if config.training.load_encoder:
            checkpoint = torch.load(config.training.load_encoder)
            model.encoder.load_state_dict(checkpoint['encoder'])
            logger.info('Loaded encoder from %s' %
                        config.training.load_encoder)
        if config.training.load_decoder:
            checkpoint = torch.load(config.training.load_decoder)
            model.decoder.load_state_dict(checkpoint['decoder'])
            logger.info('Loaded decoder from %s' %
                        config.training.load_decoder)

    if config.training.num_gpu > 0:
        model = model.cuda()
        if config.training.num_gpu > 1:
            device_ids = list(range(config.training.num_gpu))
            model = torch.nn.DataParallel(model, device_ids=device_ids)
        logger.info('Loaded the model to %d GPUs' % config.training.num_gpu)

    # n_params, enc, dec = count_parameters(model)
    # logger.info('# the number of parameters in the whole model: %d' % n_params)
    # logger.info('# the number of parameters in the Encoder: %d' % enc)
    # logger.info('# the number of parameters in the Decoder: %d' % dec)
    # logger.info('# the number of parameters in the JointNet: %d' %
    #             (n_params - dec - enc))

    optimizer = Optimizer(model.parameters(), config.optim)
    logger.info('Created a %s optimizer.' % config.optim.type)

    if opt.mode == 'continue':
        if not config.training.load_model:
            raise Exception(
                "if mode is 'continue', need 'config.training.load_model'")
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        logger.info('Load Optimizer State!')
    else:
        start_epoch = 0

    # create a visualizer
    if config.training.visualization:
        visualizer = SummaryWriter(os.path.join(exp_name, 'log'))
        logger.info('Created a visualizer.')
    else:
        visualizer = None

    logger.info(model)
    for epoch in range(start_epoch, config.training.epochs):

        train(epoch, config, model, training_data, optimizer, logger,
              visualizer)

        save_name = os.path.join(
            exp_name, '%s.epoch%d.chkpt' % (config.training.save_model, epoch))
        save_model(model, optimizer, config, save_name)
        logger.info('Epoch %d model has been saved.' % epoch)

        if config.training.eval_or_not:
            _ = eval(epoch, config, model, validate_data, logger, visualizer)

        if epoch >= config.optim.begin_to_adjust_lr:
            optimizer.decay_lr()
            # early stop
            if optimizer.lr < 1e-6:
                logger.info('The learning rate is too low to train.')
                break
            logger.info('Epoch %d update learning rate: %.6f' %
                        (epoch, optimizer.lr))

    logger.info('The training process is OVER!')
示例#18
0
def main(params):

    # initialize the multi-GPU / multi-node training
    init_distributed_mode(params)

    # initialize the experiment
    logger = initialize_exp(params)

    # initialize SLURM signal handler for time limit / pre-emption
    init_signal_handler()

    # load data
    data = load_data(params)

    # load checkpoint
    if params.model_path != "":
        reloaded = torch.load(params.model_path)
        model_params = AttrDict(reloaded['params'])
        dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                          reloaded['dico_counts'])
        encoder = TransformerModel(model_params,
                                   dico,
                                   is_encoder=True,
                                   with_output=True).cuda().eval()
        decoder = TransformerModel(model_params,
                                   dico,
                                   is_encoder=False,
                                   with_output=True).cuda().eval()
        encoder = TransformerModel(model_params,
                                   dico,
                                   is_encoder=True,
                                   with_output=True).cuda().eval()
        decoder = TransformerModel(model_params,
                                   dico,
                                   is_encoder=False,
                                   with_output=True).cuda().eval()
        encoder.load_state_dict(reloaded['encoder'])
        decoder.load_state_dict(reloaded['decoder'])
        logger.info("Supported languages: %s" %
                    ", ".join(model_params.lang2id.keys()))
    else:
        # build model
        if params.encoder_only:
            model = build_model(params, data['dico'])
        else:
            encoder, decoder = build_model(params, data['dico'])

    # build trainer, reload potential checkpoints / build evaluator
    if params.encoder_only:
        trainer = SingleTrainer(model, data, params)
        evaluator = SingleEvaluator(trainer, data, params)
    else:
        trainer = EncDecTrainer(encoder, decoder, data, params)
        evaluator = EncDecEvaluator(trainer, data, params)

    # evaluation
    if params.eval_only:
        scores = evaluator.run_all_evals(trainer)
        for k, v in scores.items():
            logger.info("%s -> %.6f" % (k, v))
        logger.info("__log__:%s" % json.dumps(scores))
        exit()

    # set sampling probabilities for training
    set_sampling_probs(data, params)

    # language model training
    for _ in range(params.max_epoch):

        logger.info("============ Starting epoch %i ... ============" %
                    trainer.epoch)

        trainer.n_sentences = 0

        while trainer.n_sentences < trainer.epoch_size:

            # CLM steps
            for lang1, lang2 in shuf_order(params.clm_steps, params):
                trainer.clm_step(lang1, lang2, params.lambda_clm)

            # MLM steps (also includes TLM if lang2 is not None)
            for lang1, lang2 in shuf_order(params.mlm_steps, params):
                trainer.mlm_step(lang1, lang2, params.lambda_mlm)

            # parallel classification steps
            for lang1, lang2 in shuf_order(params.pc_steps, params):
                trainer.pc_step(lang1, lang2, params.lambda_pc)

            # denoising auto-encoder steps
            for lang in shuf_order(params.ae_steps):
                trainer.mt_step(lang, lang, params.lambda_ae)

            # machine translation steps
            for lang1, lang2 in shuf_order(params.mt_steps, params):
                trainer.mt_step(lang1, lang2, params.lambda_mt)

            # back-translation steps
            for lang1, lang2, lang3 in shuf_order(params.bt_steps):
                trainer.bt_step(lang1, lang2, lang3, params.lambda_bt)

            trainer.iter()

        logger.info("============ End of epoch %i ============" %
                    trainer.epoch)

        # evaluate perplexity
        scores = evaluator.run_all_evals(trainer)

        # print / JSON log
        for k, v in scores.items():
            logger.info("%s -> %.6f" % (k, v))
        if params.is_master:
            logger.info("__log__:%s" % json.dumps(scores))

        # end of epoch
        trainer.save_best_model(scores)
        trainer.save_periodic()
        trainer.end_epoch(scores)
示例#19
0
def main(params):
    # generate parser / parse parameters
    #parser = get_parser()
    #params = parser.parse_args()
    reloaded = torch.load(params.model_path)

    model_params = AttrDict(reloaded['params'])

    # update dictionary parameters
    for name in [
            'src_n_words', 'tgt_n_words', 'bos_index', 'eos_index',
            'pad_index', 'unk_index', 'mask_index'
    ]:
        setattr(params, name, getattr(model_params, name))
    # print(f'src {getattr(model_params, "src_n_words")}')
    # print(f'tgt {getattr(model_params, "tgt_n_words")}')
    # build dictionary / build encoder / build decoder / reload weights
    source_dico = Dictionary(reloaded['source_dico_id2word'],
                             reloaded['source_dico_word2id'])
    target_dico = Dictionary(reloaded['target_dico_id2word'],
                             reloaded['target_dico_word2id'])
    # originalDecoder = reloaded['decoder'].copy()
    encoder = TransformerEncoder(model_params, source_dico,
                                 with_output=False).cuda().eval()
    encoder.load_state_dict(reloaded['encoder'])
    decoder = TransformerDecoder(model_params, target_dico,
                                 with_output=True).cuda().eval()
    decoder.load_state_dict(reloaded['decoder'])
    # read sentences from stdin
    table_lines = []
    title_lines = []
    table_inf = open(params.table_path, 'r', encoding='utf-8')
    for table_line in table_inf:
        table_lines.append(table_line)
    with open(params.title_path, 'r', encoding='utf-8') as title_inf:
        for title_line in title_inf:
            title_lines.append(title_line)

    assert len(title_lines) == len(table_lines)

    outf = io.open(params.output_path, 'w', encoding='utf-8')

    fillers = [
        'in', 'the', 'and', 'or', 'an', 'as', 'can', 'be', 'a', ':', '-', 'to',
        'but', 'is', 'of', 'it', 'on', '.', 'at', '(', ')', ',', 'with'
    ]

    for i in range(0, len(table_lines), params.batch_size):
        # prepare batch
        """valueLengths = []
        xLabelLengths = []
        yLabelLengths = []
        titleLengths = []"""
        enc_x1_ids = []
        enc_x2_ids = []
        enc_x3_ids = []
        enc_x4_ids = []
        for table_line, title_line in zip(table_lines[i:i + params.batch_size],
                                          title_lines[i:i +
                                                      params.batch_size]):
            record_seq = [each.split('|') for each in table_line.split()]
            assert all([len(x) == 4 for x in record_seq])

            enc_x1_ids.append(
                torch.LongTensor([source_dico.index(x[0])
                                  for x in record_seq]))
            enc_x2_ids.append(
                torch.LongTensor([source_dico.index(x[1])
                                  for x in record_seq]))
            enc_x3_ids.append(
                torch.LongTensor([source_dico.index(x[2])
                                  for x in record_seq]))
            enc_x4_ids.append(
                torch.LongTensor([source_dico.index(x[3])
                                  for x in record_seq]))

            xLabel = record_seq[1][0].split('_')
            yLabel = record_seq[0][0].split('_')
            """cleanXLabel = len([item for item in xLabel if item not in fillers])
            cleanYLabel = len([item for item in yLabel if item not in fillers])
            cleanTitle = len([word for word in title_line.split() if word not in fillers])

            xLabelLengths.append(cleanXLabel)
            yLabelLengths.append(cleanYLabel)
            titleLengths.append(cleanTitle)
            valueLengths.append(round(len(record_seq)/2))"""

        enc_xlen = torch.LongTensor([len(x) + 2 for x in enc_x1_ids])
        enc_x1 = torch.LongTensor(enc_xlen.max().item(),
                                  enc_xlen.size(0)).fill_(params.pad_index)
        enc_x1[0] = params.eos_index
        enc_x2 = torch.LongTensor(enc_xlen.max().item(),
                                  enc_xlen.size(0)).fill_(params.pad_index)
        enc_x2[0] = params.eos_index
        enc_x3 = torch.LongTensor(enc_xlen.max().item(),
                                  enc_xlen.size(0)).fill_(params.pad_index)
        enc_x3[0] = params.eos_index
        enc_x4 = torch.LongTensor(enc_xlen.max().item(),
                                  enc_xlen.size(0)).fill_(params.pad_index)
        enc_x4[0] = params.eos_index

        for j, (s1, s2, s3, s4) in enumerate(
                zip(enc_x1_ids, enc_x2_ids, enc_x3_ids, enc_x4_ids)):
            if enc_xlen[j] > 2:  # if sentence not empty
                enc_x1[1:enc_xlen[j] - 1, j].copy_(s1)
                enc_x2[1:enc_xlen[j] - 1, j].copy_(s2)
                enc_x3[1:enc_xlen[j] - 1, j].copy_(s3)
                enc_x4[1:enc_xlen[j] - 1, j].copy_(s4)
            enc_x1[enc_xlen[j] - 1, j] = params.eos_index
            enc_x2[enc_xlen[j] - 1, j] = params.eos_index
            enc_x3[enc_xlen[j] - 1, j] = params.eos_index
            enc_x4[enc_xlen[j] - 1, j] = params.eos_index

        enc_x1 = enc_x1.cuda()
        enc_x2 = enc_x2.cuda()
        enc_x3 = enc_x3.cuda()
        enc_x4 = enc_x4.cuda()
        enc_xlen = enc_xlen.cuda()

        # encode source batch and translate it
        encoder_output = encoder('fwd',
                                 x1=enc_x1,
                                 x2=enc_x2,
                                 x3=enc_x3,
                                 x4=enc_x4,
                                 lengths=enc_xlen)
        encoder_output = encoder_output.transpose(0, 1)

        max_len = 602
        if params.beam_size <= 1:
            decoded, dec_lengths = decoder.generate(encoder_output,
                                                    enc_xlen,
                                                    max_len=max_len)
        elif params.beam_size > 1:
            decoded, dec_lengths = decoder.generate_beam(encoder_output,
                                                         enc_xlen,
                                                         params.beam_size,
                                                         params.length_penalty,
                                                         params.early_stopping,
                                                         max_len=max_len)

        for j in range(decoded.size(1)):
            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]
            # print(sent)
            # output translation
            # source = table_lines[i + j].strip()
            # print(source)
            tokens = []
            for k in range(len(sent)):
                ids = sent[k].item()
                #if ids in removedDict:
                #    print('index error')
                word = target_dico[ids]
                tokens.append(word)
            target = " ".join(tokens)
            sys.stderr.write("%i / %i: %s\n" %
                             (i + j, len(table_lines), target))
            outf.write(target + "\n")
    outf.close()
示例#20
0
class Environment(CarRacing):
    """ A wrapper around the CarRacing environment 
    
    Provides APIs for: 
    1. Extracting state variables from openAI gym
    2. Converting inputs back to openAI gym controls
    """
    constants = AttrDict.from_dict({
        "size": SIZE,
        "ell": SIZE * (80 + 82),
        "fps": FPS  # Number of frames per second in the simulation
    })

    state_variable_names = ("xpos", "ypos", "theta", "velocity", "kappa",
                            "accel", "pinch")
    action_variable_names = ("jerk", "juke")

    def __init__(self):
        # Fix these options for now. We can support alternate options in the future
        super(Environment, self).__init__(allow_reverse=True,
                                          grayscale=1,
                                          show_info_panel=1,
                                          discretize_actions=None,
                                          num_obstacles=100,
                                          num_tracks=1,
                                          num_lanes=1,
                                          num_lanes_changes=4,
                                          max_time_out=0,
                                          frames_per_state=4)

    @staticmethod
    def add_argparse_args(parser):
        return parser

    @staticmethod
    def from_argparse_args(args):
        return Environment()

    @property
    def time_step_duration(self):
        return 1 / self.constants.fps

    @staticmethod
    def num_states():
        return len(Environment.state_variable_names)

    @staticmethod
    def num_actions():
        return len(Environment.action_variable_names)

    @property
    def current_state(self):
        return self._current_state.copy()

    @property
    def goal_state(self):
        return self._goal_state.copy()

    @property
    def obstacle_centers(self):
        return self._obstacle_centers.copy()

    @property
    def obstacle_radii(self):
        return self._obstacle_radii.copy()

    def disable_view_window(self):
        from gym.envs.classic_control import rendering
        org_constructor = rendering.Viewer.__init__

        def constructor(self, *args, **kwargs):
            org_constructor(self, *args, **kwargs)
            self.window.set_visible(visible=False)

        rendering.Viewer.__init__ = constructor

    def reset(self, disable_view=False):
        if disable_view:
            self.disable_view_window()
        super(Environment, self).reset()
        self._current_state = np.concatenate(
            [self._get_env_vars(), np.zeros(2)])

    def update_goal(self, relative_goal):
        """ 
        Input: relative goal in polar coordinates
        """
        # Set up the final state
        initial_state = self.current_state

        x, y = initial_state[0], initial_state[1]
        theta = initial_state[2]

        r, phi, delta_th = relative_goal

        delta_x, delta_y = (r * np.cos(phi + np.pi / 2 + theta),
                            r * np.sin(phi + np.pi / 2 + theta))
        relative_goal = np.array([delta_x, delta_y, delta_th])

        self._goal_state = np.concatenate(
            [np.array([x, y, theta]) + relative_goal,
             np.array([0, 0, 0, 0])])

    def update_obstacles(self,
                         relative_obstacle_centers=None,
                         obstacle_radii=None):
        self._obstacle_centers = relative_obstacle_centers + self.current_state[:
                                                                                2]
        self._obstacle_radii = obstacle_radii

    def get_next_state(self, state, action):
        """ Simulate one step of nonlinear dynamics """
        h = self.time_step_duration
        next_state = np.zeros_like(state)
        next_state[0] = state[0] + h * state[3] * np.cos(state[2])  # xpos
        next_state[1] = state[1] + h * state[3] * np.sin(state[2])  # ypos
        next_state[2] = state[2] + h * state[3] * state[4]  # theta
        next_state[3] = state[3] + h * state[5]  # velocity
        next_state[4] = state[4] + h * state[6]  # kappa
        next_state[5] = state[5] + h * action[0]  # accel
        next_state[6] = state[6] + h * action[1]  # pinch
        return next_state

    def rollout_actions(self, state, actions):
        assert len(
            actions.shape) == 2 and actions.shape[1] == self.num_actions()
        assert len(state.shape) == 1 and state.shape[0] == self.num_states()
        num_time_steps = actions.shape[0]
        state_trajectory = np.zeros((num_time_steps + 1, state.shape[0]))
        state_trajectory[0] = state
        for k in range(num_time_steps):
            state_trajectory[k + 1] = self.get_next_state(
                state_trajectory[k], actions[k])
        return state_trajectory

    def _get_env_vars(self):
        """ Get a subset of MPC state variables from the environment 
        
        """
        theta_mpc = self.car.hull.angle + np.pi / 2
        vec1 = np.array(self.car.hull.linearVelocity)  # Velocity as a vector
        vec2 = rotate_by_angle(np.array([1, 0]), theta_mpc)
        velocity_mpc = np.dot(vec1, vec2)
        kappa_mpc = np.tan(self.car.wheels[0].joint.angle) / self.constants.ell

        x_env = (1 / 2) * (self.car.wheels[2].position[0] +
                           self.car.wheels[3].position[0])
        y_env = (1 / 2) * (self.car.wheels[2].position[1] +
                           self.car.wheels[3].position[1])
        x_mpc = x_env
        y_mpc = y_env

        return np.array([x_mpc, y_mpc, theta_mpc, velocity_mpc, kappa_mpc])

    def take_action(self, action):
        """ Receive MPC action and feed it to the underlying environment 
        
        Expects np.ndarray of (jerk, juke)
        """
        next_state = self.get_next_state(self.current_state, action)

        # Get the env action from the MPC state and take it
        kappa, accel = next_state[4], next_state[5]
        steering_action = -1 * np.arctan(self.constants.ell * kappa)
        gas_action = (1 / 500) * accel  # Polo's magic constant
        brake_action = 0
        env_action = np.array([steering_action, gas_action, brake_action])
        _, reward, done, info = self.step(env_action)
        self._current_state = np.concatenate(
            [self._get_env_vars(), next_state[5:7]])

        return self.current_state, reward, done, info
示例#21
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    model_params.add_pred = ""
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in [
            'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
            'mask_index'
    ]:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    src_dico = load_binarized(params.src_data)
    tgt_dico = load_binarized(params.tgt_data)

    encoder = TransformerModel(model_params,
                               src_dico,
                               is_encoder=True,
                               with_output=False).cuda().eval()
    decoder = TransformerModel(model_params,
                               tgt_dico,
                               is_encoder=False,
                               with_output=True).cuda().eval()

    if all([k.startswith('module.') for k in reloaded['encoder'].keys()]):
        reloaded['encoder'] = {
            k[len('module.'):]: v
            for k, v in reloaded['encoder'].items()
        }
        reloaded['decoder'] = {
            k[len('module.'):]: v
            for k, v in reloaded['decoder'].items()
        }

    encoder.load_state_dict(reloaded['encoder'], strict=False)
    decoder.load_state_dict(reloaded['decoder'], strict=False)

    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # # float16

    # # read sentences from stdin
    src_sent = []
    input_f = open(params.input_path, 'r')
    for line in input_f:
        line = line.strip()
        assert len(line.strip().split()) > 0
        src_sent.append(line)
    logger.info("Read %i sentences from stdin. Translating ..." %
                len(src_sent))

    f = io.open(params.output_path, 'w', encoding='utf-8')

    for i in range(0, len(src_sent), params.batch_size):

        # prepare batch
        word_ids = [
            torch.LongTensor([src_dico.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + params.batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it
        encoded = encoder('fwd',
                          x=batch.cuda(),
                          lengths=lengths.cuda(),
                          langs=langs.cuda(),
                          causal=False)
        encoded = [enc.transpose(0, 1) for enc in encoded]
        decoded, dec_lengths = decoder.generate(
            encoded,
            lengths.cuda(),
            params.tgt_id,
            max_len=int(1.5 * lengths.max().item() + 10))

        # convert sentences to words
        for j in range(decoded.size(1)):

            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = src_sent[i + j].strip()
            target = " ".join(
                [tgt_dico[sent[k].item()] for k in range(len(sent))])
            #sys.stderr.write("%i / %i: %s -> %s\n" % (i + j, len(src_sent), source, target))
            f.write(target + "\n")

    f.close()
示例#22
0
def run_xnlg():
    params = get_params()

    # initialize the experiment / build sentence embedder
    logger = initialize_exp(params)

    if params.tokens_per_batch > -1:
        params.group_by_size = True

    # check parameters
    assert os.path.isdir(params.data_path)
    assert os.path.isfile(params.model_path)

    # tasks
    params.transfer_tasks = params.transfer_tasks.split(',')
    assert len(params.transfer_tasks) > 0
    assert all([task in TASKS for task in params.transfer_tasks])

    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))
    params.n_langs = model_params['n_langs']
    params.id2lang = model_params['id2lang']
    params.lang2id = model_params['lang2id']

    if "enc_params" in reloaded:
        encoder_model_params = AttrDict(reloaded["enc_params"])
    elif params.n_enc_layers == model_params.n_layers or params.n_enc_layers == 0:
        encoder_model_params = model_params
    else:
        encoder_model_params = AttrDict(reloaded['params'])
        encoder_model_params.n_layers = params.n_enc_layers
        assert model_params.n_layers is not encoder_model_params.n_layers

    if "dec_params" in reloaded:
        decoder_model_params = AttrDict(reloaded["dec_params"])
    elif params.n_dec_layers == model_params.n_layers or params.n_dec_layers == 0:
        decoder_model_params = model_params
    else:
        decoder_model_params = AttrDict(reloaded['params'])
        decoder_model_params.n_layers = params.n_dec_layers
        assert model_params.n_layers is not decoder_model_params.n_layers

    params.encoder_model_params = encoder_model_params
    params.decoder_model_params = decoder_model_params

    if params.emb_dim != -1:
        encoder_model_params.emb_dim = params.emb_dim
        decoder_model_params.emb_dim = params.emb_dim

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])

    for p in [params, encoder_model_params, decoder_model_params]:
        p.n_words = len(dico)
        p.bos_index = dico.index(BOS_WORD)
        p.eos_index = dico.index(EOS_WORD)
        p.pad_index = dico.index(PAD_WORD)
        p.unk_index = dico.index(UNK_WORD)
        p.mask_index = dico.index(MASK_WORD)

    encoder = TransformerModel(encoder_model_params,
                               dico,
                               is_encoder=True,
                               with_output=False)
    decoder = TransformerModel(decoder_model_params,
                               dico,
                               is_encoder=False,
                               with_output=True)

    def _process_state_dict(state_dict):
        return {(k[7:] if k.startswith('module.') else k): v
                for k, v in state_dict.items()}

    if params.no_init == "all":
        logger.info("All Models will not load state dict.!!!")
    elif params.reload_emb != "":
        logger.info("Reloading embedding from %s ..." % params.reload_emb)
        word2id, embeddings = read_txt_embeddings(logger, params.reload_emb)
        set_pretrain_emb(logger, encoder, dico, word2id, embeddings)
        set_pretrain_emb(logger, decoder, dico, word2id, embeddings)
    else:
        if "model" in reloaded:
            if params.no_init != "encoder":
                encoder.load_state_dict(_process_state_dict(reloaded['model']),
                                        strict=False)
            if params.no_init != "decoder":
                decoder.load_state_dict(_process_state_dict(reloaded['model']),
                                        strict=False)
        else:
            if params.no_init != "encoder":
                encoder.load_state_dict(_process_state_dict(
                    reloaded['encoder']),
                                        strict=False)
            if params.no_init != "decoder":
                decoder.load_state_dict(
                    _process_state_dict(reloaded['decoder']))

    scores = {}

    # run
    for task in params.transfer_tasks:
        if task == "XQG":
            XQG_v3(encoder, decoder, scores, dico, params).run()
        elif task == "XSumm":
            XSumm(encoder, decoder, scores, dico, params).run()
示例#23
0
#%%
# Perplexity | layers | dropout |
#    9.8509  |   12   |   0.1   | /checkpoint/guismay/dumped/clm_test1/10347724/train.log
#   10.2989  |   18   |   0.1   | /checkpoint/guismay/dumped/clm_test2/10402246/train.log
#   10.7602  |   12   |   0.1   | /checkpoint/guismay/dumped/clm_test3/10431903/train.log
#   11.0479  |   12   |   0.1   | /checkpoint/guismay/dumped/clm_test1/10347726/train.log
#   11.3784  |   12   |   0.1   | /checkpoint/guismay/dumped/clm_test1/10347725/train.log
#   11.8830  |   18   |   0.1   | /checkpoint/guismay/dumped/clm_test2/10403080/train.log
#   12.0149  |   12   |   0.3   | /checkpoint/guismay/dumped/clm_test3/10431904/train.log
#   12.5228  |   18   |   0.1   | /checkpoint/guismay/dumped/clm_test2/10403079/train.log

#%%
# model_path = '/checkpoint/guismay/dumped/clm_test3/10431904/periodic-23.pth'
model_path = '/checkpoint/guismay/dumped/clm_test3/10431904/periodic-23.pth'
reloaded = torch.load(model_path)
params = AttrDict(reloaded['params'])
print("Supported languages: %s" % ", ".join(params.lang2id.keys()))

#%% [markdown]
# ## Build dictionary / update parameters / build model

#%%
# build dictionary / update parameters
dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                  reloaded['dico_counts'])
assert params.n_words == len(dico)
assert params.bos_index == dico.index(BOS_WORD)
assert params.eos_index == dico.index(EOS_WORD)
assert params.pad_index == dico.index(PAD_WORD)
assert params.unk_index == dico.index(UNK_WORD)
assert params.mask_index == dico.index(MASK_WORD)
示例#24
0
def main():

    # Load pre-trained model
    model_path = './models/mlm_tlm_xnli15_1024.pth'
    reloaded = torch.load(model_path)
    params = AttrDict(reloaded['params'])

    # build dictionary / update parameters
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    params.n_words = len(dico)
    params.bos_index = dico.index(BOS_WORD)
    params.eos_index = dico.index(EOS_WORD)
    params.pad_index = dico.index(PAD_WORD)
    params.unk_index = dico.index(UNK_WORD)
    params.mask_index = dico.index(MASK_WORD)

    # build model / reload weights
    model = TransformerModel(params, dico, True, True)
    #model.cuda() #if using GPU
    model.load_state_dict(reloaded['model'])
    """ """

    with open(args.filename, "r") as f:
        sentence_list = f.readlines()[args.sn[0]:args.sn[1]]

    # remove new line symbols
    for i in range(0, len(sentence_list)):
        sentence_list[i] = sentence_list[i].replace("\n", "")

    # save as dataframe and add language tokens
    sentence_df = pd.DataFrame(sentence_list)
    sentence_df.columns = ['sentence']
    sentence_df['language'] = 'en'

    # match xlm format
    sentences = list(zip(sentence_df.sentence, sentence_df.language))(sentence,
                                                                      language)
    """ from XLM repo """
    # add </s> sentence delimiters
    sentences = [(('</s> %s </s>' % sent.strip()).split(), lang)
                 for sent, lang in sentences]

    # Create batch
    bs = len(sentences)
    slen = max([len(sent) for sent, _ in sentences])

    word_ids = torch.LongTensor(slen, bs).fill_(params.pad_index)
    for i in range(len(sentences)):
        sent = torch.LongTensor([dico.index(w) for w in sentences[i][0]])
        word_ids[:len(sent), i] = sent

    lengths = torch.LongTensor([len(sent) for sent, _ in sentences])
    langs = torch.LongTensor([params.lang2id[lang] for _, lang in sentences
                              ]).unsqueeze(0).expand(slen, bs)

    #if using GPU:
    #word_ids=word_ids.cuda()
    #lengths=lengths.cuda()
    #langs=langs.cuda()

    tensor = model('fwd',
                   x=word_ids,
                   lengths=lengths,
                   langs=langs,
                   causal=False).contiguous()
    print(tensor.size())

    # The variable tensor is of shape (sequence_length, batch_size, model_dimension).
    # tensor[0] is a tensor of shape (batch_size, model_dimension) that corresponds to the first hidden state of the last layer of each sentence.
    # This is this vector that we use to finetune on the GLUE and XNLI tasks.
    """ """

    torch.save(tensor[0], args.o)
示例#25
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    torch.manual_seed(
        params.seed
    )  # Set random seed. NB: Multi-GPU also needs torch.cuda.manual_seed_all(params.seed)
    assert (params.sample_temperature
            == 0) or (params.beam_size == 1), 'Cannot sample with beam search.'
    assert params.amp <= 1, f'params.amp == {params.amp} not yet supported.'
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    logger.info("Supported languages: %s" %
                ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in [
            'n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index',
            'mask_index'
    ]:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                      reloaded['dico_counts'])
    encoder = TransformerModel(model_params,
                               dico,
                               is_encoder=True,
                               with_output=False).cuda().eval()
    decoder = TransformerModel(model_params,
                               dico,
                               is_encoder=False,
                               with_output=True).cuda().eval()
    if all([k.startswith('module.') for k in reloaded['encoder'].keys()]):
        reloaded['encoder'] = {
            k[len('module.'):]: v
            for k, v in reloaded['encoder'].items()
        }
    encoder.load_state_dict(reloaded['encoder'])
    if all([k.startswith('module.') for k in reloaded['decoder'].keys()]):
        reloaded['decoder'] = {
            k[len('module.'):]: v
            for k, v in reloaded['decoder'].items()
        }
    decoder.load_state_dict(reloaded['decoder'])

    if params.amp != 0:
        models = apex.amp.initialize([encoder, decoder],
                                     opt_level=('O%i' % params.amp))
        encoder, decoder = models

    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # read sentences from stdin
    src_sent = []
    for line in sys.stdin.readlines():
        assert len(line.strip().split()) > 0
        src_sent.append(line)
    logger.info("Read %i sentences from stdin. Translating ..." %
                len(src_sent))

    # f = io.open(params.output_path, 'w', encoding='utf-8')

    hypothesis = [[] for _ in range(params.beam_size)]
    for i in range(0, len(src_sent), params.batch_size):

        # prepare batch
        word_ids = [
            torch.LongTensor([dico.index(w) for w in s.strip().split()])
            for s in src_sent[i:i + params.batch_size]
        ]
        lengths = torch.LongTensor([len(s) + 2 for s in word_ids])
        batch = torch.LongTensor(lengths.max().item(),
                                 lengths.size(0)).fill_(params.pad_index)
        batch[0] = params.eos_index
        for j, s in enumerate(word_ids):
            if lengths[j] > 2:  # if sentence not empty
                batch[1:lengths[j] - 1, j].copy_(s)
            batch[lengths[j] - 1, j] = params.eos_index
        langs = batch.clone().fill_(params.src_id)

        # encode source batch and translate it
        encoded = encoder('fwd',
                          x=batch.cuda(),
                          lengths=lengths.cuda(),
                          langs=langs.cuda(),
                          causal=False)
        encoded = encoded.transpose(0, 1)
        max_len = int(1.5 * lengths.max().item() + 10)
        if params.beam_size == 1:
            decoded, dec_lengths = decoder.generate(
                encoded,
                lengths.cuda(),
                params.tgt_id,
                max_len=max_len,
                sample_temperature=(None if params.sample_temperature == 0 else
                                    params.sample_temperature))
        else:
            decoded, dec_lengths, all_hyp_strs = decoder.generate_beam(
                encoded,
                lengths.cuda(),
                params.tgt_id,
                beam_size=params.beam_size,
                length_penalty=params.length_penalty,
                early_stopping=params.early_stopping,
                max_len=max_len,
                output_all_hyps=True)
        # hypothesis.extend(convert_to_text(decoded, dec_lengths, dico, params))

        # convert sentences to words
        for j in range(decoded.size(1)):

            # remove delimiters
            sent = decoded[:, j]
            delimiters = (sent == params.eos_index).nonzero().view(-1)
            assert len(delimiters) >= 1 and delimiters[0].item() == 0
            sent = sent[1:] if len(delimiters) == 1 else sent[1:delimiters[1]]

            # output translation
            source = src_sent[i + j].strip().replace('<unk>', '<<unk>>')
            target = " ".join([dico[sent[k].item()] for k in range(len(sent))
                               ]).replace('<unk>', '<<unk>>')
            if params.beam_size == 1:
                hypothesis[0].append(target)
            else:
                for hyp_rank in range(params.beam_size):
                    print(
                        all_hyp_strs[j]
                        [hyp_rank if hyp_rank < len(all_hyp_strs[j]) else -1])
                    hypothesis[hyp_rank].append(
                        all_hyp_strs[j]
                        [hyp_rank if hyp_rank < len(all_hyp_strs[j]) else -1])

            sys.stderr.write("%i / %i: %s -> %s\n" %
                             (i + j, len(src_sent), source.replace(
                                 '@@ ', ''), target.replace('@@ ', '')))
            # f.write(target + "\n")

    # f.close()

    # export sentences to reference and hypothesis files / restore BPE segmentation
    save_dir, split = params.output_path.rsplit('/', 1)
    for hyp_rank in range(len(hypothesis)):
        hyp_name = f'hyp.st={params.sample_temperature}.bs={params.beam_size}.lp={params.length_penalty}.es={params.early_stopping}.seed={params.seed if (len(hypothesis) == 1) else str(hyp_rank)}.{params.src_lang}-{params.tgt_lang}.{split}.txt'
        hyp_path = os.path.join(save_dir, hyp_name)
        with open(hyp_path, 'w', encoding='utf-8') as f:
            f.write('\n'.join(hypothesis[hyp_rank]) + '\n')
        restore_segmentation(hyp_path)

        # evaluate BLEU score
        if params.ref_path:
            bleu = eval_moses_bleu(params.ref_path, hyp_path)
            logger.info("BLEU %s %s : %f" % (hyp_path, params.ref_path, bleu))
示例#26
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    parser = get_parser()
    params = parser.parse_args()
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
    encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=True).cuda().eval()
    decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
    encoder.load_state_dict(reloaded['encoder'])
    decoder.load_state_dict(reloaded['decoder'])
    params.src_id = model_params.lang2id[params.src_lang]
    params.tgt_id = model_params.lang2id[params.tgt_lang]

    # float16
    if params.fp16:
        assert torch.backends.cudnn.enabled
        encoder = network_to_half(encoder)
        decoder = network_to_half(decoder)

    input_data = torch.load(params.input)
    eval_dataset = Dataset(input_data["sentences"], input_data["positions"], params)

    if params.subset_start is not None:
        assert params.subset_end
        eval_dataset.select_data(params.subset_start, params.subset_end)

    eval_dataset.remove_empty_sentences()
    eval_dataset.remove_long_sentences(params.max_len)

    n_batch = 0

    out = io.open(params.output_path, "w", encoding="utf-8")
    inp_dump = io.open(os.path.join(params.dump_path, "input.txt"), "w", encoding="utf-8")
    logger.info("logging to {}".format(os.path.join(params.dump_path, 'input.txt')))

    with open(params.output_path, "w", encoding="utf-8") as out:

        for batch in eval_dataset.get_iterator(shuffle=False):
            n_batch += 1

            (x1, len1) = batch
            input_text = convert_to_text(x1, len1, input_data["dico"], params)
            inp_dump.write("\n".join(input_text))
            inp_dump.write("\n")

            langs1 = x1.clone().fill_(params.src_id)

            # cuda
            x1, len1, langs1 = to_cuda(x1, len1, langs1)

            # encode source sentence
            enc1 = encoder("fwd", x=x1, lengths=len1, langs=langs1, causal=False)
            enc1 = enc1.transpose(0, 1)

            # generate translation - translate / convert to text
            max_len = int(1.5 * len1.max().item() + 10)
            if params.beam_size == 1:
                generated, lengths = decoder.generate(enc1, len1, params.tgt_id, max_len=max_len)
            else:
                generated, lengths = decoder.generate_beam(
                    enc1, len1, params.tgt_id, beam_size=params.beam_size,
                    length_penalty=params.length_penalty,
                    early_stopping=params.early_stopping,
                    max_len=max_len)

            hypotheses_batch = convert_to_text(generated, lengths, input_data["dico"], params)

            out.write("\n".join(hypotheses_batch))
            out.write("\n")

            if n_batch % 100 == 0:
                logger.info("{} batches processed".format(n_batch))

    out.close()
    inp_dump.close()