示例#1
0
    def restore(self, sess, dataset, paths, device):
        """This function allows continued training from a prior checkpoint and
           training from scratch with the pretrained VGG16 weights. In case the
           dataset is either CAT2000 or MIT1003, a prior checkpoint based on
           the SALICON dataset is required.

        Args:
            sess (object): The current TF training session.
            dataset ([type]): The dataset used for training.
            paths (dict, str): A dictionary with all path elements.
            device (str): Represents either "cpu" or "gpu".

        Returns:
            object: A saver object for saving the model.
        """

        model_name = "model_%s_%s" % (dataset, device)
        salicon_name = "model_salicon_%s" % device
        vgg16_name = "vgg16_hybrid"

        ext1 = ".ckpt.data-00000-of-00001"
        ext2 = ".ckpt.index"

        saver = tf.train.Saver()

        if os.path.isfile(paths["latest"] + model_name + ext1) and \
           os.path.isfile(paths["latest"] + model_name + ext2):
            saver.restore(sess, paths["latest"] + model_name + ".ckpt")
        elif dataset in ("mit1003", "cat2000", "dutomron", "pascals", "osie",
                         "fiwi"):
            if os.path.isfile(paths["best"] + salicon_name + ext1) and \
               os.path.isfile(paths["best"] + salicon_name + ext2):
                saver.restore(sess, paths["best"] + salicon_name + ".ckpt")
            else:
                raise FileNotFoundError("Train model on SALICON first")
        else:
            if not (os.path.isfile(paths["weights"] + vgg16_name + ext1)
                    or os.path.isfile(paths["weights"] + vgg16_name + ext2)):
                download.download_pretrained_weights(paths["weights"],
                                                     "vgg16_hybrid")
            self._pretraining()

            loader = tf.train.Saver(var_list=self._mapping)
            loader.restore(sess, paths["weights"] + vgg16_name + ".ckpt")

        return saver
示例#2
0
文件: main.py 项目: RJason13/A-ResP
def test_model(ds_name, encoder, paths, categorical=False):
    """The main function for executing network testing. It loads the specified
       dataset iterator and optimized saliency model. By default, when no model
       checkpoint is found locally, the pretrained weights will be downloaded.

    Args:
        ds_name (str): Denotes the dataset that was used during training.
        encoder (str): the name of the encoder want to be used to predict.
        paths (dict, str): A dictionary with all path elements.
    """

    w_filename_template = "/%s_%s_%s_weights.h5" # [encoder]_[ds_name]_weights.h5

    (test_ds, n_test) = data.load_test_dataset(ds_name, paths["data"], categorical)
    
    print(">> Preparing model with encoder %s..." % encoder)

    model = MyModel(encoder, ds_name, "test")

    weights_path = paths["weights"] + w_filename_template % (encoder, ds_name, loss_fn_name)
    if os.path.exists(weights_path):
        print("Weights are loaded!\n    %s"%weights_path)
    else:
        download.download_pretrained_weights(paths["weights"], encoder, ds_name, loss_fn_name)
    model.load_weights(weights_path)
    del weights_path

    print(">> Start predicting using model trained on %s..." % ds_name.upper())
    results_path = paths["results"] + "%s/%s/%s/" % (ds_name, encoder, loss_fn_name)

    # Preparing progbar
    test_progbar = Progbar(n_test)
    for test_images, test_ori_sizes, test_filenames in test_ds:
        pred = test_step(test_images, model)
        for pred, filename, ori_size in zip(pred, test_filenames.numpy(), test_ori_sizes):
            img = data.postprocess_saliency_map(pred, ori_size, as_image=True)
            tf.io.write_file(results_path + filename.decode("utf-8"), img)
        test_progbar.add(test_images.shape[0])
示例#3
0
def get_tf_objects(paths):
    dataset = 'mit1003'
    device = config.PARAMS["device"]
    model_name = "model_%s_%s.pb" % (dataset, device)

    current_path = os.path.dirname(os.path.realpath(__file__))
    paths = define_paths(current_path, None)

    if os.path.isfile(paths["best"] + model_name):
        with tf.gfile.Open(paths["best"] + model_name, "rb") as file:
            graph_def.ParseFromString(file.read())
    else:
        if not os.path.isfile(paths["weights"] + model_name):
            download.download_pretrained_weights(paths["weights"],
                                                 model_name[:-3])

        with tf.gfile.Open(paths["weights"] + model_name, "rb") as file:
            graph_def.ParseFromString(file.read())

    [predicted_maps] = tf.import_graph_def(graph_def,
                                           input_map={"input": input_images},
                                           return_elements=["output:0"])

    return
