Beispiel #1
0
def estimate_saliency():
    """The main function for executing network testing. It loads the specified
       dataset iterator and optimized saliency model. By default, when no model
       checkpoint is found locally, the pretrained weights will be downloaded.
       Testing only works for models trained on the same device as specified in
       the config file.

    Args:
        dataset (str): Denotes the dataset that was used during training.
        paths (dict, str): A dictionary with all path elements.
        device (str): Represents either "cpu" or "gpu".
    """

    iterator = data.get_dataset_iterator("results/hpe/")

    next_element, init_op = iterator

    input_images, original_shape, file_path = next_element

    graph_def = tf.compat.v1.GraphDef()

    with tf.compat.v1.gfile.Open(SALIENCY_MODEL_PATH, "rb") as file:
        graph_def.ParseFromString(file.read())

    [predicted_maps
     ] = tf.compat.v1.import_graph_def(graph_def,
                                       input_map={"input": input_images},
                                       return_elements=["output:0"])

    jpeg = data.postprocess_saliency_map(predicted_maps[0], original_shape[0])

    print(">> Estimating Saliency...")

    with tf.compat.v1.Session() as sess:
        sess.run(init_op)

        while True:
            try:
                output_file, path = sess.run([jpeg, file_path])
            except tf.compat.v1.errors.OutOfRangeError:
                break

            path = path[0][0].decode("utf-8")

            filename = os.path.basename(path)
            filename = os.path.splitext(filename)[0]
            filename += ".jpg"

            try:
                with open("results/saliency/" + filename, "wb") as file:
                    file.write(output_file)
                print("Done!")
            except:
                print("Failed to write file.")
def train_model(dataset, paths, device):
    """The main function for executing network training. It loads the specified
       dataset iterator, saliency model, and helper classes. Training is then
       performed in a new session by iterating over all batches for a number of
       epochs. After validation on an independent set, the model is saved and
       the training history is updated.

    Args:
        dataset (str): Denotes the dataset to be used during training.
        paths (dict, str): A dictionary with all path elements.
        device (str): Represents either "cpu" or "gpu".
    """

    iterator = data.get_dataset_iterator("train", dataset, paths["data"])

    next_element, train_init_op, valid_init_op = iterator

    input_images, ground_truths = next_element[:2]

    input_plhd = tf.placeholder_with_default(input_images,
                                             (None, None, None, 3),
                                             name="input")
    
    #training = tf.placeholder(tf.bool, name="training")  ## For BN
    
    msi_net = model_bn.MSINET(is_train=True)

    predicted_maps = msi_net.forward(input_plhd)

    optimizer, loss = msi_net.train(ground_truths, predicted_maps,
                                    config.PARAMS["learning_rate"])

    n_train_data = getattr(data, dataset.upper()).n_train
    n_valid_data = getattr(data, dataset.upper()).n_valid

    n_train_batches = int(np.ceil(n_train_data / config.PARAMS["batch_size"]))
    n_valid_batches = int(np.ceil(n_valid_data / config.PARAMS["batch_size"]))

    history = utils.History(n_train_batches,
                            n_valid_batches,
                            dataset,
                            paths["history"],
                            device)

    progbar = utils.Progbar(n_train_data,
                            n_train_batches,
                            config.PARAMS["batch_size"],
                            config.PARAMS["n_epochs"],
                            history.prior_epochs)

    #training = tf.placeholder(tf.bool, name="training")   ## For BN
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = msi_net.restore(sess, dataset, paths, device)

        print(">> Start training on %s..." % dataset.upper())

        for epoch in range(config.PARAMS["n_epochs"]):
            sess.run(train_init_op)

            for batch in range(n_train_batches):
                #_, error = sess.run([optimizer, loss], feed_dict={training: True})
                _, error = sess.run([optimizer, loss])

                history.update_train_step(error)
                progbar.update_train_step(batch)

            sess.run(valid_init_op)

            for batch in range(n_valid_batches):
                #error = sess.run(loss, feed_dict={training: False})
                error = sess.run(loss)

                history.update_valid_step(error)
                progbar.update_valid_step()

            msi_net.save(saver, sess, dataset, paths["latest"], device)

            history.save_history()

            progbar.write_summary(history.get_mean_train_error(),
                                  history.get_mean_valid_error())

            if history.valid_history[-1] == min(history.valid_history):
                msi_net.save(saver, sess, dataset, paths["best"], device)
                msi_net.optimize(sess, dataset, paths["best"], device)

                print("\tBest model!", flush=True)
