예제 #1
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    with tf.Graph().as_default():
        model_var_lists = []

        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        model_list = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_list.append(model)

        params = params_list[0]
        params.initializer_gain = 1.0

        sorted_keys, sorted_inputs = dataset.read_eval_input_file(args.input)

        features = dataset.get_predict_input(sorted_inputs, params)

        placeholders = []

        for i in range(len(params.device_list)):
            placeholders.append({
                "text":
                tf.placeholder(tf.int32, [None, None], "text_%d" % i),
                "text_length":
                tf.placeholder(tf.int32, [None], "text_length_%d" % i),
                "aspect":
                tf.placeholder(tf.int32, [None, None], "aspect_%d" % i),
                "aspect_length":
                tf.placeholder(tf.int32, [None], "aspect_length_%d" % i),
                "polarity":
                tf.placeholder(tf.int32, [None, None], "polarity_%d" % i)
            })

        predict_fn = inference.create_predict_graph

        predictions = parallel.data_parallelism(
            params.device_list, lambda f: predict_fn(model_list, f, params),
            placeholders)

        assign_ops = []
        feed_dict = {}

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i, feed_dict)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)
        init_op = tf.tables_initializer()
        results = []

        with tf.Session(config=session_config(params)) as sess:
            sess.run(assign_op, feed_dict=feed_dict)
            sess.run(init_op)

            while True:
                try:
                    feats = sess.run(features)
                    op, feed_dict = shard_features(feats, placeholders,
                                                   predictions)
                    results.append(sess.run(op, feed_dict=feed_dict))
                    message = "Finished batch %d" % len(results)
                    tf.logging.log(tf.logging.INFO, message)
                except tf.errors.OutOfRangeError:
                    break

        input_features = []
        scores1 = []
        scores2 = []
        output_alphas = []
        for result in results:
            for item in result[0]:
                input_features.append(item.tolist())
            for item in result[1]:
                scores1.append(item.tolist())
            for item in result[2]:
                scores2.append(item.tolist())
            for item in result[3]:
                output_alphas.append(item.tolist())

        scores1 = list(itertools.chain(*scores1))
        scores2 = list(itertools.chain(*scores2))
        output_alphas = list(itertools.chain(*output_alphas))

        restored_scores1 = []
        restored_scores2 = []
        restored_output_alphas = []
        restored_inputs_text = []
        restored_inputs_aspect = []
        restored_inputs_score = []

        for index in range(len(sorted_inputs[0])):
            restored_scores1.append(scores1[sorted_keys[index]][0])
            restored_scores2.append(scores2[sorted_keys[index]])
            restored_output_alphas.append(output_alphas[sorted_keys[index]])

            restored_inputs_text.append(sorted_inputs[0][sorted_keys[index]])
            restored_inputs_aspect.append(sorted_inputs[1][sorted_keys[index]])
            restored_inputs_score.append(sorted_inputs[2][sorted_keys[index]])

        class3_bad_TP = 0.0
        class3_bad_FP = 0.0
        class3_bad_FN = 0.0

        class3_mid_TP = 0.0
        class3_mid_FP = 0.0
        class3_mid_FN = 0.0

        class3_good_TP = 0.0
        class3_good_FP = 0.0
        class3_good_FN = 0.0

        with open(args.output, "w") as outfile:

            for score1, score2, score3, alphas, text, aspect in zip(
                    restored_scores1, restored_scores2, restored_inputs_score,
                    restored_output_alphas, restored_inputs_text,
                    restored_inputs_aspect):
                score1 = str(score1)
                outfile.write("###########################\n")
                pattern = "%s|||%f,%f,%f|||%s\n"
                values = (score1, score2[0], score2[1], score2[2], score3)
                outfile.write(pattern % values)
                outfile.write(aspect + "\n")
                for (word, alpha) in zip(text.split(), alphas):
                    outfile.write(word + " " + str(alpha) + "\t")
                outfile.write("\n")

                if score1 == '0' and score3 == '0':
                    class3_bad_TP += 1.0
                if score1 == '1' and score3 == '1':
                    class3_mid_TP += 1.0
                if score1 == '2' and score3 == '2':
                    class3_good_TP += 1.0

                if score1 == '0' and score3 != '0':
                    class3_bad_FP += 1.0
                if score1 == '1' and score3 != '1':
                    class3_mid_FP += 1.0
                if score1 == '2' and score3 != '2':
                    class3_good_FP += 1.0

                if score1 != '0' and score3 == '0':
                    class3_bad_FN += 1.0
                if score1 != '1' and score3 == '1':
                    class3_mid_FN += 1.0
                if score1 != '2' and score3 == '2':
                    class3_good_FN += 1.0

            outfile.write("\n")
            outfile.write("Class 3:\n")
            outfile.write("Confusion Matrix:\n")
            outfile.write("\t" + "{name: >10s}".format(name="positive") +
                          "\t" + "{name: >10s}".format(name="neural") + "\t" +
                          "{name: >10s}".format(name="negative") + "\n")
            outfile.write("TP\t" + int2int(class3_bad_TP) + "\t" +
                          int2int(class3_mid_TP) + "\t" +
                          int2int(class3_good_TP) + "\n")
            outfile.write("FP\t" + int2int(class3_bad_FP) + "\t" +
                          int2int(class3_mid_FP) + "\t" +
                          int2int(class3_good_FP) + "\n")
            outfile.write("FN\t" + int2int(class3_bad_FN) + "\t" +
                          int2int(class3_mid_FN) + "\t" +
                          int2int(class3_good_FN) + "\n")
            outfile.write(
                "P\t" + float2int(class3_bad_TP /
                                  (class3_bad_TP + class3_bad_FP + 0.000001)) +
                "\t" + float2int(class3_mid_TP /
                                 (class3_mid_TP + class3_mid_FP + 0.000001)) +
                "\t" +
                float2int(class3_good_TP /
                          (class3_good_TP + class3_good_FP + 0.000001)) + "\n")
            outfile.write(
                "R\t" + float2int(class3_bad_TP /
                                  (class3_bad_TP + class3_bad_FN + 0.000001)) +
                "\t" + float2int(class3_mid_TP /
                                 (class3_mid_TP + class3_mid_FN + 0.000001)) +
                "\t" +
                float2int(class3_good_TP /
                          (class3_good_TP + class3_good_FN + 0.000001)) + "\n")
            outfile.write("F1\t" +
                          float2int(class3_bad_TP * 2 /
                                    (class3_bad_TP * 2 + class3_bad_FP +
                                     class3_bad_FN + 0.000001)) + "\t" +
                          float2int(class3_mid_TP * 2 /
                                    (class3_mid_TP * 2 + class3_mid_FP +
                                     class3_mid_FN + 0.000001)) + "\t" +
                          float2int(class3_good_TP * 2 /
                                    (class3_good_TP * 2 + class3_good_FP +
                                     class3_good_FN + 0.000001)) + "\n")
            outfile.write("F1-Micro:\t" + float2int(
                (class3_bad_TP + class3_mid_TP + class3_good_TP) * 2 /
                ((class3_bad_TP + class3_mid_TP + class3_good_TP) * 2 +
                 (class3_bad_FP + class3_mid_FP + class3_good_FP) +
                 (class3_bad_FN + class3_mid_FN + class3_good_FN) +
                 0.000001)) + "\n")
            outfile.write("F1-Macro:\t" + float2int(
                (class3_bad_TP * 2 /
                 (class3_bad_TP * 2 + class3_bad_FP + class3_bad_FN +
                  0.000001) + class3_mid_TP * 2 /
                 (class3_mid_TP * 2 + class3_mid_FP + class3_mid_FN +
                  0.000001) + class3_good_TP * 2 /
                 (class3_good_TP * 2 + class3_good_FP + class3_good_FN +
                  0.000001)) / 3.0) + "\n")
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    model_cls_list = [transformer.Transformer for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.models[i], model_cls_list[i].get_name(),
                      params_list[i]) for i in range(len(args.models))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.models):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(
                        model_cls_list[i].get_name()):  #ignore global_step
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_fns = []

        for i in range(len(args.models)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()
            model_fns.append(model_fn)

        params = params_list[0]
        # Read input file
        sorted_keys, sorted_inputs = dataset.sort_input_file(args.input)
        # Build input queue
        features = dataset.get_inference_input(sorted_inputs, params)
        # Create placeholders
        placeholders = []

        for i in range(len(params.device_list)):
            placeholders.append({
                "source":
                tf.placeholder(tf.int32, [None, None], "source_%d" % i),
                "source_length":
                tf.placeholder(tf.int32, [None], "source_length_%d" % i)
            })

        predictions = parallel.data_parallelism(
            params.device_list,
            lambda f: beamsearch.create_inference_graph(model_fns, f, params),
            placeholders)

        # Create assign ops
        assign_ops_all = []
        assign_placeholders_all = []
        assign_values_all = []

        all_var_list = tf.trainable_variables()

        for i in range(len(args.models)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            assign_placeholders, assign_ops, assign_values = set_variables(
                un_init_var_list, model_var_lists[i], name + "_%d" % i)

            assign_placeholders_all.append(assign_placeholders)
            assign_ops_all.append(assign_ops)
            assign_values_all.append(assign_values)

        #assign_op = tf.group(*assign_ops)
        results = []

        # Create session
        with tf.Session(config=session_config(params)) as sess:
            # Restore variables
            for i in range(len(args.models)):
                for p, assign_op, v in zip(assign_placeholders_all[i],
                                           assign_ops_all[i],
                                           assign_values_all[i]):
                    sess.run(assign_op, {p: v})
            sess.run(tf.tables_initializer())

            while True:
                try:
                    feats = sess.run(features)
                    ops, feed_dict = shard_features(feats, placeholders,
                                                    predictions)
                    results.append(sess.run(ops, feed_dict=feed_dict))
                    message = "Finished batch %d" % len(results)
                    tf.logging.log(tf.logging.INFO, message)
                except tf.errors.OutOfRangeError:
                    break

        # Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []
        scores = []

        for result in results:
            for item in result[0]:
                outputs.append(item.tolist())
            for item in result[1]:
                scores.append(item.tolist())

        outputs = list(itertools.chain(*outputs))
        scores = list(itertools.chain(*scores))

        restored_inputs = []
        restored_outputs = []
        restored_scores = []

        for index in range(len(sorted_inputs)):
            restored_inputs.append(sorted_inputs[sorted_keys[index]])
            restored_outputs.append(outputs[sorted_keys[index]])
            restored_scores.append(scores[sorted_keys[index]])

        # Write to file
        with open(args.output, "w") as outfile:
            count = 0
            for outputs, scores in zip(restored_outputs, restored_scores):
                for output, score in zip(outputs, scores):
                    decoded = []
                    for idx in output:
                        if isinstance(idx, six.integer_types):
                            symbol = vocab[idx]
                        else:
                            symbol = idx

                        if symbol == params.eos:
                            break
                        decoded.append(symbol)

                    decoded = str.join(" ", decoded)

                    if not args.log:
                        outfile.write("%s\n" % decoded)
                        break
                    else:
                        pattern = "src[%d]: %s \n trans[%.4f]: %s \n"
                        source = restored_inputs[count]
                        values = (count, source, score, decoded)
                        outfile.write(pattern % values)

                count += 1
예제 #3
0
def main(args):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Load configs
    params_list = [default_parameters() for _ in range(len(args.checkpoints))]

    params_list = [
        import_params(args.checkpoints[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]

    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(args.checkpoints))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_fns = []

        for i in range(len(args.checkpoints)):
            model = pixellink.PixelLinkNetwork(params_list[i],
                                               'PixelLinkNetwork' + "_%d" % i)
            model_fn = model.get_inference_func()
            model_fns.append(model_fn)

        params = params_list[0]
        # Build input queue
        features = dataset.get_inference_input(params)
        # Create placeholders
        placeholders = []

        for i in range(len(params.device_list)):
            placeholders.append({
                "input_img":
                tf.placeholder(tf.float32, [None, None, None, 3],
                               "input_img_%d" % i),
                'lens':
                tf.placeholder(tf.float32, [
                    None,
                ], 'lens_%d' % i),
                'cnts':
                tf.placeholder(tf.float32, [None, None, None, None],
                               'cnts_%d' % i),
                'care':
                tf.placeholder(tf.float32, [
                    None,
                ], 'care_%d' % i),
            })
            # {'input_img': (
            #     tf.Dimension(None), tf.Dimension(None), tf.Dimension(None),
            #     3),
            # 'lens': (tf.Dimension(None),),
            # 'cnts': (
            #     tf.Dimension(None), tf.Dimension(None), tf.Dimension(None),
            #     tf.Dimension(None)),
            # 'care': (tf.Dimension(None),)}
            # )

            # A list of outputs
        predictions_dict = parallel.data_parallelism(params.device_list,
                                                     model_fns, placeholders)

        # Create assign ops
        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []

            for v in all_var_list:
                if v.name.startswith('PixelLinkNetwork' + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                'PixelLinkNetwork' + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)

        # Create session
        with tf.Session(config=session_config(params)) as sess:
            # Restore variables
            sess.run(assign_op)
            sess.run(tf.tables_initializer())

            time = 0
            recall_sum, precise_sum, gt_n_sum, pred_n_sum = 0, 0, 0, 0
            while True:
                try:
                    feats = sess.run(features)
                    op, feed_dict = shard_features(feats, placeholders,
                                                   predictions_dict)
                    results = []
                    temp = sess.run(predictions_dict, feed_dict=feed_dict)
                    # print(temp)
                    results.append(temp)
                    message = "Finished batch %d" % len(results)
                    tf.logging.log(tf.logging.INFO, message)
                    #TODO: save and reconstruct
                    for res in results:
                        # print(len(results))
                        outputs = res[0]
                        img = outputs['input_img']
                        prediction = outputs['prediction']
                        lens = outputs['lens']
                        cnts = outputs['cnts']
                        cnts = [(x / 2).astype(np.int32) for x in cnts]
                        #print(cnts)
                        cnts = _depad(cnts, lens)
                        care = outputs['care']
                        # imname = outputs['imname']
                        # print(imname)
                        for i in range(img.shape[0]):
                            re_cnts = reconstruct(img[i], prediction[i])
                            TR, TP, T_gt_n, T_pred_n, PR, PP, P_gt_n, P_pred_n = \
                                evaluate(img[i], cnts, re_cnts, care)
                            tf.logging.info(' recall: ' + str(TR) +
                                            '; precise: ' + str(TP))
                            recall_sum += TR * T_gt_n
                            precise_sum += TP * T_pred_n
                            gt_n_sum += T_gt_n
                            pred_n_sum += T_pred_n

                            height, width = prediction.shape[
                                1], prediction.shape[2]
                            imgoutput = np.zeros(shape=(height * 2, width * 2,
                                                        3),
                                                 dtype=np.uint8)
                            imgoutput[0:height,
                                      width:width * 2, :] = cv2.resize(
                                          img[0], (width, height))
                            imgoutput[height:height * 2,
                                      width:width * 2, :] = (
                                          _softmax(prediction[i, :, :, 0:2]) *
                                          255).astype(np.uint8)
                            cv2.drawContours(imgoutput, cnts, -1, (0, 0, 255))
                            cv2.drawContours(imgoutput, re_cnts, -1,
                                             (0, 255, 0))
                            cv2.imwrite(
                                os.path.join(
                                    params.output,
                                    'output_{:03d}_r{}_p{}.png'.format(
                                        time, TR, TP)), imgoutput)
                            time += 1

                    # for i in range(len(predictions)):
                    #     res = reconstruct(None, predictions[i])
                    #     print(res)
                except tf.errors.OutOfRangeError:
                    if int(gt_n_sum) != 0:
                        ave_r = recall_sum / gt_n_sum
                    else:
                        ave_r = 0.0
                    if int(pred_n_sum) != 0:
                        ave_p = precise_sum / pred_n_sum
                    else:
                        ave_p = 0.0
                    if ave_r != 0.0 and ave_p != 0.0:
                        ave_f = 2 / (1 / ave_r + 1 / ave_p)
                    else:
                        ave_f = 0.0
                    tf.logging.info('ave recall:{}, precise:{}, f:{}'.format(
                        ave_r, ave_p, ave_f))
                    tf.logging.info('end evaluation')
                    time += 1
                    break