Exemple #1
0
def evaluate(checkpoint):
    with tf.Graph().as_default() as g, tf.device('/cpu:0'):
        # Get images and labels
        data = tf.placeholder(tf.float32, [
            graphcnn_input.EVAL_BATCH_SIZE, graphcnn_input.HEIGHT,
            graphcnn_input.WIDTH, graphcnn_input.NUM_CHANNELS
        ])
        labels = tf.placeholder(
            tf.int32,
            [graphcnn_input.EVAL_BATCH_SIZE, graphcnn_input.NUM_CLASSES])

        # inference
        # logits = graphcnn_model.inference_GPU(data,eval_data=True,dependencies_loss=False)
        # logits = graphcnn_model.inference_CPU(data, eval_data=True, dependencies_loss=False)
        logits = graphcnn_model.inference(data, eval_data=True)

        logits = tf.sigmoid(logits)

        # Restore the moving average version of the learned variables for eval. # ?????????????????????????
        variable_averages = tf.train.ExponentialMovingAverage(
            graphcnn_option.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        # Build the summary operation based on the TF collection of Summaries.
        # summary_op = tf.merge_all_summaries()
        # summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, g)

        best_eval_value = 0
        ####
        best_eval_ckpt = 0
        ####

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True,
                log_device_placement=FLAGS.log_device_placement)) as sess:
            if checkpoint == '0':
                ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    # Restores from checkpoint
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    # extract global_step
                    global_step_for_restore = int(
                        ckpt.model_checkpoint_path.split('/')[-1].split('-')
                        [-1])
                else:
                    print('No checkpoint file found')
                    return
            else:
                if os.path.exists(
                        os.path.join(FLAGS.checkpoint_dir,
                                     'model.ckpt-' + checkpoint)):
                    saver.restore(
                        sess,
                        os.path.join(FLAGS.checkpoint_dir,
                                     'model.ckpt-' + checkpoint))
                    global_step_for_restore = int(checkpoint)
                else:
                    print('No checkpoint file found')
                    return

            num_iter = int(
                math.floor(graphcnn_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL /
                           graphcnn_input.EVAL_BATCH_SIZE))
            total_sample_count = num_iter * graphcnn_input.EVAL_BATCH_SIZE
            step = 0
            total_predicted_value = np.zeros([1, graphcnn_input.NUM_CLASSES],
                                             dtype=np.float32)  ##
            total_true_value = np.zeros([1, graphcnn_input.NUM_CLASSES],
                                        dtype=np.int32)
            while step < num_iter:
                test_data, test_label = evalDataSet.next_batch(
                    graphcnn_input.EVAL_BATCH_SIZE)
                start_time = time.time()
                predicted_value, true_value = sess.run([logits, labels],
                                                       feed_dict={
                                                           data: test_data,
                                                           labels: test_label
                                                       })
                duration = time.time() - start_time
                sec_per_batch = float(duration)
                print('sec_per_batch:%.3f/%d' %
                      (sec_per_batch, graphcnn_input.EVAL_BATCH_SIZE))
                total_predicted_value = np.concatenate(
                    (total_predicted_value, predicted_value), axis=0)
                total_true_value = np.concatenate(
                    (total_true_value, true_value), axis=0)
                step += 1

            total_predicted_value = total_predicted_value[1:]
            total_true_value = total_true_value[1:]

            detail_filename = os.path.join(
                FLAGS.eval_dir, 'log_eval_for_predicted_value_dictribution')
            if os.path.exists(detail_filename):
                os.remove(detail_filename)
            np.savetxt(detail_filename, total_predicted_value, fmt='%.4f')
            total_predicted_value = (
                (total_predicted_value) >=
                graphcnn_option.EVALUTION_THRESHOLD_FOR_MULTI_LABEL
            ).astype(int)
            assert total_sample_count == total_predicted_value.shape[
                0], 'sample_count error!'
            detail_filename = os.path.join(FLAGS.eval_dir,
                                           'log_eval_for_predicted_value')
            if os.path.exists(detail_filename):
                os.remove(detail_filename)
            np.savetxt(detail_filename, total_predicted_value, fmt='%d')
            detail_filename = os.path.join(FLAGS.eval_dir,
                                           'log_eval_for_true_value')
            if os.path.exists(detail_filename):
                os.remove(detail_filename)
            np.savetxt(detail_filename, total_true_value, fmt='%d')

            filename_eval_log = os.path.join(FLAGS.eval_dir, 'log_eval')
            file_eval_log = open(filename_eval_log, 'w')
            np.set_printoptions(threshold=np.nan)
            print('\nevaluation:', file=file_eval_log)
            print('\nevaluation:')
            print('  %s, ckpt-%d:' % (datetime.now(), global_step_for_restore),
                  file=file_eval_log)
            print('  %s, ckpt-%d:' % (datetime.now(), global_step_for_restore))

            total_predicted_value = total_predicted_value.astype(bool)
            total_true_value = total_true_value.astype(bool)

            print('  example based evaluations:', file=file_eval_log)
            print('  example based evaluations:')

            equal = total_true_value == total_predicted_value
            match = np.sum(equal, axis=1) == np.size(equal, axis=1)
            exact_match_ratio = np.sum(match) / np.size(match)
            print('      exact_match_ratio = %.4f' % exact_match_ratio,
                  file=file_eval_log)
            print('      exact_match_ratio = %.4f' % exact_match_ratio)

            true_and_predict = np.sum(total_true_value & total_predicted_value,
                                      axis=1)
            true_or_predict = np.sum(total_true_value | total_predicted_value,
                                     axis=1)
            accuracy = np.mean(true_and_predict / true_or_predict)
            print('      accuracy = %.4f' % accuracy, file=file_eval_log)
            print('      accuracy = %.4f' % accuracy)

            precison = np.mean(true_and_predict /
                               (np.sum(total_predicted_value, axis=1) + 1e-9))
            print('      precison = %.4f' % precison, file=file_eval_log)
            print('      precison = %.4f' % precison)

            recall = np.mean(true_and_predict /
                             np.sum(total_true_value, axis=1))
            print('      recall = %.4f' % recall, file=file_eval_log)
            print('      recall = %.4f' % recall)

            F1_Measure = np.mean((true_and_predict * 2) /
                                 (np.sum(total_true_value, axis=1) +
                                  np.sum(total_predicted_value, axis=1)))
            print('      F1_Measure = %.4f' % F1_Measure, file=file_eval_log)
            print('      F1_Measure = %.4f' % F1_Measure)

            HammingLoss = np.mean(total_true_value ^ total_predicted_value)
            print('      HammingLoss = %.4f' % HammingLoss, file=file_eval_log)
            print('      HammingLoss = %.4f' % HammingLoss)

            print('  label based evaluations:', file=file_eval_log)
            print('  label based evaluations:')

            TP = np.sum(total_true_value & total_predicted_value,
                        axis=0,
                        dtype=np.int32)
            FP = np.sum((~total_true_value) & total_predicted_value,
                        axis=0,
                        dtype=np.int32)
            FN = np.sum(total_true_value & (~total_predicted_value),
                        axis=0,
                        dtype=np.int32)

            _P = np.sum(TP) / (np.sum(TP) + np.sum(FP) + 1e-9)
            _R = np.sum(TP) / (np.sum(TP) + np.sum(FN) + 1e-9)
            Micro_F1 = (2 * _P * _R) / (_P + _R)
            print('      P = %.4f' % _P, file=file_eval_log)
            print('      P = %.4f' % _P)
            print('      R = %.4f' % _R, file=file_eval_log)
            print('      R = %.4f' % _R)
            print('      Micro-F1 = %.4f' % Micro_F1, file=file_eval_log)
            print('      Micro-F1 = %.4f' % Micro_F1)

            _P_t = TP / (TP + FP + 1e-9)
            _R_t = TP / (TP + FN + 1e-9)
            Macro_F1 = np.mean((2 * _P_t * _R_t) / (_P_t + _R_t + 1e-9))
            # print('    P_t = %.4f' % _P, file=file_eval_log)
            # print('    P_t = %.4f' % _P)
            # print('    R_t = %.4f' % _R, file=file_eval_log)
            # print('    R_t = %.4f' % _R)
            print('      Macro-F1 = %.4f' % Macro_F1, file=file_eval_log)
            print('      Macro-F1 = %.4f' % Macro_F1)

            print(
                'evaluation samples number:%d, evaluation classes number:%d' %
                (total_predicted_value.shape[0],
                 total_predicted_value.shape[1]),
                file=file_eval_log)
            print(
                'evaluation samples number:%d, evaluation classes number:%d' %
                (total_predicted_value.shape[0],
                 total_predicted_value.shape[1]))
            # print('evaluation detail: ' + os.path.join(FLAGS.eval_dir, 'log_eval_for_true_value')
            #       + ', ' + os.path.join(FLAGS.eval_dir, 'log_eval_for_predicted_value')
            #       + ', ' + os.path.join(FLAGS.eval_dir, 'log_eval_for_predicted_value_dictribution'),
            #       file=file_eval_log)
            # print('evaluation detail: ' + os.path.join(FLAGS.eval_dir, 'log_eval')
            #       + ', ' + os.path.join(FLAGS.eval_dir, 'log_eval_for_true_value')
            #       + ', ' + os.path.join(FLAGS.eval_dir, 'log_eval_for_predicted_value')
            #       + ', ' + os.path.join(FLAGS.eval_dir, 'log_eval_for_predicted_value_dictribution'))
            file_eval_log.close()

            best_eval_ckpt = global_step_for_restore
            best_eval_value = Macro_F1
            sourceFile = os.path.join(FLAGS.eval_dir, 'log_eval')
            targetFile = os.path.join(FLAGS.eval_dir, 'best_eval')
            if os.path.exists(targetFile):
                os.remove(targetFile)
            shutil.copy(sourceFile, targetFile)
            sourceFile = os.path.join(FLAGS.eval_dir,
                                      'log_eval_for_true_value')
            targetFile = os.path.join(FLAGS.eval_dir,
                                      'best_eval_for_true_value')
            if os.path.exists(targetFile):
                os.remove(targetFile)
            shutil.copy(sourceFile, targetFile)
            sourceFile = os.path.join(FLAGS.eval_dir,
                                      'log_eval_for_predicted_value')
            targetFile = os.path.join(FLAGS.eval_dir,
                                      'best_eval_for_predicted_value')
            if os.path.exists(targetFile):
                os.remove(targetFile)
            shutil.copy(sourceFile, targetFile)
            sourceFile = os.path.join(
                FLAGS.eval_dir, 'log_eval_for_predicted_value_dictribution')
            targetFile = os.path.join(
                FLAGS.eval_dir, 'best_eval_for_predicted_value_dictribution')
            if os.path.exists(targetFile):
                os.remove(targetFile)
            shutil.copy(sourceFile, targetFile)
            # sourceFile = ckpt.model_checkpoint_path + '.index'
            # targetFile = os.path.join(FLAGS.eval_dir, 'best_eval.ckpt')
            # if os.path.exists(targetFile):
            #     os.remove(targetFile)
            # shutil.copy(sourceFile, targetFile)

        while True:
            if FLAGS.run_once:
                break
            time.sleep(FLAGS.eval_interval_secs)
            with tf.Session() as sess:
                ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    # extract global_step
                    global_step_for_restore = int(
                        ckpt.model_checkpoint_path.split('/')[-1].split('-')
                        [-1])
                    if global_step_for_restore > best_eval_ckpt:
                        # Restores from checkpoint
                        saver.restore(sess, ckpt.model_checkpoint_path)
                else:
                    print('No checkpoint file found')
                    return

                if global_step_for_restore > best_eval_ckpt:
                    num_iter = int(
                        math.floor(
                            graphcnn_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL /
                            graphcnn_input.EVAL_BATCH_SIZE))
                    total_sample_count = num_iter * graphcnn_input.EVAL_BATCH_SIZE
                    step = 0
                    total_predicted_value = np.zeros(
                        [1, graphcnn_input.NUM_CLASSES], dtype=np.float32)  ##
                    total_true_value = np.zeros(
                        [1, graphcnn_input.NUM_CLASSES], dtype=np.int32)
                    while step < num_iter:
                        test_data, test_label = evalDataSet.next_batch(
                            graphcnn_input.EVAL_BATCH_SIZE)
                        predicted_value, true_value = sess.run(
                            [logits, labels],
                            feed_dict={
                                data: test_data,
                                labels: test_label
                            })
                        total_predicted_value = np.concatenate(
                            (total_predicted_value, predicted_value), axis=0)
                        total_true_value = np.concatenate(
                            (total_true_value, true_value), axis=0)
                        step += 1

                    total_predicted_value = total_predicted_value[1:]
                    total_true_value = total_true_value[1:]

                    detail_filename = os.path.join(
                        FLAGS.eval_dir,
                        'log_eval_for_predicted_value_dictribution')
                    if os.path.exists(detail_filename):
                        os.remove(detail_filename)
                    np.savetxt(detail_filename,
                               total_predicted_value,
                               fmt='%.4f')
                    total_predicted_value = (
                        (total_predicted_value) >=
                        graphcnn_option.EVALUTION_THRESHOLD_FOR_MULTI_LABEL
                    ).astype(int)
                    assert total_sample_count == total_predicted_value.shape[
                        0], 'sample_count error!'
                    detail_filename = os.path.join(
                        FLAGS.eval_dir, 'log_eval_for_predicted_value')
                    if os.path.exists(detail_filename):
                        os.remove(detail_filename)
                    np.savetxt(detail_filename,
                               total_predicted_value,
                               fmt='%d')
                    detail_filename = os.path.join(FLAGS.eval_dir,
                                                   'log_eval_for_true_value')
                    if os.path.exists(detail_filename):
                        os.remove(detail_filename)
                    np.savetxt(detail_filename, total_true_value, fmt='%d')

                    filename_eval_log = os.path.join(FLAGS.eval_dir,
                                                     'log_eval')
                    file_eval_log = open(filename_eval_log, 'a')
                    np.set_printoptions(threshold=np.nan)
                    print('\nevaluation:', file=file_eval_log)
                    print('\nevaluation:')
                    print('  %s, ckpt-%d:' %
                          (datetime.now(), global_step_for_restore),
                          file=file_eval_log)
                    print('  %s, ckpt-%d:' %
                          (datetime.now(), global_step_for_restore))

                    total_predicted_value = total_predicted_value.astype(bool)
                    total_true_value = total_true_value.astype(bool)

                    print('  example based evaluations:', file=file_eval_log)
                    print('  example based evaluations:')

                    equal = total_true_value == total_predicted_value
                    match = np.sum(equal, axis=1) == np.size(equal, axis=1)
                    exact_match_ratio = np.sum(match) / np.size(match)
                    print('      exact_match_ratio = %.4f' % exact_match_ratio,
                          file=file_eval_log)
                    print('      exact_match_ratio = %.4f' % exact_match_ratio)

                    true_and_predict = np.sum(total_true_value
                                              & total_predicted_value,
                                              axis=1)
                    true_or_predict = np.sum(total_true_value
                                             | total_predicted_value,
                                             axis=1)
                    accuracy = np.mean(true_and_predict / true_or_predict)
                    print('      accuracy = %.4f' % accuracy,
                          file=file_eval_log)
                    print('      accuracy = %.4f' % accuracy)

                    precison = np.mean(
                        true_and_predict /
                        (np.sum(total_predicted_value, axis=1) + 1e-9))
                    print('      precison = %.4f' % precison,
                          file=file_eval_log)
                    print('      precison = %.4f' % precison)

                    recall = np.mean(true_and_predict /
                                     np.sum(total_true_value, axis=1))
                    print('      recall = %.4f' % recall, file=file_eval_log)
                    print('      recall = %.4f' % recall)

                    F1_Measure = np.mean(
                        (true_and_predict * 2) /
                        (np.sum(total_true_value, axis=1) +
                         np.sum(total_predicted_value, axis=1)))
                    print('      F1_Measure = %.4f' % F1_Measure,
                          file=file_eval_log)
                    print('      F1_Measure = %.4f' % F1_Measure)

                    HammingLoss = np.mean(total_true_value
                                          ^ total_predicted_value)
                    print('      HammingLoss = %.4f' % HammingLoss,
                          file=file_eval_log)
                    print('      HammingLoss = %.4f' % HammingLoss)

                    print('  label based evaluations:', file=file_eval_log)
                    print('  label based evaluations:')

                    TP = np.sum(total_true_value & total_predicted_value,
                                axis=0,
                                dtype=np.int32)
                    FP = np.sum((~total_true_value) & total_predicted_value,
                                axis=0,
                                dtype=np.int32)
                    FN = np.sum(total_true_value & (~total_predicted_value),
                                axis=0,
                                dtype=np.int32)

                    _P = np.sum(TP) / (np.sum(TP) + np.sum(FP) + 1e-9)
                    _R = np.sum(TP) / (np.sum(TP) + np.sum(FN) + 1e-9)
                    Micro_F1 = (2 * _P * _R) / (_P + _R + 1e-9)
                    print('      P = %.4f' % _P, file=file_eval_log)
                    print('      P = %.4f' % _P)
                    print('      R = %.4f' % _R, file=file_eval_log)
                    print('      R = %.4f' % _R)
                    print('      Micro-F1 = %.4f' % Micro_F1,
                          file=file_eval_log)
                    print('      Micro-F1 = %.4f' % Micro_F1)

                    _P_t = TP / (TP + FP + 1e-9)
                    _R_t = TP / (TP + FN + 1e-9)
                    # assert _P_t.shape[0]==graphcnn_input.NUM_CLASSES, '_P_t has a wrong size'
                    Macro_F1 = np.mean(
                        (2 * _P_t * _R_t) / (_P_t + _R_t + 1e-9))
                    # print('    P_t = %.4f' % _P, file=file_eval_log)
                    # print('    P_t = %.4f' % _P)
                    # print('    R_t = %.4f' % _R, file=file_eval_log)
                    # print('    R_t = %.4f' % _R)
                    print('      Macro-F1 = %.4f' % Macro_F1,
                          file=file_eval_log)
                    print('      Macro-F1 = %.4f' % Macro_F1)

                    print(
                        'evaluation samples number:%d, evaluation classes number:%d'
                        % (total_predicted_value.shape[0],
                           total_predicted_value.shape[1]),
                        file=file_eval_log)
                    print(
                        'evaluation samples number:%d, evaluation classes number:%d'
                        % (total_predicted_value.shape[0],
                           total_predicted_value.shape[1]))
                    # print('evaluation detail: ' + os.path.join(FLAGS.eval_dir, 'log_eval_for_true_value')
                    #       + ', '+os.path.join(FLAGS.eval_dir, 'log_eval_for_predicted_value'),
                    #       file=file_eval_log)
                    # print('evaluation detail: ' + os.path.join(FLAGS.eval_dir, 'log_eval')
                    #       + ', '+ os.path.join(FLAGS.eval_dir, 'log_eval_for_true_value')
                    #       + ', '+ os.path.join(FLAGS.eval_dir, 'log_eval_for_predicted_value')
                    #       + ', '+ os.path.join(FLAGS.eval_dir, 'log_eval_for_predicted_value_dictribution'))

                    file_eval_log.close()

                    if accuracy > best_eval_value:
                        best_eval_value = Macro_F1

                        filename_eval_best = os.path.join(
                            FLAGS.eval_dir, 'best_eval')
                        file_eval_best = open(filename_eval_best, 'w')
                        print('evaluation:', file=file_eval_best)
                        print('  %s, ckpt-%d:' %
                              (datetime.now(), global_step_for_restore),
                              file=file_eval_best)
                        print('  example based evaluations:',
                              file=file_eval_best)
                        print('      exact_match_ratio = %.4f' %
                              exact_match_ratio,
                              file=file_eval_best)
                        print('      accuracy = %.4f' % accuracy,
                              file=file_eval_best)
                        print('      precison = %.4f' % precison,
                              file=file_eval_best)
                        print('      recall = %.4f' % recall,
                              file=file_eval_best)
                        print('      F1_Measure = %.4f' % F1_Measure,
                              file=file_eval_best)
                        print('      HammingLoss = %.4f' % HammingLoss,
                              file=file_eval_best)
                        print('  label based evaluations:',
                              file=file_eval_best)
                        print('      P = %.4f' % _P, file=file_eval_best)
                        print('      R = %.4f' % _R, file=file_eval_best)
                        print('      Micro-F1 = %.4f' % Micro_F1,
                              file=file_eval_best)
                        print('      Macro-F1 = %.4f' % Macro_F1,
                              file=file_eval_best)
                        print(
                            'evaluation samples number:%d, evaluation classes number:%d'
                            % (total_predicted_value.shape[0],
                               total_predicted_value.shape[1]),
                            file=file_eval_best)
                        print(
                            'evaluation detail: ' + os.path.join(
                                FLAGS.eval_dir, 'best_eval_for_true_value') +
                            ', ' +
                            os.path.join(FLAGS.eval_dir,
                                         'best_eval_for_predicted_value') +
                            ', ' + os.path.join(
                                FLAGS.eval_dir,
                                'best_eval_for_predicted_value_dictribution'),
                            file=file_eval_best)
                        file_eval_best.close()

                        sourceFile = os.path.join(FLAGS.eval_dir,
                                                  'log_eval_for_true_value')
                        targetFile = os.path.join(FLAGS.eval_dir,
                                                  'best_eval_for_true_value')
                        if os.path.exists(targetFile):
                            os.remove(targetFile)
                        shutil.copy(sourceFile, targetFile)
                        sourceFile = os.path.join(
                            FLAGS.eval_dir, 'log_eval_for_predicted_value')
                        targetFile = os.path.join(
                            FLAGS.eval_dir, 'best_eval_for_predicted_value')
                        if os.path.exists(targetFile):
                            os.remove(targetFile)
                        shutil.copy(sourceFile, targetFile)
                        sourceFile = os.path.join(
                            FLAGS.eval_dir,
                            'log_eval_for_predicted_value_dictribution')
                        targetFile = os.path.join(
                            FLAGS.eval_dir,
                            'best_eval_for_predicted_value_dictribution')
                        if os.path.exists(targetFile):
                            os.remove(targetFile)
                        shutil.copy(sourceFile, targetFile)
                        # sourceFile = ckpt.model_checkpoint_path
                        # targetFile = os.path.join(FLAGS.eval_dir, 'best_eval.ckpt')
                        # if os.path.exists(targetFile):
                        #     os.remove(targetFile)
                        # shutil.copy(sourceFile, targetFile)
                    best_eval_ckpt = global_step_for_restore