def test_model(dataset, paths, device):
    """The main function for executing network testing. It loads the specified
       dataset iterator and optimized saliency model. By default, when no model
       checkpoint is found locally, the pretrained weights will be downloaded.
       Testing only works for models trained on the same device as specified in
       the config file.

    Args:
        dataset (str): Denotes the dataset that was used during training.
        paths (dict, str): A dictionary with all path elements.
        device (str): Represents either "cpu" or "gpu".
    """

    iterator = data.get_dataset_iterator("test", dataset, paths["data"])

    next_element, init_op = iterator

    input_images, original_shape, file_path = next_element

    #training = tf.placeholder(tf.bool, name="training")   ## For BN
    
    graph_def = tf.GraphDef()

    model_name = "model_%s_%s.pb" % (dataset, device)

    if os.path.isfile(paths["best"] + model_name):
        with tf.gfile.Open(paths["best"] + model_name, "rb") as file:
            graph_def.ParseFromString(file.read())
    else:
        if not os.path.isfile(paths["weights"] + model_name):
            download.download_pretrained_weights(paths["weights"],
                                                 model_name[:-3])

        with tf.gfile.Open(paths["weights"] + model_name, "rb") as file:
            graph_def.ParseFromString(file.read())

    [predicted_maps] = tf.import_graph_def(graph_def,
                                           input_map={"input": input_images},
                                           return_elements=["output:0"])

    jpeg = data.postprocess_saliency_map(predicted_maps[0],
                                         original_shape[0])

    print(">> Start testing with %s %s model..." % (dataset.upper(), device))

    with tf.Session() as sess:
        sess.run(init_op)

        while True:
            try:
                #output_file, path = sess.run([jpeg, file_path], feed_dict={training: False})
                output_file, path = sess.run([jpeg, file_path])
            except tf.errors.OutOfRangeError:
                break

            path = path[0][0].decode("utf-8")

            filename = os.path.basename(path)
            filename = os.path.splitext(filename)[0]
            filename += ".jpeg"

            os.makedirs(paths["images"], exist_ok=True)

            with open(paths["images"] + filename, "wb") as file:
                file.write(output_file)
Beispiel #4
0
def test_model(dataset, paths, device):
    """The main function for executing network testing. It loads the specified
       dataset iterator and optimized saliency model. By default, when no model
       checkpoint is found locally, the pretrained weights will be downloaded.
       Testing only works for models trained on the same device as specified in
       the config file.

    Args:
        dataset (str): Denotes the dataset that was used during training.
        paths (dict, str): A dictionary with all path elements.
        device (str): Represents either "cpu" or "gpu".
    """

    video_file = tf.placeholder(tf.string, shape=())
    iterator = data.get_dataset_iterator("test", dataset, paths["data"],
                                         video_file)

    next_element, init_op = iterator

    input_images, original_shape, file_path = next_element

    graph_def = tf.GraphDef()

    model_name = "model_%s_%s.pb" % (dataset, device)

    if os.path.isfile(paths["best"] + model_name):
        with tf.gfile.Open(paths["best"] + model_name, "rb") as file:
            graph_def.ParseFromString(file.read())
    else:
        if not os.path.isfile(paths["weights"] + model_name):
            download.download_pretrained_weights(paths["weights"],
                                                 model_name[:-3])

        with tf.gfile.Open(paths["weights"] + model_name, "rb") as file:
            graph_def.ParseFromString(file.read())

    [predicted_maps] = tf.import_graph_def(graph_def,
                                           input_map={"input": input_images},
                                           return_elements=["output:0"])

    print(">> Start testing with %s %s model..." % (dataset.upper(), device))

    with tf.Session() as sess:

        video_files = data._get_file_list(paths["data"])
        for vf in video_files:
            print(vf)
            saliency_images_list = []
            sess.run(init_op, feed_dict={video_file: vf})
            while True:
                try:
                    saliency_images, target_shape, np_file_path = \
                        sess.run([predicted_maps, original_shape, file_path],
                            feed_dict={video_file:vf})
                    saliency_images_list.append(saliency_images)
                except tf.errors.OutOfRangeError:
                    break

            saliency_video = np.concatenate(saliency_images_list)
            target_shape = target_shape[0]
            np_file_path = np_file_path[0][0]

            commonpath = os.path.commonpath([paths["data"], paths["images"]])
            file_path_str = np_file_path.decode("utf8")
            relative_file_path = os.path.relpath(file_path_str,
                                                 start=commonpath)
            output_file_path = os.path.join(paths["images"],
                                            relative_file_path)
            os.makedirs(os.path.dirname(output_file_path), exist_ok=True)

            fourcc = cv2.VideoWriter_fourcc(*'XVID')
            frame_size = (target_shape[1], target_shape[0])
            frame_size = (saliency_video.shape[2], saliency_video.shape[1])
            out = cv2.VideoWriter(output_file_path, fourcc, 25, frame_size,
                                  False)

            saliency_video = np.squeeze(saliency_video)
            saliency_video *= 255
            for frame in saliency_video:
                saliency_map = data._resize_image(frame,
                                                  target_shape,
                                                  True,
                                                  is_numpy=True)
                saliency_map = data._crop_image(saliency_map,
                                                target_shape,
                                                is_numpy=True)

                saliency_map = np.round(frame)
                saliency_map = saliency_map.astype(np.uint8)

                out.write(saliency_map)
            out.release()