Example #1
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_list = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_list.append(model)

        params = params_list[0]
        placeholder = {}
        placeholder["source"] = tf.placeholder(tf.int32, [None, None],
                                               "source")
        placeholder["source_length"] = tf.placeholder(tf.int32, [None],
                                                      "source_length")
        enc_fn, dec_fn = model_list[0].get_inference_func()
        enc = enc_fn(placeholder, params)
        state = {}
        state["encoder"] = tf.placeholder(tf.float32,
                                          [None, None, params.hidden_size],
                                          "encoder")
        dec = dec_fn(placeholder, state, params)
        # Read input file
        sorted_keys, sorted_inputs = dataset.sort_input_file(args.input)
        # Build input queue
        features = dataset.get_inference_input(sorted_inputs, params)
        # Create placeholders
        placeholders = []

        for i in range(len(params.device_list)):
            placeholders.append({
                "source":
                tf.placeholder(tf.int32, [None, None], "source_%d" % i),
                "source_length":
                tf.placeholder(tf.int32, [None], "source_length_%d" % i)
            })

        # A list of outputs
        if params.generate_samples:
            inference_fn = sampling.create_sampling_graph
        else:
            inference_fn = inference.create_inference_graph

        # Create assign ops
        assign_ops = []
        feed_dict = {}

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i, feed_dict)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)
        init_op = tf.tables_initializer()
        results = []

        tf.get_default_graph().finalize()

        # Create session
        with tf.Session(config=session_config(params)) as sess:
            # Restore variables
            sess.run(assign_op, feed_dict=feed_dict)
            sess.run(init_op)

            total_start = time.time()
            while True:
                start = time.time()
                try:
                    feats = sess.run(features)
                    feed_dict = {
                        placeholder["source"]: feats["source"],
                        placeholder["source_length"]: feats["source_length"]
                    }
                    encoder_output = sess.run(enc, feed_dict=feed_dict)
                    encoder_output = encoder_output['encoder']
                    feed_dict_dec = {
                        placeholder["source"]: feats["source"],
                        placeholder["source_length"]: feats["source_length"],
                        state["encoder"]: encoder_output
                    }
                    result = sess.run(dec, feed_dict=feed_dict_dec)
                    #print(result)
                    results.append(result)
                    message = "Finished batch %d" % len(results)
                    tf.logging.log(tf.logging.INFO, message)
                    end = time.time()
                    print('time:', end - start, 's')
                except tf.errors.OutOfRangeError:
                    break
            total_end = time.time()
            print('total time:', total_end - total_start, 's')

        # Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []
        scores = []

        for result in results:
            print('result', result)
            for item in result:
                outputs.append(item.tolist())
            #for item in result[1]:
            #    scores.append(item.tolist())

        #outputs = list(itertools.chain(*outputs))

        restored_inputs = []
        restored_outputs = []

        for index in range(len(sorted_inputs)):
            restored_inputs.append(sorted_inputs[sorted_keys[index]])
            restored_outputs.append(outputs[sorted_keys[index]])

        # Write to file
        with open(args.output, "w") as outfile:
            count = 0
            for outputs in restored_outputs:
                print('oup', outputs)
                for output in outputs:
                    outfile.write(str(round(output, 2)) + ' ')
                outfile.write('\n')
                for output, score in zip(outputs, scores):
                    decoded = []
                    decoded = " ".join(decoded)

                    outfile.write("%s\n" % decoded)

                count += 1
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            print("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_fns = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()
            model_fns.append(model_fn)

        params = params_list[0]

        #features = dataset.get_inference_input_with_bert(args.input, params)
        if params.use_bert and params.bert_emb_path:
            features = ds.get_inference_input_with_bert(
                params.input + [params.bert_emb_path], params)
        else:
            features = ds.get_inference_input(params.input, params)

        predictions = search.create_inference_graph(model_fns, features,
                                                    params)

        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)

        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params))

        results = []

        # Create session
        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            # Restore variables
            sess.run(assign_op)

            while not sess.should_stop():
                results.append(sess.run(predictions))
                message = "Finished batch %d" % len(results)
                tf.logging.log(tf.logging.INFO, message)
                if len(results) > 2:
                    break
        # Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []

        for result in results:
            outputs.append(result.tolist())

        outputs = list(itertools.chain(*outputs))

        #restored_outputs = []

        # Write to file
        with open(args.output, "w") as outfile:
            for output in outputs:
                decoded = []
                for idx in output:
                    #if idx == params.mapping["target"][params.eos]:
                    #if idx != output[-1]:
                    #print("Warning: incomplete predictions as {}".format(" ".join(output)))
                    #break
                    decoded.append(vocab[idx])

                decoded = " ".join(decoded)
                outfile.write("%s\n" % decoded)
