예제 #1
0
def train():
    args = parse_args()

    num_layers = args.num_layers
    src_vocab_size = args.src_vocab_size
    tar_vocab_size = args.tar_vocab_size
    batch_size = args.batch_size
    dropout = args.dropout
    init_scale = args.init_scale
    max_grad_norm = args.max_grad_norm
    hidden_size = args.hidden_size
    # inference process

    print("src", src_vocab_size)

    # dropout type using upscale_in_train, dropout can be remove in inferecen
    # So we can set dropout to 0
    if args.attention:
        model = AttentionModel(hidden_size,
                               src_vocab_size,
                               tar_vocab_size,
                               batch_size,
                               num_layers=num_layers,
                               init_scale=init_scale,
                               dropout=0.0)
    else:
        model = BaseModel(hidden_size,
                          src_vocab_size,
                          tar_vocab_size,
                          batch_size,
                          num_layers=num_layers,
                          init_scale=init_scale,
                          dropout=0.0)

    beam_size = args.beam_size
    trans_res = model.build_graph(mode='beam_search', beam_size=beam_size)
    # clone from default main program and use it as the validation program
    main_program = fluid.default_main_program()

    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = Executor(place)
    exe.run(framework.default_startup_program())

    source_vocab_file = args.vocab_prefix + "." + args.src_lang
    infer_file = args.infer_file

    infer_data = reader.raw_mono_data(source_vocab_file, infer_file)

    def prepare_input(batch, epoch_id=0, with_lr=True):
        src_ids, src_mask, tar_ids, tar_mask = batch
        res = {}
        src_ids = src_ids.reshape((src_ids.shape[0], src_ids.shape[1], 1))
        in_tar = tar_ids[:, :-1]
        label_tar = tar_ids[:, 1:]

        in_tar = in_tar.reshape((in_tar.shape[0], in_tar.shape[1], 1))
        in_tar = np.zeros_like(in_tar, dtype='int64')
        label_tar = label_tar.reshape(
            (label_tar.shape[0], label_tar.shape[1], 1))
        label_tar = np.zeros_like(label_tar, dtype='int64')

        res['src'] = src_ids
        res['tar'] = in_tar
        res['label'] = label_tar
        res['src_sequence_length'] = src_mask
        res['tar_sequence_length'] = tar_mask

        return res, np.sum(tar_mask)

    dir_name = args.reload_model
    print("dir name", dir_name)
    fluid.io.load_params(exe, dir_name)

    train_data_iter = reader.get_data_iter(infer_data, 1, mode='eval')

    tar_id2vocab = []
    tar_vocab_file = args.vocab_prefix + "." + args.tar_lang
    with open(tar_vocab_file, "r") as f:
        for line in f.readlines():
            tar_id2vocab.append(line.strip())

    infer_output_file = args.infer_output_file

    out_file = open(infer_output_file, 'w')

    for batch_id, batch in enumerate(train_data_iter):
        input_data_feed, word_num = prepare_input(batch, epoch_id=0)

        fetch_outs = exe.run(feed=input_data_feed,
                             fetch_list=[trans_res.name],
                             use_program_cache=False)

        res = [tar_id2vocab[e] for e in fetch_outs[0].reshape(-1)]

        res = res[1:]

        new_res = []
        for ele in res:
            if ele == "</s>":
                break
            new_res.append(ele)

        out_file.write(' '.join(new_res))
        out_file.write('\n')

    out_file.close()
예제 #2
0
def infer():
    args = parse_args()

    num_layers = args.num_layers
    src_vocab_size = args.src_vocab_size
    tar_vocab_size = args.tar_vocab_size
    batch_size = args.batch_size
    dropout = args.dropout
    init_scale = args.init_scale
    max_grad_norm = args.max_grad_norm
    hidden_size = args.hidden_size
    # inference process

    print("src", src_vocab_size)
    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    with fluid.dygraph.guard(place):
        # dropout type using upscale_in_train, dropout can be remove in inferecen
        # So we can set dropout to 0
        if args.attention:
            model = AttentionModel(hidden_size,
                                   src_vocab_size,
                                   tar_vocab_size,
                                   batch_size,
                                   beam_size=args.beam_size,
                                   num_layers=num_layers,
                                   init_scale=init_scale,
                                   dropout=0.0,
                                   mode='beam_search')
        else:
            model = BaseModel(hidden_size,
                              src_vocab_size,
                              tar_vocab_size,
                              batch_size,
                              beam_size=args.beam_size,
                              num_layers=num_layers,
                              init_scale=init_scale,
                              dropout=0.0,
                              mode='beam_search')

        source_vocab_file = args.vocab_prefix + "." + args.src_lang
        infer_file = args.infer_file

        infer_data = reader.raw_mono_data(source_vocab_file, infer_file)

        def prepare_input(batch, epoch_id=0):
            src_ids, src_mask, tar_ids, tar_mask = batch
            res = {}
            src_ids = src_ids.reshape((src_ids.shape[0], src_ids.shape[1]))
            in_tar = tar_ids[:, :-1]
            label_tar = tar_ids[:, 1:]

            in_tar = in_tar.reshape((in_tar.shape[0], in_tar.shape[1]))
            label_tar = label_tar.reshape(
                (label_tar.shape[0], label_tar.shape[1], 1))
            inputs = [src_ids, in_tar, label_tar, src_mask, tar_mask]
            return inputs, np.sum(tar_mask)

        dir_name = args.reload_model
        print("dir name", dir_name)
        state_dict, _ = fluid.dygraph.load_dygraph(dir_name)
        model.set_dict(state_dict)
        model.eval()

        train_data_iter = reader.get_data_iter(infer_data,
                                               batch_size,
                                               mode='infer')

        tar_id2vocab = []
        tar_vocab_file = args.vocab_prefix + "." + args.tar_lang
        with io.open(tar_vocab_file, "r", encoding='utf-8') as f:
            for line in f.readlines():
                tar_id2vocab.append(line.strip())

        infer_output_file = args.infer_output_file
        infer_output_dir = infer_output_file.split('/')[0]
        if not os.path.exists(infer_output_dir):
            os.mkdir(infer_output_dir)

        with io.open(infer_output_file, 'w', encoding='utf-8') as out_file:

            for batch_id, batch in enumerate(train_data_iter):
                input_data_feed, word_num = prepare_input(batch, epoch_id=0)
                # import ipdb; ipdb.set_trace()
                outputs = model(input_data_feed)
                for i in range(outputs.shape[0]):
                    ins = outputs[i].numpy()
                    res = [tar_id2vocab[int(e)] for e in ins[:, 0].reshape(-1)]
                    new_res = []
                    for ele in res:
                        if ele == "</s>":
                            break
                        new_res.append(ele)

                    out_file.write(space_tok.join(new_res))
                    out_file.write(line_tok)