def test_model(dataset, paths, device):
    """The main function for executing network testing. It loads the specified
       dataset iterator and optimized saliency model. By default, when no model
       checkpoint is found locally, the pretrained weights will be downloaded.
       Testing only works for models trained on the same device as specified in
       the config file.

    Args:
        dataset (str): Denotes the dataset that was used during training.
        paths (dict, str): A dictionary with all path elements.
        device (str): Represents either "cpu" or "gpu".
    """

    iterator = data.get_dataset_iterator("test", dataset, paths["data"])

    next_element, init_op = iterator

    input_images, original_shape, file_path = next_element

    #training = tf.placeholder(tf.bool, name="training")   ## For BN
    
    graph_def = tf.GraphDef()

    model_name = "model_%s_%s.pb" % (dataset, device)

    if os.path.isfile(paths["best"] + model_name):
        with tf.gfile.Open(paths["best"] + model_name, "rb") as file:
            graph_def.ParseFromString(file.read())
    else:
        if not os.path.isfile(paths["weights"] + model_name):
            download.download_pretrained_weights(paths["weights"],
                                                 model_name[:-3])

        with tf.gfile.Open(paths["weights"] + model_name, "rb") as file:
            graph_def.ParseFromString(file.read())

    [predicted_maps] = tf.import_graph_def(graph_def,
                                           input_map={"input": input_images},
                                           return_elements=["output:0"])

    jpeg = data.postprocess_saliency_map(predicted_maps[0],
                                         original_shape[0])

    print(">> Start testing with %s %s model..." % (dataset.upper(), device))

    with tf.Session() as sess:
        sess.run(init_op)

        while True:
            try:
                #output_file, path = sess.run([jpeg, file_path], feed_dict={training: False})
                output_file, path = sess.run([jpeg, file_path])
            except tf.errors.OutOfRangeError:
                break

            path = path[0][0].decode("utf-8")

            filename = os.path.basename(path)
            filename = os.path.splitext(filename)[0]
            filename += ".jpeg"

            os.makedirs(paths["images"], exist_ok=True)

            with open(paths["images"] + filename, "wb") as file:
                file.write(output_file)
示例#5
0
    left, right, bottom, top = best_window(saliency_resized_arr)
    output = original_arr[bottom:top, left:right, :]

    if show_saliency:
        bounded = overlay_saliency(original_img, saliency_resized_img, left,
                                   right, bottom, top)
        return bounded

    return output


### Model loading code
graph_def = tf.GraphDef()
model_name = "weights/model_mit1003_cpu.pb"

download.download_pretrained_weights('weights/', 'model_mit1003_cpu')

with tf.gfile.Open(model_name, "rb") as file:
    graph_def.ParseFromString(file.read())
    input_plhd = tf.placeholder(tf.float32, (None, None, None, 3))
    [predicted_maps] = tf.import_graph_def(graph_def,
                                           input_map={"input": input_plhd},
                                           return_elements=["output:0"])

sess = tf.Session()

examples = [["images/1.jpg", True], ["images/2.jpg", True]]