Example #3
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_list = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_list.append(model)

        params = params_list[0]
        # Read input file
        sorted_keys, sorted_inputs = dataset.sort_input_file(args.input)
        # Build input queue
        features = dataset.get_inference_input(sorted_inputs, params)
        # Create placeholders
        placeholders = []

        for i in range(len(params.device_list)):
            placeholders.append({
                "source":
                tf.placeholder(tf.int32, [None, None], "source_%d" % i),
                "source_length":
                tf.placeholder(tf.int32, [None], "source_length_%d" % i)
            })

        # A list of outputs
        if params.generate_samples:
            inference_fn = sampling.create_sampling_graph
        else:
            inference_fn = inference.create_inference_graph

        predictions = parallel.data_parallelism(
            params.device_list, lambda f: inference_fn(model_list, f, params),
            placeholders)

        # Create assign ops
        assign_ops = []
        feed_dict = {}

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i, feed_dict)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)
        init_op = tf.tables_initializer()
        results = []

        tf.get_default_graph().finalize()

        # Create session
        with tf.Session(config=session_config(params)) as sess:
            # Restore variables
            sess.run(assign_op, feed_dict=feed_dict)
            sess.run(init_op)

            while True:
                try:
                    feats = sess.run(features)
                    op, feed_dict = shard_features(feats, placeholders,
                                                   predictions)
                    results.append(sess.run(op, feed_dict=feed_dict))
                    message = "Finished batch %d" % len(results)
                    tf.logging.log(tf.logging.INFO, message)
                except tf.errors.OutOfRangeError:
                    break

        # Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []
        scores = []

        for result in results:
            for item in result[0]:
                outputs.append(item.tolist())
            for item in result[1]:
                scores.append(item.tolist())

        outputs = list(itertools.chain(*outputs))
        scores = list(itertools.chain(*scores))

        restored_inputs = []
        restored_outputs = []
        restored_scores = []

        for index in range(len(sorted_inputs)):
            restored_inputs.append(sorted_inputs[sorted_keys[index]])
            restored_outputs.append(outputs[sorted_keys[index]])
            restored_scores.append(scores[sorted_keys[index]])

        # Write to file
        with open(args.output, "w") as outfile:
            count = 0
            for outputs, scores in zip(restored_outputs, restored_scores):
                for output, score in zip(outputs, scores):
                    decoded = []
                    for idx in output:
                        if idx == params.mapping["target"][params.eos]:
                            break
                        decoded.append(vocab[idx])

                    decoded = " ".join(decoded)

                    if not args.verbose:
                        outfile.write("%s\n" % decoded)
                        #break
                    else:
                        pattern = "%d ||| %s ||| %s ||| %f\n"
                        source = restored_inputs[count]
                        values = (count, source, decoded, score)
                        outfile.write(pattern % values)

                count += 1
