Exemple #1
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls = models.get_model(args.model)
    params = default_parameters()

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.output, args.model, params)
    override_parameters(params, args)

    # Export all parameters and model specific parameters
    export_params(params.output, "params.json", params)
    export_params(params.output, "%s.json" % args.model,
                  collect_params(params, model_cls.get_parameters()))

    # Build Graph
    with tf.Graph().as_default():
        if not params.record:
            # Build input queue
            features = dataset.get_training_input(params.input, params)
        else:
            features = record.get_input_features(
                os.path.join(params.record, "*train*"), "train", params)

        # Build model
        initializer = get_initializer(params)
        model = model_cls(params)
        if params.MRT:
            assert params.batch_size == 1
            features = mrt_utils.get_MRT(features, params, model)

        # Multi-GPU setting
        sharded_losses = parallel.parallel_model(
            model.get_training_func(initializer), features, params.device_list)
        loss = tf.add_n(sharded_losses) / len(sharded_losses)

        # Create global step
        global_step = tf.train.get_or_create_global_step()

        # Print parameters
        all_weights = {v.name: v for v in tf.trainable_variables()}
        total_size = 0

        for v_name in sorted(list(all_weights)):
            v = all_weights[v_name]
            tf.logging.info("%s\tshape    %s", v.name[:-2].ljust(80),
                            str(v.shape).ljust(20))
            v_size = np.prod(np.array(v.shape.as_list())).tolist()
            total_size += v_size
        tf.logging.info("Total trainable variables size: %d", total_size)

        learning_rate = get_learning_rate_decay(params.learning_rate,
                                                global_step, params)
        learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)
        tf.summary.scalar("learning_rate", learning_rate)

        # Create optimizer
        opt = tf.train.AdamOptimizer(learning_rate,
                                     beta1=params.adam_beta1,
                                     beta2=params.adam_beta2,
                                     epsilon=params.adam_epsilon)

        train_op = tf.contrib.layers.optimize_loss(
            name="training",
            loss=loss,
            global_step=global_step,
            learning_rate=learning_rate,
            clip_gradients=params.clip_grad_norm or None,
            optimizer=opt,
            colocate_gradients_with_ops=True)

        # Validation
        if params.validation and params.references[0]:
            files = [params.validation] + list(params.references)
            eval_inputs = dataset.sort_and_zip_files(files)
            eval_input_fn = dataset.get_evaluation_input
        else:
            eval_input_fn = None

        # Add hooks
        train_hooks = [
            tf.train.StopAtStepHook(last_step=params.train_steps),
            tf.train.NanTensorHook(loss),
            tf.train.LoggingTensorHook(
                {
                    "step": global_step,
                    "loss": loss,
                    "source": tf.shape(features["source"]),
                    "target": tf.shape(features["target"])
                },
                every_n_iter=1),
            tf.train.CheckpointSaverHook(
                checkpoint_dir=params.output,
                save_secs=params.save_checkpoint_secs or None,
                save_steps=params.save_checkpoint_steps or None,
                saver=tf.train.Saver(max_to_keep=params.keep_checkpoint_max,
                                     sharded=False))
        ]

        config = session_config(params)

        if eval_input_fn is not None:
            train_hooks.append(
                hooks.EvaluationHook(
                    lambda f: search.create_inference_graph(
                        model.get_evaluation_func(), f, params),
                    lambda: eval_input_fn(eval_inputs, params),
                    lambda x: decode_target_ids(x, params),
                    params.output,
                    config,
                    params.keep_top_checkpoint_max,
                    eval_secs=params.eval_secs,
                    eval_steps=params.eval_steps))

        # Create session, do not use default CheckpointSaverHook
        with tf.train.MonitoredTrainingSession(checkpoint_dir=params.output,
                                               hooks=train_hooks,
                                               save_checkpoint_secs=None,
                                               config=config) as sess:
            while not sess.should_stop():
                sess.run(train_op)
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            print("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_fns = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()
            model_fns.append(model_fn)

        params = params_list[0]

        #features = dataset.get_inference_input_with_bert(args.input, params)
        if params.use_bert and params.bert_emb_path:
            features = ds.get_inference_input_with_bert(
                params.input + [params.bert_emb_path], params)
        else:
            features = ds.get_inference_input(params.input, params)

        predictions = search.create_inference_graph(model_fns, features,
                                                    params)

        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)

        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params))

        results = []

        # Create session
        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            # Restore variables
            sess.run(assign_op)

            while not sess.should_stop():
                results.append(sess.run(predictions))
                message = "Finished batch %d" % len(results)
                tf.logging.log(tf.logging.INFO, message)
                if len(results) > 2:
                    break
        # Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []

        for result in results:
            outputs.append(result.tolist())

        outputs = list(itertools.chain(*outputs))

        #restored_outputs = []

        # Write to file
        with open(args.output, "w") as outfile:
            for output in outputs:
                decoded = []
                for idx in output:
                    #if idx == params.mapping["target"][params.eos]:
                    #if idx != output[-1]:
                    #print("Warning: incomplete predictions as {}".format(" ".join(output)))
                    #break
                    decoded.append(vocab[idx])

                decoded = " ".join(decoded)
                outfile.write("%s\n" % decoded)