thumbnail = "https://ibb.co/hXdbDyD"
io = gr.Interface(test_model, [
    gr.inputs.Image(label="Your Image"),
示例#6
0
def test_model(dataset, paths, device):
    """The main function for executing network testing. It loads the specified
       dataset iterator and optimized saliency model. By default, when no model
       checkpoint is found locally, the pretrained weights will be downloaded.
       Testing only works for models trained on the same device as specified in
       the config file.

    Args:
        dataset (str): Denotes the dataset that was used during training.
        paths (dict, str): A dictionary with all path elements.
        device (str): Represents either "cpu" or "gpu".
    """

    video_file = tf.placeholder(tf.string, shape=())
    iterator = data.get_dataset_iterator("test", dataset, paths["data"],
                                         video_file)

    next_element, init_op = iterator

    input_images, original_shape, file_path = next_element

    graph_def = tf.GraphDef()

    model_name = "model_%s_%s.pb" % (dataset, device)

    if os.path.isfile(paths["best"] + model_name):
        with tf.gfile.Open(paths["best"] + model_name, "rb") as file:
            graph_def.ParseFromString(file.read())
    else:
        if not os.path.isfile(paths["weights"] + model_name):
            download.download_pretrained_weights(paths["weights"],
                                                 model_name[:-3])

        with tf.gfile.Open(paths["weights"] + model_name, "rb") as file:
            graph_def.ParseFromString(file.read())

    [predicted_maps] = tf.import_graph_def(graph_def,
                                           input_map={"input": input_images},
                                           return_elements=["output:0"])

    print(">> Start testing with %s %s model..." % (dataset.upper(), device))

    with tf.Session() as sess:

        video_files = data._get_file_list(paths["data"])
        for vf in video_files:
            print(vf)
            saliency_images_list = []
            sess.run(init_op, feed_dict={video_file: vf})
            while True:
                try:
                    saliency_images, target_shape, np_file_path = \
                        sess.run([predicted_maps, original_shape, file_path],
                            feed_dict={video_file:vf})
                    saliency_images_list.append(saliency_images)
                except tf.errors.OutOfRangeError:
                    break

            saliency_video = np.concatenate(saliency_images_list)
            target_shape = target_shape[0]
            np_file_path = np_file_path[0][0]

            commonpath = os.path.commonpath([paths["data"], paths["images"]])
            file_path_str = np_file_path.decode("utf8")
            relative_file_path = os.path.relpath(file_path_str,
                                                 start=commonpath)
            output_file_path = os.path.join(paths["images"],
                                            relative_file_path)
            os.makedirs(os.path.dirname(output_file_path), exist_ok=True)

            fourcc = cv2.VideoWriter_fourcc(*'XVID')
            frame_size = (target_shape[1], target_shape[0])
            frame_size = (saliency_video.shape[2], saliency_video.shape[1])
            out = cv2.VideoWriter(output_file_path, fourcc, 25, frame_size,
                                  False)

            saliency_video = np.squeeze(saliency_video)
            saliency_video *= 255
            for frame in saliency_video:
                saliency_map = data._resize_image(frame,
                                                  target_shape,
                                                  True,
                                                  is_numpy=True)
                saliency_map = data._crop_image(saliency_map,
                                                target_shape,
                                                is_numpy=True)

                saliency_map = np.round(frame)
                saliency_map = saliency_map.astype(np.uint8)

                out.write(saliency_map)
            out.release()
示例#7
0
文件: main.py 项目: RJason13/A-ResP
def train_model(ds_name, encoder, paths):
    """The main function for executing network training. It loads the specified
       dataset iterator, saliency model, and helper classes. Training is then
       performed in a new session by iterating over all batches for a number of
       epochs. After validation on an independent set, the model is saved and
       the training history is updated.

    Args:
        ds_name (str): Denotes the dataset to be used during training.
        paths (dict, str): A dictionary with all path elements.
    """

    w_filename_template = "/%s_%s_%s_weights.h5" # [encoder]_[ds_name]_[loss_fn_name]_weights.h5

    (train_ds, n_train), (val_ds, n_val) = data.load_train_dataset(ds_name, paths["data"])
    
    print(">> Preparing model with encoder %s..." % encoder)

    model = MyModel(encoder, ds_name, "train")

    if ds_name != "salicon":
        salicon_weights = paths["weights"] + w_filename_template % (encoder, "salicon", loss_fn_name)
        if os.path.exists(salicon_weights):
            print("Salicon weights are loaded!\n    %s"%salicon_weights)
        else:
            download.download_pretrained_weights(paths["weights"], encoder, "salicon", loss_fn_name)
        model.load_weights(salicon_weights)
        del salicon_weights

    model.summary()

    n_epochs = config.PARAMS["n_epochs"]

    # Preparing
    loss_fn = globals().get(loss_fn_name, None)
    optimizer = tf.keras.optimizers.Adam(config.PARAMS["learning_rate"])

    train_metric = tf.keras.metrics.Mean(name="train_loss")
    val_metric = tf.keras.metrics.Mean(name="val_loss")

    ckpts_path = paths["ckpts"] + "%s/%s/%s/" % (encoder, ds_name, loss_fn_name)
    ckpt = tf.train.Checkpoint(net=model, train_metric=train_metric, val_metric=val_metric)
    ckpt_manager = tf.train.CheckpointManager(ckpt, ckpts_path, max_to_keep=n_epochs)
    start_epoch = 0
    
    # if a checkpoint exists, restore the latest checkpoint.
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint).assert_consumed()
        start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])
        print ('Checkpoint restored:\n{}'.format(ckpt_manager.latest_checkpoint))
        train_metric.reset_states()
        val_metric.reset_states()

    print("\n>> Start training model on %s..." % ds_name.upper())
    print(("Training details:" +
    "\n{0:<4}Number of epochs: {n_epochs:d}" +
    "\n{0:<4}Batch size: {batch_size:d}" +
    "\n{0:<4}Learning rate: {learning_rate:.1e}" +
    "\n{0:<4}Loss function: {1}").format(" ", loss_fn_name, **config.PARAMS))
    print("_" * 65)
    if ds_name == "salicon" and start_epoch < 2:
        model.freeze_unfreeze_encoder_trained_layers(True)
    for epoch in range(start_epoch, n_epochs):
        if ds_name == "salicon" and epoch == 2:
            model.freeze_unfreeze_encoder_trained_layers(False)

        train_progbar = Progbar(n_train, stateful_metrics=["train_loss"])
        for train_x, train_y_true, train_ori_sizes, train_filenames in train_ds:
            train_y_pred, train_loss = train_step(train_x, train_y_true, model, loss_fn, optimizer)
            train_metric(train_loss)
            train_progbar.add(train_x.shape[0], [("train_loss", train_metric.result())])

        val_progbar = Progbar(n_val, stateful_metrics=["val_loss"])
        for val_x, val_y_true, val_ori_sizes, val_filenames in val_ds:
            val_y_pred, val_loss = val_step(val_x, val_y_true, model, loss_fn)
            val_metric(val_loss)
            val_progbar.add(val_x.shape[0], [("val_loss", val_metric.result())])

        train_metrics_results = _print_metrics({"train_loss": train_metric})
        val_metrics_results = _print_metrics({"val_loss": val_metric})
        print('Epoch {} - {} - {}'.format(epoch+1, train_metrics_results, val_metrics_results))
        
        ckpt_manager.save()

        # Reset the metrics for the next epoch
        train_metric.reset_states()
        val_metric.reset_states()

    # Picking best result
    print(">> Picking best result")
    min_val_loss = None

    for i, checkpoint in enumerate(ckpt_manager.checkpoints):
        ckpt.restore(checkpoint).assert_consumed()

        train_metrics_results = _print_metrics({"train_loss": train_metric})
        val_metrics_results = _print_metrics({"val_loss": val_metric})
        print('Epoch {} - {} - {}'.format(i+1, train_metrics_results, val_metrics_results))
        val_loss_result = val_metric.result()
        if min_val_loss is None or min_val_loss > val_loss_result:
            min_train_loss = train_metric.result()
            min_val_loss = val_loss_result
            min_index = i
    
    ckpt.restore(ckpt_manager.checkpoints[min_index])
    print("best result picked -> epoch: {0} - train_{1}: {2} - val_{1}: {3}".format(min_index + 1, loss_fn_name,
        ('%.4f' if min_train_loss > 1e-3 else '%.4e') % min_train_loss,
        ('%.4f' if min_val_loss > 1e-3 else '%.4e') % min_val_loss))

    # Saving model's weights
    print(">> Saving model's weights")
    dest_path = paths["weights"] + w_filename_template % (encoder, ds_name, loss_fn_name)
    if min_index < 2:
        model.freeze_unfreeze_encoder_trained_layers(False)
    model.save_weights(dest_path)
    print("weights are saved to:\n%s" % dest_path)