Example #4
0
def build_graph(params, args, model_list, model_cls_list, model_var_lists, problem=None):
    if problem == "parsing":
        fo = args.parsing_output
        fi = args.parsing_input
    elif problem == "amr":
        fo = args.amr_outpu
        fi = args.amr_input
    else:
        print("problem only in parsing or amr")

    # Read input file
    sorted_keys, sorted_inputs = dataset.sort_input_file(fi)
    # Build input queue
    features = dataset.get_inference_input(sorted_inputs, params)  # only source data
    # Create placeholders
    placeholders = []

    for i in range(len(params.device_list)):
        placeholders.append({
            "source": tf.placeholder(tf.int32, [None, None],
                                     "source_%d" % i),
            "source_length": tf.placeholder(tf.int32, [None],
                                            "source_length_%d" % i)
        })

    # A list of outputs
    if params.generate_samples:
        inference_fn = sampling.create_sampling_graph
    else:
        inference_fn = inference.create_inference_graph

    predictions = parallel.data_parallelism(
        params.device_list, lambda f: inference_fn(model_list, f, params, problem=problem),
        placeholders)

    # Create assign ops
    assign_ops = []
    feed_dict = {}

    all_var_list = tf.trainable_variables()

    for i in range(len(args.checkpoints)):
        un_init_var_list = []
        name = model_cls_list[i].get_name()

        for v in all_var_list:
            if v.name.startswith(name + "_%d" % i):
                un_init_var_list.append(v)

        ops = set_variables(un_init_var_list, model_var_lists[i],
                            name + "_%d" % i, feed_dict)
        assign_ops.extend(ops)

    assign_op = tf.group(*assign_ops)
    init_op = tf.tables_initializer()

    results = []

    tf.get_default_graph().finalize()

    # Create session
    with tf.Session(config=session_config(params)) as sess:
        # Restore variables
        sess.run(assign_op, feed_dict=feed_dict)
        sess.run(init_op)

        while True:
            try:
                feats = sess.run(features)
                op, feed_dict = shard_features(feats, placeholders,
                                               predictions)
                results.append(sess.run(op, feed_dict=feed_dict))
                message = "Finished %s batch %d" % (len(results), problem)
                tf.logging.log(tf.logging.INFO, message)
            except tf.errors.OutOfRangeError:
                break

    # Convert to plain text
    vocab = params.vocabulary[problem+"_target"]
    outputs = []
    scores = []

    for result in results:
        for item in result[0]:
            outputs.append(item.tolist())
        for item in result[1]:
            scores.append(item.tolist())

    outputs = list(itertools.chain(*outputs))
    scores = list(itertools.chain(*scores))

    restored_inputs = []
    restored_outputs = []
    restored_scores = []

    for index in range(len(sorted_inputs)):
        restored_inputs.append(sorted_inputs[sorted_keys[index]])
        restored_outputs.append(outputs[sorted_keys[index]])
        restored_scores.append(scores[sorted_keys[index]])

    # Write to file

    with open(fo, "w") as outfile:
        count = 0
        for outputs, scores in zip(restored_outputs, restored_scores):
            for output, score in zip(outputs, scores):
                decoded = []
                for idx in output:
                    if idx == params.mapping["target"][params.eos]:
                        break
                    decoded.append(vocab[idx])

                decoded = " ".join(decoded)

                if not args.verbose:
                    outfile.write("%s\n" % decoded)
                    break
                else:
                    pattern = "%d ||| %s ||| %s ||| %f\n"
                    source = restored_inputs[count]
                    values = (count, source, decoded, score)
                    outfile.write(pattern % values)
            count += 1
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(args.models)]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints, args.models,
                      params_list[0])  #导入训练产生的配置文件
        #for i in range(len([args.checkpoints]))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate([args.checkpoints]):
            print("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)  #所有模型变量取成列表
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()
                                       ):  #获取所有rnnsearch里不带"losses_avg"的变量
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)  #获取成数
                values[name] = tensor

            model_var_lists.append(values)  #获取所有rnnsearch里不带"losses_avg"的变量,数值

        # Build models
        model_fns = []

        for i in range(len([args.checkpoints])):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()  #调用模型中的推理功能
            model_fns.append(model_fn)

        params = params_list[0]

        #features = dataset.get_inference_input_with_bert(args.input, params)
        if params.use_bert and params.bert_emb_path:
            features = dataset.get_inference_input_with_bert(
                params.input + [params.bert_emb_path], params)
        else:
            features = dataset.get_inference_input([params.input], params)

        predictions = search.create_inference_graph(model_fns, features,
                                                    params)

        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len([args.checkpoints])):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)

        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params))

        result_for_score = []
        result_for_write = []

        # Create session
        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            # Restore variables
            sess.run(assign_op)
            lenth = []
            with open(args.input, "r", encoding="utf8") as f:
                for line in f:
                    if line.strip() == "-DOCSTART-\n":
                        continue
                    lines = line.strip().split(" ")
                    lenth.append(len(lines))
                f.close()  #获取每句话的长度,为去掉padding的字做参考
            current_num = 0
            batch = 0
            while not sess.should_stop():
                currrent_res_arr = sess.run(predictions)
                result_for_write.append(currrent_res_arr)
                for arr in currrent_res_arr:
                    result_for_score.extend(list(arr)[:lenth[current_num]])
                    current_num += 1
                batch += 1
                message = "Finished batch %d" % batch
                tf.logging.log(tf.logging.INFO, message)
        if params.is_validation:
            from sklearn.metrics import precision_score, recall_score, f1_score
            import numpy as np
            #将标签映射成序号
            voc_lis = params.vocabulary["target"]
            index = list(np.arange(len(voc_lis)))
            dic = dict(zip(voc_lis, index))

            def map_res(x):
                return dic[x]

            targets_list = []
            with open(args.eval_file, "r") as f:  #读取标签文件
                for line in f:
                    if line.strip() == "O":
                        continue
                    lines = line.strip().split(" ")
                    targets_list.extend(list(map(map_res, lines)))  #标签文件转化成序号

            result_arr = np.array(result_for_score)
            targets_arr = np.array(targets_list)
            precision_ = precision_score(targets_arr,
                                         result_arr,
                                         average="micro",
                                         labels=[0, 2, 3, 4])
            recall_ = recall_score(result_arr,
                                   targets_arr,
                                   average="micro",
                                   labels=[0, 2, 3, 4])
            print("precision_score:{}".format(precision_))
            print("recall_score:{}".format(recall_))
            print("F1_score:{}".format(2 * precision_ * recall_ /
                                       (recall_ + precision_)))
        else:
            # Convert to plain text
            vocab = params.vocabulary["target"]
            outputs = []

            for result in result_for_write:
                outputs.append(result.tolist())

            outputs = list(itertools.chain(*outputs))

            #restored_outputs = []

            # Write to file
            num = 0
            with open(args.output, "w") as outfile:
                for output in outputs:
                    decoded = []
                    for idx in output[:lenth[num] + 1]:
                        if idx == params.mapping["target"][params.eos]:
                            if idx != output[lenth[num]]:
                                print(
                                    "Warning: incomplete predictions as line{} in src sentence"
                                    .format(num + 1))
                        decoded.append(vocab[idx])
                    decoded = " ".join(decoded[:-1])
                    outfile.write("%s\n" % decoded)
                    num += 1
