Example #1
0
def run_imt_server(models, processors=None, port=5007):
    # Note: servers use a special .yaml config format-- maps language pairs to NMT configuration files
    # the server instantiates a predictor for each config, and hashes them by language pair tuples -- i.e. (en,fr)
    # Caller passes in a dict of predictors, keys are tuples (source_lang, target_lang)
    if processors is None:
        app.processors = {k: None for k in models.keys()}
    else:
        app.processors = processors

    app.models = models
    app.decoders = {k: create_constrained_decoder(v) for k, v in models.items()}

    logger.info('Server starting on port: {}'.format(port))
    # logger.info('navigate to: http://localhost:{}/neural_MT_demo to see the system demo'.format(port))
    # app.run(debug=True, port=port, host='127.0.0.1', threaded=True)
    app.run(debug=False, port=port, host='127.0.0.1', threaded=False)
Example #2
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [transformer.Transformer for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor
            model_var_lists.append(values)

        # Build models
        model_fns = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_rerank_inference_func()
            model_fns.append(model_fn)

        params = params_list[0]
        # Read input file
        with open(args.input, "r") as encoded_file:
            sorted_keys = cPickle.load(encoded_file)
            decoder_input_list = cPickle.load(encoded_file)
            encoder_output_list = cPickle.load(encoded_file)

        state_placeholders = []
        for i in range(len(params.device_list)):
            decode_state = {
                "encoder":
                tf.placeholder(tf.float32, [None, None, params.hidden_size],
                               "encoder_%d" % i),
                #"encoder_weight": we doesn't need encoder weight
                "source":
                tf.placeholder(tf.int32, [None, None], "source_%d" % i),
                "source_length":
                tf.placeholder(tf.int32, [None], "source_length_%d" % i),
                # [bos_id, ...] => [..., 0]
                "target":
                tf.placeholder(tf.int32, [None, None], "target_%d" % i),
                #"target_length": tf.placeholder(tf.int32, [None, ], "target_length_%d" % i)
            }
            #需要这些值,以进行增量式解码
            for j in range(params.num_decoder_layers):
                decode_state["decoder_layer_%d_key" % j] = tf.placeholder(
                    tf.float32, [None, None, params.hidden_size],
                    "decoder_layer_%d_key_%d" % (j, i))
                decode_state["decoder_layer_%d_value" % j] = tf.placeholder(
                    tf.float32, [None, None, params.hidden_size],
                    "decoder_layer_%d_value_%d" % (j, i))  # layer and GPU
                # we only need the return value of this
                # decode_state["decoder_layer_%d_att_weight" % j] = tf.placeholder(tf.float32, [None, None, None, None],
                #                              # N Head T S  inference的时候,T总是为1,表示1步
                #                              "decoder_layer_%d_att_weight" % j),
            state_placeholders.append(decode_state)

        def decoding_fn(s):
            _decoding_fn = model_fns[0][1]
            #split s to state and feature, and 转换为嵌套的结构,以满足transformer模型
            state = {
                "encoder": s["encoder"],
                "decoder": {
                    "layer_%d" % j: {
                        "key": s["decoder_layer_%d_key" % j],
                        "value": s["decoder_layer_%d_value" % j],
                    }
                    for j in range(params.num_decoder_layers)
                }
            }
            inputs = s["target"]
            #inputs = tf.Print(inputs, [inputs], "before target", 100, 10000)
            feature = {
                "source":
                s["source"],
                "source_length":
                s["source_length"],
                # [bos_id, ...] => [..., 0]
                # "target": tf.pad(inputs[:,1:], [[0, 0], [0, 1]])
                #"target": tf.pad(inputs, [[0, 0], [0, 1]]),  # 前面没有bos_id,因此直接补上0,这是为了和decode_graph中的补bos相配合
                "target":
                inputs,
                "target_length":
                tf.fill([tf.shape(inputs)[0]],
                        tf.shape(inputs)[1])
            }
            #feature["target"] = tf.Print(feature["target"], [feature["target"]], "target", 100,10000)
            ret = _decoding_fn(feature, state, params)
            return ret

        decoder_op = parallel.data_parallelism(params.device_list,
                                               lambda s: decoding_fn(s),
                                               state_placeholders)

        #batch = tf.shape(encoder_output)[0]

        # Create assign ops
        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)

        results = []
        sen_decode_time = []
        grid_hyps = []  #存放每个句子中每个grid中的hyps,以便后期分析和统计
        # Create session
        with tf.Session(config=session_config(params)) as sess:
            # from tensorflow.python import debug as tf_debug
            # sess = tf_debug.LocalCLIDebugWrapperSession(sess,ui_type='curses')#readline

            # Restore variables
            sess.run(assign_op)
            #startpoint=320
            for i, (decode_input, encoder_output) in enumerate(
                    zip(decoder_input_list, encoder_output_list)):
                # if i < startpoint:
                #     continue

                # if i == startpoint:
                #     break
                # print(input["source"])
                # print(input["constraints"])
                #################
                # create constraint translation related model
                # build ensembled TM
                thumt_tm = ThumtTranslationModel(sess, decoder_op,
                                                 encoder_output,
                                                 state_placeholders,
                                                 decode_input, params)

                # Build GBS search
                cons_decoder = create_constrained_decoder(thumt_tm)
                ##################
                max_length = decode_input["source_length"][
                    0] + params.decode_length
                beam_size = params.beam_size
                # top_beams = params.top_beams
                top_beams = 1
                start_time = time.time()
                best_output, search_grid = decode(encoder_output,
                                                  sess,
                                                  decoder_op,
                                                  state_placeholders,
                                                  params,
                                                  cons_decoder,
                                                  thumt_tm,
                                                  decode_input,
                                                  top_beams,
                                                  max_hyp_len=max_length,
                                                  beam_size=beam_size,
                                                  return_alignments=True,
                                                  length_norm=False)
                sen_decode_time.append(time.time() - start_time)
                hyps_num = {k: len(search_grid[k]) for k in search_grid.keys()}
                grid_hyps.append(hyps_num)

                # output_beams = [search_grid[k] for k in search_grid.keys() if k[1] == top_row]
                # output_hyps = [h for beam in output_beams for h in beam]

                # constraints=input_constraints,
                # return_alignments=return_alignments,
                # length_norm=length_norm)
                results.append(best_output)

                message = "Finished decoding sentences index: %d" % (i)
                tf.logging.log(tf.logging.INFO, message)

        # Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []
        scores = []
        mask_ratio = []
        best_alignment = []

        for result in results:
            sub_result = zip(*result[0])
            outputs.extend(sub_result[0])
            scores.extend(sub_result[1])
            best_alignment.extend(result[1])

            # for sub_result in result:  # 每次解码结果可能有多个bestscore
            #     outputs.append(sub_result[0][0][1:])  # seqs
            #     scores.append(sub_result[0][1])  # score
            #     mask_ratio.append(0)
            #     best_alignment.extend(sub_result[1])
        new_outputs = []
        for s in outputs:
            new_outputs.append(s[1:])
        outputs = new_outputs

        for s, score in zip(outputs, scores):
            s1 = []
            for idx in s:
                if idx == params.mapping["target"][params.eos]:
                    break
                s1.append(vocab[idx])
            s1 = " ".join(s1)
            #print("%s" % s1)
            print("%f   %s" % (score, s1))

        restored_inputs = []
        restored_outputs = []
        restored_scores = []
        restored_constraints = []
        restored_alignment = []
        restored_sen_decode_time = []
        restored_grid_hyps = []
        for index in range(len(sorted_keys)):
            restored_outputs.append(outputs[sorted_keys[index]])
            restored_scores.append(scores[sorted_keys[index]])
            #restored_constraints.append(sorted_constraints[sorted_keys[index]])
            restored_alignment.append(best_alignment[sorted_keys[index]])
            restored_sen_decode_time.append(
                sen_decode_time[sorted_keys[index]])
            restored_grid_hyps.append(grid_hyps[sorted_keys[index]])

        # restored_outputs = outputs
        # restored_scores = scores
        # restored_alignment = best_alignment
        # restored_sen_decode_time = sen_decode_time
        # restored_grid_hyps = grid_hyps

        # Write to file
        with open(args.output, "w") as outfile:
            count = 0
            for output, score, de_time in zip(restored_outputs,
                                              restored_scores,
                                              restored_sen_decode_time):
                decoded = []
                for idx in output:
                    if idx == params.mapping["target"][params.eos]:
                        break
                    decoded.append(vocab[idx])
                decoded = " ".join(decoded)

                if not args.verbose:
                    outfile.write("%s\n" % decoded)
                else:
                    pattern = "%d |%s |%f |%f \n"
                    # cons = restored_constraints[count]
                    # cons_token_num = 0
                    # for cons_item in cons:
                    #     cons_token_num += cons_item["tgt_len"]
                    values = (count, decoded, score, de_time)
                    outfile.write(pattern % values)
                count += 1

        with open(args.output + ".alignment", "w") as outfile:
            count = 0
            for alignment in restored_alignment:
                outfile.write("%d\n" % count)
                cPickle.dump(alignment, outfile)
                count += 1
        #  保存解码时间和grid中的hyps,以便进行分析
        with open(args.output + ".time_hyps", "w") as outfile:
            cPickle.dump(restored_sen_decode_time, outfile)
            cPickle.dump(restored_grid_hyps, outfile)
        with open(args.output + ".time", "w") as outfile:
            time_sen = np.asarray(restored_sen_decode_time)
            ave = np.average(time_sen)
            outfile.write("average time:%f\n" % ave)
            cPickle.dump(restored_sen_decode_time, outfile)