Exemple #2
0
def evaluate(checkpoint, test_index_array):
    with tf.Graph().as_default() as g, tf.device('/cpu:0'):
        # Get images and labels
        data = tf.placeholder(tf.float32, [
            graphcnn_input.EVAL_BATCH_SIZE, graphcnn_input.HEIGHT,
            graphcnn_input.WIDTH, graphcnn_input.NUM_CHANNELS
        ])
        # labels = tf.placeholder(tf.int32, [graphcnn_input.EVAL_BATCH_SIZE,graphcnn_input.NUM_CLASSES])

        # inference
        logits = graphcnn_model.inference(data, eval_data=True)
        # logits = graphcnn_model.inference_CPU(data, eval_data=True, dependencies_loss=False)

        # multi-label sigmoid
        logits = tf.sigmoid(logits)

        # Restore the moving average version of the learned variables for eval. # ?????????????????????????
        variable_averages = tf.train.ExponentialMovingAverage(
            graphcnn_option.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        # Build the summary operation based on the TF collection of Summaries.
        # summary_op = tf.merge_all_summaries()
        # summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, g)

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True,
                log_device_placement=FLAGS.log_device_placement)) as sess:
            if checkpoint == '0':
                ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    # Restores from checkpoint
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    # extract global_step
                    global_step_for_restore = int(
                        ckpt.model_checkpoint_path.split('/')[-1].split('-')
                        [-1])
                else:
                    print('No checkpoint file found')
                    return
            else:
                if os.path.exists(
                        os.path.join(FLAGS.checkpoint_dir,
                                     'model.ckpt-' + checkpoint)):
                    saver.restore(
                        sess,
                        os.path.join(FLAGS.checkpoint_dir,
                                     'model.ckpt-' + checkpoint))
                    global_step_for_restore = int(checkpoint)
                else:
                    print('No checkpoint file found')
                    return

            num_iter = int(
                math.floor(graphcnn_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL /
                           graphcnn_input.EVAL_BATCH_SIZE))
            total_sample_count = num_iter * graphcnn_input.EVAL_BATCH_SIZE
            step = 0
            total_predicted_value = np.zeros([1, graphcnn_input.NUM_CLASSES],
                                             dtype=np.float32)  ##
            while step < num_iter:
                test_data = evalDataSet.next_batch(
                    graphcnn_input.EVAL_BATCH_SIZE)
                predicted_value = sess.run(logits, feed_dict={data: test_data})
                total_predicted_value = np.concatenate(
                    (total_predicted_value, predicted_value), axis=0)
                step += 1

            total_predicted_value = total_predicted_value[1:]

            detail_filename = os.path.join(
                FLAGS.eval_dir, 'log_eval_for_predicted_value_dictribution')
            if os.path.exists(detail_filename):
                os.remove(detail_filename)
            np.savetxt(detail_filename, total_predicted_value, fmt='%.4f')
            total_predicted_value_argmax = np.argmax(total_predicted_value,
                                                     axis=1)
            total_predicted_value = (
                (total_predicted_value) >=
                EVALUTION_THRESHOLD_FOR_MULTI_LABEL).astype(int)
            assert total_sample_count == total_predicted_value.shape[
                0], 'sample_count error!'
            detail_filename = os.path.join(FLAGS.eval_dir,
                                           'log_eval_for_predicted_value')
            if os.path.exists(detail_filename):
                os.remove(detail_filename)
            np.savetxt(detail_filename, total_predicted_value, fmt='%d')

            filename = os.path.join(graphcnn_option.EVAL_DATA_DIR,
                                    graphcnn_option.DATA_LABELS_REMAP_NAME)
            total_remap = np.loadtxt(filename, dtype=int)

            detail_filename = os.path.join(
                graphcnn_option.EVAL_DATA_DIR, graphcnn_option.HIER_DIR_NAME,
                graphcnn_option.HIER_labels_remap_file)
            remap = np.loadtxt(detail_filename, dtype=int)

            filename = os.path.join('../hier_result_leaf',
                                    graphcnn_option.HIER_eval_result_leaf_file)
            fr_leaf = open(filename, 'a')
            filename = os.path.join('../hier_result_root',
                                    graphcnn_option.HIER_eval_result_root_file)
            fr_root = open(filename, 'w')

            # filename = os.path.join(graphcnn_option.EVAL_DATA_DIR, 'hier_rootstr')
            # fr = open(filename, 'r')
            # rootstr = fr.readlines()
            # fr.close()
            # filename = os.path.join(graphcnn_option.EVAL_DATA_DIR, 'hier_rootlist')
            # fr = open(filename, 'r')
            # rootlines = fr.readlines()
            # fr.close()
            # rootlist = []
            # for line in rootlines:
            #     line = line.strip()
            #     linelist = line.split(' ')
            #     linelist = [int(k) for k in linelist]
            #     rootlist.append(linelist)

            # rootstr_tmp = []
            detail_filename = os.path.join(
                FLAGS.eval_dir, 'log_eval_for_predicted_value_list')
            fr = open(detail_filename, 'w')
            for i in range(0, np.size(total_predicted_value, axis=0)):
                labels = np.where(total_predicted_value[i] == 1)[0]
                if len(labels) > 0:
                    labels_remap = remap[labels, 0]
                    for elem in labels_remap:
                        print(elem, end=' ', file=fr)
                        if elem in total_remap[:, 0]:  # leaf
                            print('%d %d' % (test_index_array[i], elem),
                                  file=fr_leaf)
                        else:
                            print('%d %d' % (test_index_array[i], elem),
                                  file=fr_root)
                            # for j in range(0,len(rootlist)):
                            #     if elem in rootlist[j]:
                            #         if rootstr[j] not in rootstr_tmp:
                            #             rootstr_tmp.append(rootstr[j])
                    print('', file=fr)
                else:
                    # labels_remap = remap[:, 0]
                    labels = total_predicted_value_argmax[i]
                    labels_remap = remap[labels, 0]
                    for elem in labels_remap:
                        print(elem, end=' ', file=fr)
                        if elem in total_remap[:, 0]:  # leaf
                            print('%d %d' % (test_index_array[i], elem),
                                  file=fr_leaf)
                        else:
                            print('%d %d' % (test_index_array[i], elem),
                                  file=fr_root)
                            # for j in range(0,len(rootlist)):
                            #     if elem in rootlist[j]:
                            #         if rootstr[j] not in rootstr_tmp:
                            #             rootstr_tmp.append(rootstr[j])
                    print('', file=fr)
            fr.close()
            fr_leaf.close()
            fr_root.close()

            # filename = os.path.join(FLAGS.eval_dir, 'hier_next_root')
            # fr = open(filename, 'w')
            # for one in rootstr_tmp:
            #     print(one)
            #     print(one,file=fr)
            # fr.close()

            filename_eval_log = os.path.join(FLAGS.eval_dir, 'log_eval')
            file_eval_log = open(filename_eval_log, 'w')
            np.set_printoptions(threshold=np.nan)
            print('\nevaluation:', file=file_eval_log)
            print('\nevaluation:')
            print('  %s, ckpt-%d' % (datetime.now(), global_step_for_restore),
                  file=file_eval_log)
            print('  %s, ckpt-%d' % (datetime.now(), global_step_for_restore))
            print('evaluation is end...')
            print('evaluation is end...', file=file_eval_log)

            print(
                'evaluation samples number:%d, evaluation classes number:%d' %
                (total_predicted_value.shape[0],
                 total_predicted_value.shape[1]),
                file=file_eval_log)
            print(
                'evaluation samples number:%d, evaluation classes number:%d' %
                (total_predicted_value.shape[0],
                 total_predicted_value.shape[1]))
            print(
                'evaluation detail: ' + ', ' +
                os.path.join(FLAGS.eval_dir, 'log_eval_for_predicted_value') +
                ', ' +
                os.path.join(FLAGS.eval_dir,
                             'log_eval_for_predicted_value_dictribution'),
                file=file_eval_log)
            print(
                'evaluation detail: ' +
                os.path.join(FLAGS.eval_dir, 'log_eval') + ', ' +
                os.path.join(FLAGS.eval_dir, 'log_eval_for_predicted_value') +
                ', ' +
                os.path.join(FLAGS.eval_dir,
                             'log_eval_for_predicted_value_dictribution'))
            file_eval_log.close()