示例#8
0
文件: main.py 项目: RJason13/A-ResP
def find_n_high(ds_name, encoder, paths, n, metric, negate=False):
    """The main function for executing network training. It loads the specified
       dataset iterator, saliency model, and helper classes. Training is then
       performed in a new session by iterating over all batches for a number of
       epochs. After validation on an independent set, the model is saved and
       the training history is updated.

    Args:
        ds_name (str): Denotes the dataset to be used during training.
        paths (dict, str): A dictionary with all path elements.
    """

    w_filename_template = "/%s_%s_%s_weights.h5" # [encoder]_[ds_name]_[loss_fn_name]_weights.h5
    
    (eval_ds, n_eval) = data.load_eval_dataset(ds_name, paths["data"])
    
    print(">> Preparing model with encoder %s..." % encoder)

    model = MyModel(encoder, ds_name, "train")

    if "trained_weights" in paths:
        if os.path.exists(paths["trained_weights"]):
            weights_path = paths["trained_weights"]
        else:
            raise ValueError("could not find the specified weights file.\n    specified weights: %s"%paths["trained_weights"])
    else:
        weights_path = paths["weights"] + w_filename_template % (encoder, ds_name, loss_fn_name)

    if os.path.exists(weights_path):
        print("Weights are loaded!\n    %s"%weights_path)
    else:
        download.download_pretrained_weights(paths["weights"], encoder, "salicon", loss_fn_name)
    
    model.load_weights(weights_path)
    del weights_path

    model.summary()

    # Preparing

    print("\n>> Start finding %d %s results for model on %s..." % (n, "worst" if negate else "best",ds_name.upper()))
    print(("Evaluation details:" +
        "\n{0:<4}Metric: {1}").format(" ", metric))
    print("_" * 65)

    eval_progbar = Progbar(n_eval)
    min_heap = []
    count = 0
    sign = -1 if negate else 1
    for eval_x, eval_fixs, eval_y_true, eval_ori_sizes, eval_filenames in eval_ds:
        eval_y_pred = test_step(eval_x, model)
        for pred, y_true, fixs, filename, ori_size in zip(eval_y_pred, eval_fixs, eval_y_true, eval_filenames.numpy(), eval_ori_sizes):
            pred = tf.expand_dims(data.postprocess_saliency_map(pred, ori_size), axis=0)
            fixs = tf.expand_dims(fixs, axis=0)
            y_true = tf.expand_dims(y_true, axis=0)

            score = _calc_metrics([metric], y_true, fixs, pred)[metric].numpy() * sign
            
            if count < n:
                count+=1
                heapq.heappush(min_heap, (score, filename.decode("utf-8")))
            else:
                heapq.heappushpop(min_heap, (score, filename.decode("utf-8")))
        eval_progbar.add(eval_x.shape[0])
    
    min_heap.sort(reverse=True)
    for s, n in min_heap:
        print(s, n)