Exemple #3
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(args.models)]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints, args.models,
                      params_list[0])  #导入训练产生的配置文件
        #for i in range(len([args.checkpoints]))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate([args.checkpoints]):
            print("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)  #所有模型变量取成列表
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()
                                       ):  #获取所有rnnsearch里不带"losses_avg"的变量
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)  #获取成数
                values[name] = tensor

            model_var_lists.append(values)  #获取所有rnnsearch里不带"losses_avg"的变量,数值

        # Build models
        model_fns = []

        for i in range(len([args.checkpoints])):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()  #调用模型中的推理功能
            model_fns.append(model_fn)

        params = params_list[0]

        #features = dataset.get_inference_input_with_bert(args.input, params)
        if params.use_bert and params.bert_emb_path:
            features = dataset.get_inference_input_with_bert(
                params.input + [params.bert_emb_path], params)
        else:
            features = dataset.get_inference_input([params.input], params)

        predictions = search.create_inference_graph(model_fns, features,
                                                    params)

        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len([args.checkpoints])):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)

        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params))

        results = []

        # Create session
        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            # Restore variables
            sess.run(assign_op)

            while not sess.should_stop():
                results.extend(sess.run(predictions))
                message = "Finished batch %d" % len(results)
                tf.logging.log(tf.logging.INFO, message)
            tar = []
            with open(params.input, "r") as inputs_f:
                for line in inputs_f:
                    if line.strip() == "O":
                        continue
                    else:
                        tar.extend(line.split(" ")[:-1])
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(args.models)]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints, args.models,
                      params_list[0])  #导入训练产生的配置文件
        #for i in range(len([args.checkpoints]))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate([args.checkpoints]):
            print("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)  #所有模型变量取成列表
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()
                                       ):  #获取所有rnnsearch里不带"losses_avg"的变量
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)  #获取成数
                values[name] = tensor

            model_var_lists.append(values)  #获取所有rnnsearch里不带"losses_avg"的变量,数值

        # Build models
        model_fns = []

        for i in range(len([args.checkpoints])):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()  #调用模型中的推理功能
            model_fns.append(model_fn)

        params = params_list[0]

        #features = dataset.get_inference_input_with_bert(args.input, params)
        if params.use_bert and params.bert_emb_path:
            features = dataset.get_inference_input_with_bert(
                params.input + [params.bert_emb_path], params)
        else:
            features = dataset.get_inference_input([params.input], params)

        predictions = search.create_inference_graph(model_fns, features,
                                                    params)

        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len([args.checkpoints])):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)

        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params))

        result_for_score = []
        result_for_write = []

        # Create session
        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            # Restore variables
            sess.run(assign_op)
            lenth = []
            with open(args.input, "r", encoding="utf8") as f:
                for line in f:
                    if line.strip() == "-DOCSTART-\n":
                        continue
                    lines = line.strip().split(" ")
                    lenth.append(len(lines))
                f.close()  #获取每句话的长度,为去掉padding的字做参考
            current_num = 0
            batch = 0
            while not sess.should_stop():
                currrent_res_arr = sess.run(predictions)
                result_for_write.append(currrent_res_arr)
                for arr in currrent_res_arr:
                    result_for_score.extend(list(arr)[:lenth[current_num]])
                    current_num += 1
                batch += 1
                message = "Finished batch %d" % batch
                tf.logging.log(tf.logging.INFO, message)
        if params.is_validation:
            from sklearn.metrics import precision_score, recall_score, f1_score
            import numpy as np
            #将标签映射成序号
            voc_lis = params.vocabulary["target"]
            index = list(np.arange(len(voc_lis)))
            dic = dict(zip(voc_lis, index))

            def map_res(x):
                return dic[x]

            targets_list = []
            with open(args.eval_file, "r") as f:  #读取标签文件
                for line in f:
                    if line.strip() == "O":
                        continue
                    lines = line.strip().split(" ")
                    targets_list.extend(list(map(map_res, lines)))  #标签文件转化成序号

            result_arr = np.array(result_for_score)
            targets_arr = np.array(targets_list)
            precision_ = precision_score(targets_arr,
                                         result_arr,
                                         average="micro",
                                         labels=[0, 2, 3, 4])
            recall_ = recall_score(result_arr,
                                   targets_arr,
                                   average="micro",
                                   labels=[0, 2, 3, 4])
            print("precision_score:{}".format(precision_))
            print("recall_score:{}".format(recall_))
            print("F1_score:{}".format(2 * precision_ * recall_ /
                                       (recall_ + precision_)))
        else:
            # Convert to plain text
            vocab = params.vocabulary["target"]
            outputs = []

            for result in result_for_write:
                outputs.append(result.tolist())

            outputs = list(itertools.chain(*outputs))

            #restored_outputs = []

            # Write to file
            num = 0
            with open(args.output, "w") as outfile:
                for output in outputs:
                    decoded = []
                    for idx in output[:lenth[num] + 1]:
                        if idx == params.mapping["target"][params.eos]:
                            if idx != output[lenth[num]]:
                                print(
                                    "Warning: incomplete predictions as line{} in src sentence"
                                    .format(num + 1))
                        decoded.append(vocab[idx])
                    decoded = " ".join(decoded[:-1])
                    outfile.write("%s\n" % decoded)
                    num += 1