def evaluate(checkpoint):
    with tf.Graph().as_default() as g, tf.device('/cpu:0'):
        # Get images and labels
        data = tf.placeholder(tf.float32, [
            graphcnn_input.EVAL_BATCH_SIZE, graphcnn_input.HEIGHT,
            graphcnn_input.WIDTH, graphcnn_input.NUM_CHANNELS
        ])
        labels = tf.placeholder(
            tf.int32,
            [graphcnn_input.EVAL_BATCH_SIZE, graphcnn_input.NUM_CLASSES])

        # inference
        logits, conv_vectors = graphcnn_model.inference(data,
                                                        eval_data=True,
                                                        eigenvectors=True)

        # Restore the moving average version of the learned variables for eval. # ?????????????????????????
        variable_averages = tf.train.ExponentialMovingAverage(
            graphcnn_option.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        # Build the summary operation based on the TF collection of Summaries.
        # summary_op = tf.merge_all_summaries()
        # summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, g)

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True,
                log_device_placement=FLAGS.log_device_placement)) as sess:
            if checkpoint == '0':
                ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    # Restores from checkpoint
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    # extract global_step
                    global_step_for_restore = int(
                        ckpt.model_checkpoint_path.split('/')[-1].split('-')
                        [-1])
                else:
                    print('No checkpoint file found')
                    return
            else:
                if os.path.exists(
                        os.path.join(FLAGS.checkpoint_dir,
                                     'model.ckpt-' + checkpoint)):
                    saver.restore(
                        sess,
                        os.path.join(FLAGS.checkpoint_dir,
                                     'model.ckpt-' + checkpoint))
                    global_step_for_restore = int(checkpoint)
                else:
                    print('No checkpoint file found')
                    return

            num_iter = int(
                math.floor(graphcnn_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL /
                           graphcnn_input.EVAL_BATCH_SIZE))
            total_sample_count = num_iter * graphcnn_input.EVAL_BATCH_SIZE
            step = 0

            total_samples_X_eigenvector_list = []  ##
            total_true_value_list = []
            while step < num_iter:
                test_data, test_label = evalDataSet.next_batch(
                    graphcnn_input.EVAL_BATCH_SIZE)
                samples_X_eigenvector, samples_Y = sess.run(
                    [conv_vectors, labels],
                    feed_dict={
                        data: test_data,
                        labels: test_label
                    })
                total_samples_X_eigenvector_list.append(samples_X_eigenvector)
                total_true_value_list.append(samples_Y)
                step += 1

            total_samples_X_eigenvector = np.concatenate(
                total_samples_X_eigenvector_list, axis=0)
            total_true_value = np.concatenate(total_true_value_list, axis=0)

            assert total_sample_count == total_samples_X_eigenvector.shape[
                0], 'sample_count error! %d != %d' % (
                    total_sample_count, total_samples_X_eigenvector.shape[0])
            total_predicted_value_list = []
            for bin_class in range(0, graphcnn_input.NUM_CLASSES):
                # total_samples_Y = np.zeros([total_sample_count,1],dtype=int)
                total_samples_Y = total_true_value[:, bin_class]
                clf = svm.SVC()  # class
                clf.fit(total_samples_X_eigenvector,
                        total_samples_Y)  # training the svc model