Example #6
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(args.models)]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints, args.models,
                      params_list[0])  #导入训练产生的配置文件
        #for i in range(len([args.checkpoints]))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate([args.checkpoints]):
            print("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)  #所有模型变量取成列表
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()
                                       ):  #获取所有rnnsearch里不带"losses_avg"的变量
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)  #获取成数
                values[name] = tensor

            model_var_lists.append(values)  #获取所有rnnsearch里不带"losses_avg"的变量,数值

        # Build models
        model_fns = []

        for i in range(len([args.checkpoints])):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()  #调用模型中的推理功能
            model_fns.append(model_fn)

        params = params_list[0]

        #features = dataset.get_inference_input_with_bert(args.input, params)
        if params.use_bert and params.bert_emb_path:
            features = dataset.get_inference_input_with_bert(
                params.input + [params.bert_emb_path], params)
        else:
            features = dataset.get_inference_input([params.input], params)

        predictions = search.create_inference_graph(model_fns, features,
                                                    params)

        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len([args.checkpoints])):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)

        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params))

        results = []

        # Create session
        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            # Restore variables
            sess.run(assign_op)

            while not sess.should_stop():
                results.extend(sess.run(predictions))
                message = "Finished batch %d" % len(results)
                tf.logging.log(tf.logging.INFO, message)
            tar = []
            with open(params.input, "r") as inputs_f:
                for line in inputs_f:
                    if line.strip() == "O":
                        continue
                    else:
                        tar.extend(line.split(" ")[:-1])
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_list = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_list.append(model)

        params = params_list[0]
        # Read input file
        sorted_keys, sorted_inputs = dataset.sort_input_file(args.input)
        # Build input queue
        features = dataset.get_inference_input(sorted_inputs, params)
        # Create placeholders
        placeholders = []

        for i in range(len(params.device_list)):
            placeholders.append({
                "source":
                tf.placeholder(tf.int32, [None, None], "source_%d" % i),
                "source_length":
                tf.placeholder(tf.int32, [None], "source_length_%d" % i)
            })

        # A list of outputs
        if params.generate_samples:
            inference_fn = sampling.create_sampling_graph
        else:
            inference_fn = inference.create_inference_graph

        predictions = parallel.data_parallelism(
            params.device_list, lambda f: inference_fn(model_list, f, params),
            placeholders)

        # Create assign ops
        assign_ops = []
        feed_dict = {}

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i, feed_dict)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)
        init_op = tf.tables_initializer()
        results = []

        tf.get_default_graph().finalize()

        tf.logging.info(args.models[0])
        if args.models[0] == 'transformer_raw_t5':
            t5_list = []
            for var in tf.trainable_variables():
                if 'en_t5_bias_mat' in var.name or 'de_self_relative_attention_bias' in var.name:
                    t5_list.append(var)
                    tf.logging.info(var)

            for op in tf.get_default_graph().get_operations():
                if 'encoder_t5_bias' in op.name or 'decoder_t5_bias' in op.name:
                    if 'random' in op.name or 'read' in op.name or 'Assign' in op.name or 'placeholder' in op.name:
                        continue
                    t5_list.append(op.values()[0])
                    tf.logging.info(op.values()[0].name)
        elif args.models[0] == 'transformer_raw_soft_t5':
            soft_t5_bias_list = []
            for op in tf.get_default_graph().get_operations():
                if 'soft_t5_bias' in op.name or 'soft_t5_encoder' in op.name or 'soft_t5_decoder' in op.name:
                    if 'random' in op.name or 'read' in op.name or 'Assign' in op.name or 'placeholder' in op.name or 'decoder' in op.name:
                        continue
                    soft_t5_bias_list.append(op.values()[0])
                    tf.logging.info(op.values()[0].name)

        # Create session
        with tf.Session(config=session_config(params)) as sess:
            # Restore variables
            sess.run(assign_op, feed_dict=feed_dict)
            sess.run(init_op)

            while True:
                try:
                    feats = sess.run(features)
                    op, feed_dict = shard_features(feats, placeholders,
                                                   predictions)
                    results.append(sess.run(op, feed_dict=feed_dict))
                    '''
                    if args.models[0] == 'transformer_raw_t5':
                        var_en_bucket=tf.get_default_graph().get_tensor_by_name(t5_list[0].name)
                        var_de_bucket=tf.get_default_graph().get_tensor_by_name(t5_list[1].name)
                        
                        var_en_bias=tf.get_default_graph().get_tensor_by_name(t5_list[2].name)
                        
                        en_bucket,de_bucket,en_t5_bias = sess.run([var_en_bucket,
                                                                   var_de_bucket,
                                                                   var_en_bias],
                                              feed_dict=feed_dict)
                        
                        ret_param = {'en_bucket':en_bucket,'de_bucket':en_bucket,
                                     'en_t5_bias':en_t5_bias}
                        pickle.dump(ret_param,open(args.checkpoints[0]+'/'+'t5_bias.pkl','wb'))
                        tf.logging.info('store the t5 bias')
                    elif args.models[0] == 'transformer_raw_soft_t5':
                        var_en_alpha=tf.get_default_graph().get_tensor_by_name(soft_t5_bias_list[0].name)
                        var_en_beta=tf.get_default_graph().get_tensor_by_name(soft_t5_bias_list[1].name)
                        var_en_t5_bias=tf.get_default_graph().get_tensor_by_name(soft_t5_bias_list[2].name)
                        en_alpha,en_beta,en_t5_bias = sess.run([var_en_alpha,var_en_beta,var_en_t5_bias], feed_dict=feed_dict)
                    
                        ret_param = {'en_t5_bias':en_t5_bias,'en_alpha':en_alpha,
                              'en_beta':en_beta}
                        pickle.dump(ret_param,open(args.checkpoints[0]+'/'+'soft_t5_bias.pkl','wb'))
                        tf.logging.info('store the soft-t5 bias')
                        '''
                    message = "Finished batch %d" % len(results)
                    tf.logging.log(tf.logging.INFO, message)
                except tf.errors.OutOfRangeError:
                    break

        # Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []
        scores = []

        for result in results:
            for shard in result:
                for item in shard[0]:
                    outputs.append(item.tolist())
                for item in shard[1]:
                    scores.append(item.tolist())

        restored_inputs = []
        restored_outputs = []
        restored_scores = []

        for index in range(len(sorted_inputs)):
            restored_inputs.append(sorted_inputs[sorted_keys[index]])
            restored_outputs.append(outputs[sorted_keys[index]])
            restored_scores.append(scores[sorted_keys[index]])

        # Write to file
        if sys.version_info.major == 2:
            outfile = open(args.output, "w")
        elif sys.version_info.major == 3:
            outfile = open(args.output, "w", encoding="utf-8")
        else:
            raise ValueError("Unkown python running environment!")

        count = 0
        for outputs, scores in zip(restored_outputs, restored_scores):
            for output, score in zip(outputs, scores):
                decoded = []
                for idx in output:
                    if idx == params.mapping["target"][params.eos]:
                        break
                    decoded.append(vocab[idx])

                decoded = " ".join(decoded)

                if not args.verbose:
                    outfile.write("%s\n" % decoded)
                else:
                    pattern = "%d ||| %s ||| %s ||| %f\n"
                    source = restored_inputs[count]
                    values = (count, source, decoded, score)
                    outfile.write(pattern % values)

            count += 1
        outfile.close()