def run(input_files,
        constraints_file,
        output,
        models,
        configs,
        weights,
        n_best=1,
        length_factor=1.3,
        beam_size=5,
        mert_nbest=False,
        write_alignments=None,
        length_norm=True):

    if configs is not None:
        assert len(models) == len(
            configs), 'Number of models differs from numer of config files'

    if weights is not None:
        assert len(models) == len(
            weights
        ), 'If you specify weights, there must be one for each model'

    return_alignments = False
    if write_alignments is not None:
        return_alignments = True
        try:
            os.remove(write_alignments)
        except OSError:
            pass

    # remember Nematus needs _encoded_ utf8
    if configs is not None:
        configs = [load_config(f) for f in configs]

    # build ensembled TM
    nematus_tm = NematusTranslationModel(models,
                                         configs,
                                         model_weights=weights)

    # Build GBS search
    decoder = create_constrained_decoder(nematus_tm)

    constraints = None
    if constraints_file is not None:
        constraints = json.loads(
            codecs.open(constraints_file, encoding='utf8').read())

    if output.name != '<stdout>':
        output = codecs.open(output.name, 'w', encoding='utf8')

    input_iters = []
    for input_file in input_files:
        input_iters.append(codecs.open(input_file, encoding='utf8'))

    for idx, inputs in enumerate(itertools.izip(*input_iters)):
        input_constraints = []
        if constraints is not None:
            input_constraints = constraints[idx]

        # Note: the length_factor is used with the length of the first model input of the ensemble
        # in case the users constraints will go beyond the max length according to length_factor
        max_hyp_len = int(round(len(inputs[0].split()) * length_factor))
        if len(input_constraints) > 0:
            num_constraint_tokens = sum(1 for c in input_constraints
                                        for _ in c)
            if num_constraint_tokens >= max_hyp_len:
                logger.warn('The number of tokens in the constraints are greater than max_len*length_factor, ' + \
                            'autoscaling the maximum hypothesis length...')
                max_hyp_len = num_constraint_tokens + int(
                    round(max_hyp_len / 2))

        best_output = decode(decoder,
                             nematus_tm,
                             inputs,
                             n_best,
                             max_hyp_len=max_hyp_len,
                             beam_size=beam_size,
                             constraints=input_constraints,
                             return_alignments=return_alignments,
                             length_norm=length_norm)

        if return_alignments:
            # decoding returned a tuple with 2 items
            best_output, best_alignments = best_output

        if n_best > 1:
            if mert_nbest:
                # format each n-best entry in the mert format
                translations, scores, model_scores = zip(*best_output)
                # start from idx 1 to cut off `None` at the beginning of the sequence
                translations = [u' '.join(s[1:]) for s in translations]
                # create dummy feature names
                model_names = [
                    u'M{}'.format(m_i) for m_i in range(len(model_scores[0]))
                ]
                #Note: we make model scores and logprob negative for MERT optimization to work
                model_score_strings = [
                    u' '.join([
                        u'{}= {}'.format(model_name, -s_i)
                        for model_name, s_i in zip(model_names, m_scores)
                    ]) for m_scores in model_scores
                ]
                nbest_output_strings = [
                    u'{} ||| {} ||| {} ||| {}'.format(idx, translation,
                                                      feature_scores, -logprob)
                    for translation, feature_scores, logprob in zip(
                        translations, model_score_strings, scores)
                ]
                decoder_output = u'\n'.join(nbest_output_strings) + u'\n'
            else:
                # start from idx 1 to cut off `None` at the beginning of the sequence
                # separate each n-best list with newline
                decoder_output = u'\n'.join(
                    [u' '.join(s[0][1:]) for s in best_output]) + u'\n\n'

            if output.name == '<stdout>':
                output.write(decoder_output.encode('utf8'))
            else:
                output.write(decoder_output)
        else:
            # start from idx 1 to cut off `None` at the beginning of the sequence
            decoder_output = u' '.join(best_output[0][1:])
            if output.name == '<stdout>':
                output.write((decoder_output + u'\n').encode('utf8'))
            else:
                output.write(decoder_output + u'\n')

        # Note alignments are always an n-best list (may be n=1)
        if write_alignments is not None:
            with codecs.open(write_alignments, 'a+',
                             encoding='utf8') as align_out:
                align_out.write(
                    json.dumps([a.tolist() for a in best_alignments]) + u'\n')

        if (idx + 1) % 10 == 0:
            logger.info('Wrote {} translations to {}'.format(
                idx + 1, output.name))