示例#9
0
文件: main.py 项目: RJason13/A-ResP
def eval_results(ds_name, encoder, paths):
    """The main function for executing network training. It loads the specified
       dataset iterator, saliency model, and helper classes. Training is then
       performed in a new session by iterating over all batches for a number of
       epochs. After validation on an independent set, the model is saved and
       the training history is updated.

    Args:
        ds_name (str): Denotes the dataset to be used during training.
        paths (dict, str): A dictionary with all path elements.
    """

    w_filename_template = "/%s_%s_%s_weights.h5" # [encoder]_[ds_name]_[loss_fn_name]_weights.h5

    (eval_ds, n_eval) = data.load_eval_dataset(ds_name, paths["data"])
    
    print(">> Preparing model with encoder %s..." % encoder)

    model = MyModel(encoder, ds_name, "train")

    if "trained_weights" in paths:
        if os.path.exists(paths["trained_weights"]):
            weights_path = paths["trained_weights"]
        else:
            raise ValueError("could not find the specified weights file.\n    specified weights: %s"%paths["trained_weights"])
    else:
        weights_path = paths["weights"] + w_filename_template % (encoder, ds_name, loss_fn_name)

    if os.path.exists(weights_path):
        print("Weights are loaded!\n    %s"%weights_path)
    else:
        download.download_pretrained_weights(paths["weights"], encoder, "salicon", loss_fn_name)
    
    model.load_weights(weights_path)
    del weights_path

    model.summary()

    # Preparing
    metrics = config.PARAMS["metrics"]

    print("\n>> Start evaluating model on %s..." % ds_name.upper())
    print(("Evaluation details:" +
    "\n{0:<4}Metrics: {2}").format(" ", loss_fn_name, ", ".join(metrics), **config.PARAMS))
    print("_" * 65)

    eval_progbar = Progbar(n_eval)
    categorical = config.SPECS[ds_name].get("categorical", False)
    cat_metrics = {}
    for eval_x, eval_fixs, eval_y_true, eval_ori_sizes, eval_filenames in eval_ds:
        eval_y_pred = test_step(eval_x, model)
        for pred, y_true, fixs, filename, ori_size in zip(eval_y_pred, eval_fixs, eval_y_true, eval_filenames.numpy(), eval_ori_sizes):
            pred = tf.expand_dims(data.postprocess_saliency_map(pred, ori_size), axis=0)
            fixs = tf.expand_dims(fixs, axis=0)
            y_true = tf.expand_dims(y_true, axis=0)

            met_vals = _calc_metrics(metrics, y_true, fixs, pred)
            
            if categorical:
                cat = "/".join(filename.decode("utf-8").split("/")[:-1])
                if not cat in cat_metrics:
                    cat_metrics[cat] = {}
                    for name in metrics:
                        cat_metrics[cat][name] = {"sum":0, "count": 0}
                for name, value in met_vals.items():
                    cat_metrics[cat][name]["sum"] += value
                    cat_metrics[cat][name]["count"] += 1
        eval_progbar.add(eval_x.shape[0], met_vals.items())

    for cat, cat_met in cat_metrics.items():
        to_print = []
        for name, value in cat_met.items():
            _mean = value["sum"]/value["count"]
            to_print.append("{}: {}".format(name, ('%.4f' if _mean > 1e-3 else '%.4e') % _mean))
        print('Results ({}): {}'.format(cat, " - ".join(to_print)))