def train(newTrain, checkpoint):
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        data = tf.placeholder(tf.float32, [
            graphcnn_input.TRAIN_BATCH_SIZE, graphcnn_input.HEIGHT,
            graphcnn_input.WIDTH, graphcnn_input.NUM_CHANNELS
        ])
        labels = tf.placeholder(
            tf.int32,
            [graphcnn_input.TRAIN_BATCH_SIZE, graphcnn_input.NUM_CLASSES])

        # inference model.
        # logits = graphcnn_model.inference_GPU(data)
        logits = graphcnn_model.inference(data)
        # logits = graphcnn_model.inference_CPU(data,dependencies_loss=False)

        # Calculate loss.
        loss = graphcnn_model.loss(logits, labels)

        # updates the model parameters.
        train_op = graphcnn_model.train(loss, global_step)

        # Create a saver.
        saver = tf.train.Saver(var_list=tf.global_variables(),
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=10)

        if graphcnn_option.SUMMARYWRITER:
            # Build the summary operation based on the TF collection of Summaries.
            summary_op = tf.merge_all_summaries()

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU implementations.
        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=FLAGS.log_device_placement))

        first_step = 0
        if not newTrain:
            if checkpoint == '0':  # choose the latest one
                ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    new_saver = tf.train.import_meta_graph(
                        ckpt.model_checkpoint_path + '.meta')
                    # Restores from checkpoint
                    new_saver.restore(sess, ckpt.model_checkpoint_path)
                    global_step_for_restore = ckpt.model_checkpoint_path.split(
                        '/')[-1].split('-')[-1]
                    first_step = int(global_step_for_restore) + 1
                else:
                    print('No checkpoint file found')
                    return
            else:  #
                if os.path.exists(
                        os.path.join(FLAGS.train_dir,
                                     'model.ckpt-' + checkpoint)):
                    new_saver = tf.train.import_meta_graph(
                        os.path.join(FLAGS.train_dir,
                                     'model.ckpt-' + checkpoint + '.meta'))
                    new_saver.restore(
                        sess,
                        os.path.join(FLAGS.train_dir,
                                     'model.ckpt-' + checkpoint))
                    first_step = int(checkpoint) + 1
                else:
                    print('No checkpoint file found')
                    return
        else:
            sess.run(init)

        if graphcnn_option.SUMMARYWRITER:
            summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
                                                    sess.graph)

        filename_train_log = os.path.join(FLAGS.train_dir, 'log_train')
        if os.path.exists(filename_train_log):
            file_train_log = open(filename_train_log, 'a')
        else:
            file_train_log = open(filename_train_log, 'w')

        # learning_rate = graphcnn_option.lr_decay_value[0]  # 0.1(5), 0.01(100), 0.001(500), 0.0001(300), 0.00001(100)
        # learning_rate_index = 0
        for step in range(first_step, MAX_STEPS):
            # if learning_rate_index < len(graphcnn_option.lr_decay_value) - 1:
            #     if step > STEPS_PER_ECOPH * graphcnn_option.lr_decay_ecophs[learning_rate_index]:
            #         learning_rate_index = learning_rate_index + 1
            #         learning_rate = graphcnn_option.lr_decay_value[learning_rate_index]

            train_data, train_label = trainDataSet.next_batch(
                graphcnn_input.TRAIN_BATCH_SIZE)
            start_time = time.time()
            _, loss_value = sess.run([train_op, loss],
                                     feed_dict={
                                         data: train_data,
                                         labels: train_label
                                     })
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % 10 == 0:
                sec_per_batch = float(duration)
                format_str = ('%s: step=%d, loss=%.4f; %.3f sec/batch)')
                print(format_str %
                      (datetime.now(), step, loss_value, sec_per_batch),
                      file=file_train_log)
                print(format_str %
                      (datetime.now(), step, loss_value, sec_per_batch))

            if graphcnn_option.SUMMARYWRITER:
                if step % 100 == 0:
                    summary_str = sess.run(summary_op,
                                           feed_dict={
                                               data: train_data,
                                               labels: train_label
                                           })
                    summary_writer.add_summary(summary_str, step)

            # Save the model checkpoint periodically. (named 'model.ckpt-global_step.meta')
            if step % CKPT_PERIOD == 0 or (step + 1) == MAX_STEPS:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)
        file_train_log.close()