Example #4
0
def run(input_files,
        constraints_file,
        output,
        models,
        configs,
        weights,
        n_best=1,
        length_factor=1.3,
        beam_size=5,
        mert_nbest=False,
        write_alignments=None,
        length_norm=True):

    assert len(models) == len(
        configs), 'We need one config file for every model'
    if weights is not None:
        assert len(models) == len(
            weights
        ), 'If you specify weights, there must be one for each model'

    if write_alignments is not None:
        try:
            os.remove(write_alignments)
        except OSError:
            pass

    # remember Nematus needs _encoded_ utf8
    configs = [load_config(f) for f in configs]

    # build ensembled TM
    nematus_tm = NematusTranslationModel(models,
                                         configs,
                                         model_weights=weights)

    # Build GBS search
    decoder = create_constrained_decoder(nematus_tm)

    constraints = None
    if constraints_file is not None:
        constraints = json.loads(
            codecs.open(constraints_file, encoding='utf8').read())

    if output.name != '<stdout>':
        output = codecs.open(output.name, 'w', encoding='utf8')

    input_iters = []
    for input_file in input_files:
        input_iters.append(codecs.open(input_file, encoding='utf8'))

    for idx, inputs in enumerate(itertools.izip(*input_iters)):
        mapped_inputs = nematus_tm.map_inputs(inputs)

        input_constraints = []
        if constraints is not None:
            input_constraints = nematus_tm.map_constraints(constraints[idx])

        start_hyp = nematus_tm.start_hypothesis(mapped_inputs,
                                                input_constraints)

        # Note: the length_factor is used with the length of the first model input of the ensemble
        search_grid = decoder.search(
            start_hyp=start_hyp,
            constraints=input_constraints,
            max_hyp_len=int(round(len(mapped_inputs[0][0]) * length_factor)),
            beam_size=beam_size)

        best_output, best_alignments = decoder.best_n(
            search_grid,
            nematus_tm.eos_token,
            n_best=n_best,
            return_model_scores=mert_nbest,
            return_alignments=True,
            length_normalization=length_norm)

        if n_best > 1:
            if mert_nbest:
                # format each n-best entry in the mert format
                translations, scores, model_scores = zip(*best_output)
                # start from idx 1 to cut off `None` at the beginning of the sequence
                translations = [u' '.join(s[1:]) for s in translations]
                # create dummy feature names
                model_names = [
                    u'M{}'.format(m_i) for m_i in range(len(model_scores[0]))
                ]
                #Note: we make model scores and logprob negative for MERT optimization to work
                model_score_strings = [
                    u' '.join([
                        u'{}= {}'.format(model_name, -s_i)
                        for model_name, s_i in zip(model_names, m_scores)
                    ]) for m_scores in model_scores
                ]
                nbest_output_strings = [
                    u'{} ||| {} ||| {} ||| {}'.format(idx, translation,
                                                      feature_scores, -logprob)
                    for translation, feature_scores, logprob in zip(
                        translations, model_score_strings, scores)
                ]
                decoder_output = u'\n'.join(nbest_output_strings) + u'\n'
            else:
                # start from idx 1 to cut off `None` at the beginning of the sequence
                # separate each n-best list with newline
                decoder_output = u'\n'.join(
                    [u' '.join(s[0][1:]) for s in best_output]) + u'\n\n'

            if output.name == '<stdout>':
                output.write(decoder_output.encode('utf8'))
            else:
                output.write(decoder_output)
        else:
            # start from idx 1 to cut off `None` at the beginning of the sequence
            decoder_output = u' '.join(best_output[0][1:])
            if output.name == '<stdout>':
                output.write((decoder_output + u'\n').encode('utf8'))
            else:
                output.write(decoder_output + u'\n')

        # Note alignments are always an n-best list (may be n=1)
        if write_alignments is not None:
            with codecs.open(write_alignments, 'a+',
                             encoding='utf8') as align_out:
                align_out.write(
                    json.dumps([a.tolist() for a in best_alignments]) + u'\n')

        if (idx + 1) % 10 == 0:
            logger.info('Wrote {} translations to {}'.format(
                idx + 1, output.name))
