Exemple #1
0
def test_save_json_not_existed_dir(temp_dir):
    """Test for save json to not existed dir"""
    data = json.dumps({"k": "v", "list": [1, 2, 3]})
    dist = os.path.join(temp_dir, 'not_existed')
    save_json(dist, data, step=1)

    assert os.path.exists(os.path.join(dist, "json", "1.json"))
Exemple #2
0
def evaluate(config, restore_path, output_dir):
    if restore_path is None:
        restore_file = executor.search_restore_filename(environment.CHECKPOINTS_DIR)
        restore_path = os.path.join(environment.CHECKPOINTS_DIR, restore_file)

    if not file_io.exists("{}.index".format(restore_path)):
        raise Exception("restore file {} dont exists.".format(restore_path))

    if output_dir is None:
        output_dir = os.path.join(os.path.dirname(os.path.dirname(restore_path)), "evaluate")

    logger.info(f"restore_path:{restore_path}")

    DatasetClass = config.DATASET_CLASS
    ModelClass = config.NETWORK_CLASS
    network_kwargs = {key.lower(): val for key, val in config.NETWORK.items()}

    if "test" in DatasetClass.available_subsets:
        subset = "test"
    else:
        subset = "validation"

    validation_dataset = setup_dataset(config, subset, seed=0)

    graph = tf.Graph()
    with graph.as_default():

        if ModelClass.__module__.startswith("blueoil.networks.object_detection"):
            model = ModelClass(
                classes=validation_dataset.classes,
                num_max_boxes=validation_dataset.num_max_boxes,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )

        else:
            model = ModelClass(
                classes=validation_dataset.classes,
                is_debug=config.IS_DEBUG,
                **network_kwargs,
            )

        is_training = tf.constant(False, name="is_training")

        images_placeholder, labels_placeholder = model.placeholders()

        output = model.inference(images_placeholder, is_training)

        metrics_ops_dict, metrics_update_op = model.metrics(output, labels_placeholder)
        model.summary(output, labels_placeholder)

        summary_op = tf.compat.v1.summary.merge_all()
        metrics_summary_op = executor.metrics_summary_op(metrics_ops_dict)

        init_op = tf.compat.v1.global_variables_initializer()
        reset_metrics_op = tf.compat.v1.local_variables_initializer()
        saver = tf.compat.v1.train.Saver(max_to_keep=None)

    session_config = None  # tf.ConfigProto(log_device_placement=True)
    sess = tf.compat.v1.Session(graph=graph, config=session_config)
    sess.run([init_op, reset_metrics_op])

    validation_writer = tf.compat.v1.summary.FileWriter(environment.TENSORBOARD_DIR + "/evaluate")

    saver.restore(sess, restore_path)

    last_step = sess.run(model.global_step)

    # init metrics values
    test_step_size = int(math.ceil(validation_dataset.num_per_epoch / config.BATCH_SIZE))
    logger.info(f"test_step_size{test_step_size}")

    for test_step in range(test_step_size):
        logger.info(f"test_step{test_step}")

        images, labels = validation_dataset.feed()
        feed_dict = {
            images_placeholder: images,
            labels_placeholder: labels,
        }

        # Summarize at only last step.
        if test_step == test_step_size - 1:
            summary, _ = sess.run([summary_op, metrics_update_op], feed_dict=feed_dict)
            validation_writer.add_summary(summary, last_step)
        else:
            sess.run([metrics_update_op], feed_dict=feed_dict)

    metrics_summary = sess.run(metrics_summary_op)
    validation_writer.add_summary(metrics_summary, last_step)

    is_tfds = "TFDS_KWARGS" in config.DATASET
    dataset_name = config.DATASET.TFDS_KWARGS["name"] if is_tfds else config.DATASET_CLASS.__name__
    dataset_path = config.DATASET.TFDS_KWARGS["data_dir"] if is_tfds else ""

    metrics_dict = {
        'task_type': config.TASK.value,
        'network_name': config.NETWORK_CLASS.__name__,
        'dataset_name': dataset_name,
        'dataset_path': dataset_path,
        'last_step': int(last_step),
        'metrics': {k: float(sess.run(op)) for k, op in metrics_ops_dict.items()},
    }
    save_json(output_dir, json.dumps(metrics_dict, indent=4,), metrics_dict["last_step"])
    validation_dataset.close()
Exemple #3
0
def test_save_json_with_invalid_step(temp_dir):
    """Test for save json with invalid step arg"""
    data = json.dumps({"k": "v", "list": [1, 2, 3]})

    with pytest.raises(ValueError):
        save_json(temp_dir, data, step={"invalid": "dict"})
Exemple #4
0
def test_save_json(temp_dir):
    """Test for save json to existed dir"""
    data = json.dumps({"k": "v", "list": [1, 2, 3]})
    save_json(temp_dir, data, step=1)

    assert os.path.exists(os.path.join(temp_dir, "json", "1.json"))