def evaluate(checkpoint, test_index_array):
    with tf.Graph().as_default() as g, tf.device('/cpu:0'):
        # Get images and labels
        data = tf.placeholder(tf.float32, [
            graphcnn_input.EVAL_BATCH_SIZE, graphcnn_input.HEIGHT,
            graphcnn_input.WIDTH, graphcnn_input.NUM_CHANNELS
        ])
        # labels = tf.placeholder(tf.int32, [graphcnn_input.EVAL_BATCH_SIZE,graphcnn_input.NUM_CLASSES])

        # inference
        logits, conv_vectors = graphcnn_model.inference(data,
                                                        eval_data=True,
                                                        eigenvectors=True)

        # Restore the moving average version of the learned variables for eval. # ?????????????????????????
        variable_averages = tf.train.ExponentialMovingAverage(
            graphcnn_option.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)

        # Build the summary operation based on the TF collection of Summaries.
        # summary_op = tf.merge_all_summaries()
        # summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, g)

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True,
                log_device_placement=FLAGS.log_device_placement)) as sess:
            if checkpoint == '0':
                ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    # Restores from checkpoint
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    # extract global_step
                    global_step_for_restore = int(
                        ckpt.model_checkpoint_path.split('/')[-1].split('-')
                        [-1])
                else:
                    print('No checkpoint file found')
                    return
            else:
                if os.path.exists(
                        os.path.join(FLAGS.checkpoint_dir,
                                     'model.ckpt-' + checkpoint)):
                    saver.restore(
                        sess,
                        os.path.join(FLAGS.checkpoint_dir,
                                     'model.ckpt-' + checkpoint))
                    global_step_for_restore = int(checkpoint)
                else:
                    print('No checkpoint file found')
                    return

            num_iter = int(
                math.floor(graphcnn_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL /
                           graphcnn_input.EVAL_BATCH_SIZE))
            total_sample_count = num_iter * graphcnn_input.EVAL_BATCH_SIZE
            step = 0
            total_eval_X_eigenvector_list = []
            while step < num_iter:
                test_data = evalDataSet.next_batch(
                    graphcnn_input.EVAL_BATCH_SIZE)
                samples_X_eigenvector = sess.run(conv_vectors,
                                                 feed_dict={data: test_data})
                total_eval_X_eigenvector_list.append(samples_X_eigenvector)
                step += 1
            total_eval_X_eigenvector = np.concatenate(
                total_eval_X_eigenvector_list, axis=0)

            assert total_sample_count == total_eval_X_eigenvector.shape[
                0], 'sample_count error! %d != %d' % (
                    total_sample_count, total_eval_X_eigenvector.shape[0])

            total_predicted_value = evaluate_SVM(checkpoint,
                                                 total_eval_X_eigenvector)

            detail_filename = os.path.join(
                FLAGS.eval_dir,
                'log_eval_for_predicted_value_dictribution_all')
            if os.path.exists(detail_filename):
                os.remove(detail_filename)
            np.savetxt(detail_filename, total_predicted_value, fmt='%.4f')

            filename_eval_log = os.path.join(FLAGS.eval_dir, 'log_eval')
            file_eval_log = open(filename_eval_log, 'w')
            np.set_printoptions(threshold=np.nan)
            print('\nevaluation:', file=file_eval_log)
            print('\nevaluation:')
            print('  %s, ckpt-%d' % (datetime.now(), global_step_for_restore),
                  file=file_eval_log)
            print('  %s, ckpt-%d' % (datetime.now(), global_step_for_restore))
            print('evaluation is end...')
            print('evaluation is end...', file=file_eval_log)

            print(
                'evaluation samples number:%d, evaluation classes number:%d' %
                (total_predicted_value.shape[0],
                 total_predicted_value.shape[1]),
                file=file_eval_log)
            print(
                'evaluation samples number:%d, evaluation classes number:%d' %
                (total_predicted_value.shape[0],
                 total_predicted_value.shape[1]))
            print(
                'evaluation detail: ' + ', ' +
                os.path.join(FLAGS.eval_dir, 'log_eval_for_predicted_value') +
                ', ' +
                os.path.join(FLAGS.eval_dir,
                             'log_eval_for_predicted_value_dictribution'),
                file=file_eval_log)
            print(
                'evaluation detail: ' +
                os.path.join(FLAGS.eval_dir, 'log_eval') + ', ' +
                os.path.join(FLAGS.eval_dir, 'log_eval_for_predicted_value') +
                ', ' +
                os.path.join(FLAGS.eval_dir,
                             'log_eval_for_predicted_value_dictribution'))
            file_eval_log.close()