Example #5
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [transformer.Transformer for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor
            model_var_lists.append(values)

        # Build models
        model_fns = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_rerank_inference_func()
            model_fns.append(model_fn)

        params = params_list[0]
        # Read input file
        sorted_keys, sorted_inputs, sorted_constraints = \
            src_cons_dataset.sort_input_src_cons(args.input, args.constraints)

        # Build input queue
        features = src_cons_dataset.get_input_with_src_constraints(
            sorted_inputs, sorted_constraints, params)

        print(sorted_keys)

        #Create placeholder
        placeholders = []
        for i in range(len(params.device_list)):
            placeholders.append({
                "source":
                tf.placeholder(tf.int32, [None, None], "source_%d" % i),
                "source_length":
                tf.placeholder(tf.int32, [None], "source_length_%d" % i),
                "constraints_src_pos":
                tf.placeholder(tf.int32, [None, None, None],
                               "constraints_src_pos_%d" % i),
                "constraints":
                tf.placeholder(tf.int32, [None, None, None],
                               "constraints_%d" % i),
                "constraints_len":
                tf.placeholder(tf.int32, [None, None],
                               "constraints_len_%d" % i)
            })
        encoding_fn = model_fns[0][0]

        encoder_op = parallel.data_parallelism(
            params.device_list, lambda f: encoding_fn(f, params), placeholders)

        state_placeholders = []
        for i in range(len(params.device_list)):
            decode_state = {
                "encoder":
                tf.placeholder(tf.float32, [None, None, params.hidden_size],
                               "encoder_%d" % i),
                #"encoder_weight": we doesn't need encoder weight
                "source":
                tf.placeholder(tf.int32, [None, None], "source_%d" % i),
                "source_length":
                tf.placeholder(tf.int32, [None], "source_length_%d" % i),
                # [bos_id, ...] => [..., 0]
                "target":
                tf.placeholder(tf.int32, [None, None], "target_%d" % i),
                #"target_length": tf.placeholder(tf.int32, [None, ], "target_length_%d" % i)
            }
            #需要这些值,以进行增量式解码
            for j in range(params.num_decoder_layers):
                decode_state["decoder_layer_%d_key" % j] = tf.placeholder(
                    tf.float32, [None, None, params.hidden_size],
                    "decoder_layer_%d_key_%d" % (j, i))
                decode_state["decoder_layer_%d_value" % j] = tf.placeholder(
                    tf.float32, [None, None, params.hidden_size],
                    "decoder_layer_%d_value_%d" % (j, i))  # layer and GPU
                # we only need the return value of this
                # decode_state["decoder_layer_%d_att_weight" % j] = tf.placeholder(tf.float32, [None, None, None, None],
                #                              # N Head T S  inference的时候,T总是为1,表示1步
                #                              "decoder_layer_%d_att_weight" % j),
            state_placeholders.append(decode_state)

        def decoding_fn(s):
            _decoding_fn = model_fns[0][1]
            #split s to state and feature, and 转换为嵌套的结构,以满足transformer模型
            state = {
                "encoder": s["encoder"],
                "decoder": {
                    "layer_%d" % j: {
                        "key": s["decoder_layer_%d_key" % j],
                        "value": s["decoder_layer_%d_value" % j],
                    }
                    for j in range(params.num_decoder_layers)
                }
            }
            inputs = s["target"]
            #inputs = tf.Print(inputs, [inputs], "before target", 100, 10000)
            feature = {
                "source":
                s["source"],
                "source_length":
                s["source_length"],
                # [bos_id, ...] => [..., 0]
                # "target": tf.pad(inputs[:,1:], [[0, 0], [0, 1]])
                #"target": tf.pad(inputs, [[0, 0], [0, 1]]),  # 前面没有bos_id,因此直接补上0,这是为了和decode_graph中的补bos相配合
                "target":
                inputs,
                "target_length":
                tf.fill([tf.shape(inputs)[0]],
                        tf.shape(inputs)[1])
            }
            #feature["target"] = tf.Print(feature["target"], [feature["target"]], "target", 100,10000)
            ret = _decoding_fn(feature, state, params)
            return ret

        decoder_op = parallel.data_parallelism(params.device_list,
                                               lambda s: decoding_fn(s),
                                               state_placeholders)

        #batch = tf.shape(encoder_output)[0]

        # Create assign ops
        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)
        results = []

        # Create session
        with tf.Session(config=session_config(params)) as sess:
            # from tensorflow.python import debug as tf_debug
            # sess = tf_debug.LocalCLIDebugWrapperSession(sess,ui_type='curses')#readline

            # Restore variables
            sess.run(assign_op)
            sess.run(tf.tables_initializer())
            # pad_id = params.mapping["target"][params.pad]
            # bos_id = params.mapping["target"][params.bos]
            # eos_id = params.mapping["target"][params.eos]
            while True:
                try:
                    feats = sess.run(features)
                    encoder_op, feed_dict = shard_features(
                        feats, placeholders, encoder_op)
                    #print("encoding %d" % i)
                    encoder_state = sess.run(encoder_op, feed_dict=feed_dict)
                    decoder_input_list = []
                    encoder_output_list = []
                    for j in range(len(feats["source"])):
                        decoder_input_item = {
                            "source": [feats["source"][j]],
                            "source_length": [feats["source_length"][j]],
                            "constraints_src_pos":
                            feats["constraints_src_pos"][j],
                            "constraints": feats["constraints"][j],
                            "constraints_len": feats["constraints_len"][j],
                        }
                        decoder_input_list.append(decoder_input_item)
                    # 不能简单的用GPU数量来循环,要用实际的输出来循环,因为有时候会空出GPU,比如最后一句或几句,无法凑够给1个GPU
                    for i in range(len(encoder_state[0])):  # gpu
                        state_len = len(encoder_state[0][i])  #
                        for j in range(state_len):
                            encoder_output_item = {
                                "encoder": encoder_state[0][i][j:j + 1],
                                "encoder_weight": encoder_state[1][i][j:j + 1]
                            }
                            encoder_output_list.append(encoder_output_item)

                    for input, encoder_output in zip(decoder_input_list,
                                                     encoder_output_list):
                        # print(input["source"])
                        # print(input["constraints"])
                        #################
                        # create constraint translation related model
                        # build ensembled TM
                        thumt_tm = ThumtTranslationModel(
                            sess, decoder_op, encoder_output,
                            state_placeholders, input, params)

                        # Build GBS search
                        cons_decoder = create_constrained_decoder(thumt_tm)
                        ##################
                        max_length = input["source_length"][
                            0] + params.decode_length
                        beam_size = params.beam_size
                        # top_beams = params.top_beams
                        top_beams = 1
                        best_output = decode(encoder_output,
                                             sess,
                                             decoder_op,
                                             state_placeholders,
                                             params,
                                             cons_decoder,
                                             thumt_tm,
                                             input,
                                             top_beams,
                                             max_hyp_len=max_length,
                                             beam_size=beam_size,
                                             return_alignments=True,
                                             length_norm=False)
                        # constraints=input_constraints,
                        # return_alignments=return_alignments,
                        # length_norm=length_norm)
                        results.append(best_output)
                    message = "Finished sentences: %d" % len(results)
                    tf.logging.log(tf.logging.INFO, message)
                except tf.errors.OutOfRangeError:
                    break
        # Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []
        scores = []
        mask_ratio = []
        best_alignment = []
        # for result in results:
        # outputs.append(result)
        # scores.append(0)
        # mask_ratio.append(0)
        for result in results:
            # print(result[0])
            # #outputs.append(result[0][0][1:])
            sub_result = zip(*result[0])
            outputs.extend(sub_result[0])
            scores.extend(sub_result[1])
            mask_ratio.extend([0] * len(sub_result[1]))  #放入假的ratio
            best_alignment.extend(result[1])

            # for sub_result in result:  # 每次解码结果可能有多个bestscore
            #     outputs.append(sub_result[0][0][1:])  # seqs
            #     scores.append(sub_result[0][1])  # score
            #     mask_ratio.append(0)
            #     best_alignment.extend(sub_result[1])
        new_outputs = []
        for s in outputs:
            new_outputs.append(s[1:])
        outputs = new_outputs

        for s, score in zip(outputs, scores):
            s1 = []
            for idx in s:
                if idx == params.mapping["target"][params.eos]:
                    break
                s1.append(vocab[idx])
            s1 = " ".join(s1)
            #print("%s" % s1)
            print("%f   %s" % (score, s1))

        restored_inputs = []
        restored_outputs = []
        restored_scores = []
        restored_ratio = []
        restored_constraints = []
        restored_alignment = []
        for index in range(len(sorted_inputs)):
            restored_inputs.append(sorted_inputs[sorted_keys[index]])
            restored_outputs.append(outputs[sorted_keys[index]])
            restored_scores.append(scores[sorted_keys[index]])
            restored_ratio.append(mask_ratio[sorted_keys[index]])
            restored_constraints.append(sorted_constraints[sorted_keys[index]])
            restored_alignment.append(best_alignment[sorted_keys[index]])

        # Write to file
        with open(args.output, "w") as outfile:
            count = 0
            for output, score, ratio in zip(restored_outputs, restored_scores,
                                            restored_ratio):
                decoded = []
                for idx in output:
                    if idx == params.mapping["target"][params.eos]:
                        break
                    decoded.append(vocab[idx])
                decoded = " ".join(decoded)

                if not args.verbose:
                    outfile.write("%s\n" % decoded)
                else:
                    pattern = "%d ||| %s ||| %s ||| %f ||| %f ||| %d\n"
                    source = restored_inputs[count]
                    cons = restored_constraints[count]
                    cons_token_num = 0
                    for cons_item in cons:
                        cons_token_num += cons_item["tgt_len"]
                    values = (count, source, decoded, score, ratios[0],
                              cons_token_num)
                    outfile.write(pattern % values)
                count += 1

        with open(args.output + ".alignment", "w") as outfile:
            count = 0
            for alignment in restored_alignment:
                outfile.write("%d\n" % count)
                cPickle.dump(alignment, outfile)
                count += 1