Example #8
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_fns = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()
            model_fns.append(model_fn)

        params = params_list[0]
        # Read input file
        sorted_keys, sorted_inputs = dataset.sort_input_file(args.input)
        # Build input queue
        features = dataset.get_inference_input(sorted_inputs, params)
        predictions = search.create_inference_graph(model_fns, features,
                                                    params)

        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)

        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params)
        )

        results = []

        # Create session
        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            # Restore variables
            sess.run(assign_op)

            while not sess.should_stop():
                results.append(sess.run(predictions))
                message = "Finished batch %d" % len(results)
                tf.logging.log(tf.logging.INFO, message)

        # Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []
        scores = []

        for result in results:
            outputs.append(result[0].tolist())
            scores.append(result[1].tolist())

        outputs = list(itertools.chain(*outputs))
        scores = list(itertools.chain(*scores))

        restored_inputs = []
        restored_outputs = []
        restored_scores = []

        for index in range(len(sorted_inputs)):
            restored_inputs.append(sorted_inputs[sorted_keys[index]])
            restored_outputs.append(outputs[sorted_keys[index]])
            restored_scores.append(scores[sorted_keys[index]])

        # Write to file
        with open(args.output, "w") as outfile:
            count = 0
            for outputs, scores in zip(restored_outputs, restored_scores):
                for output, score in zip(outputs, scores):
                    decoded = []
                    for idx in output:
                        if idx == params.mapping["target"][params.eos]:
                            break
                        decoded.append(vocab[idx])

                    decoded = " ".join(decoded)

                    if not args.verbose:
                        outfile.write("%s\n" % decoded)
                        break
                    else:
                        pattern = "%d ||| %s ||| %s ||| %f\n"
                        source = restored_inputs[count]
                        values = (count, source, decoded, score)
                        outfile.write(pattern % values)

                count += 1