Exemple #5
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_fns = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()
            model_fns.append(model_fn)

        params = params_list[0]
        # Read input file
        sorted_keys, sorted_inputs = dataset.sort_input_file(args.input)
        # Build input queue
        features = dataset.get_inference_input(sorted_inputs, params)
        predictions = search.create_inference_graph(model_fns, features,
                                                    params)

        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)

        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params)
        )

        results = []

        # Create session
        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            # Restore variables
            sess.run(assign_op)

            while not sess.should_stop():
                results.append(sess.run(predictions))
                message = "Finished batch %d" % len(results)
                tf.logging.log(tf.logging.INFO, message)

        # Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []
        scores = []

        for result in results:
            outputs.append(result[0].tolist())
            scores.append(result[1].tolist())

        outputs = list(itertools.chain(*outputs))
        scores = list(itertools.chain(*scores))

        restored_inputs = []
        restored_outputs = []
        restored_scores = []

        for index in range(len(sorted_inputs)):
            restored_inputs.append(sorted_inputs[sorted_keys[index]])
            restored_outputs.append(outputs[sorted_keys[index]])
            restored_scores.append(scores[sorted_keys[index]])

        # Write to file
        with open(args.output, "w") as outfile:
            count = 0
            for outputs, scores in zip(restored_outputs, restored_scores):
                for output, score in zip(outputs, scores):
                    decoded = []
                    for idx in output:
                        if idx == params.mapping["target"][params.eos]:
                            break
                        decoded.append(vocab[idx])

                    decoded = " ".join(decoded)

                    if not args.verbose:
                        outfile.write("%s\n" % decoded)
                        break
                    else:
                        pattern = "%d ||| %s ||| %s ||| %f\n"
                        source = restored_inputs[count]
                        values = (count, source, decoded, score)
                        outfile.write(pattern % values)

                count += 1