def restore(self, sess, dataset, paths, device): """This function allows continued training from a prior checkpoint and training from scratch with the pretrained VGG16 weights. In case the dataset is either CAT2000 or MIT1003, a prior checkpoint based on the SALICON dataset is required. Args: sess (object): The current TF training session. dataset ([type]): The dataset used for training. paths (dict, str): A dictionary with all path elements. device (str): Represents either "cpu" or "gpu". Returns: object: A saver object for saving the model. """ model_name = "model_%s_%s" % (dataset, device) salicon_name = "model_salicon_%s" % device vgg16_name = "vgg16_hybrid" ext1 = ".ckpt.data-00000-of-00001" ext2 = ".ckpt.index" saver = tf.train.Saver() if os.path.isfile(paths["latest"] + model_name + ext1) and \ os.path.isfile(paths["latest"] + model_name + ext2): saver.restore(sess, paths["latest"] + model_name + ".ckpt") elif dataset in ("mit1003", "cat2000", "dutomron", "pascals", "osie", "fiwi"): if os.path.isfile(paths["best"] + salicon_name + ext1) and \ os.path.isfile(paths["best"] + salicon_name + ext2): saver.restore(sess, paths["best"] + salicon_name + ".ckpt") else: raise FileNotFoundError("Train model on SALICON first") else: if not (os.path.isfile(paths["weights"] + vgg16_name + ext1) or os.path.isfile(paths["weights"] + vgg16_name + ext2)): download.download_pretrained_weights(paths["weights"], "vgg16_hybrid") self._pretraining() loader = tf.train.Saver(var_list=self._mapping) loader.restore(sess, paths["weights"] + vgg16_name + ".ckpt") return saver
def test_model(ds_name, encoder, paths, categorical=False): """The main function for executing network testing. It loads the specified dataset iterator and optimized saliency model. By default, when no model checkpoint is found locally, the pretrained weights will be downloaded. Args: ds_name (str): Denotes the dataset that was used during training. encoder (str): the name of the encoder want to be used to predict. paths (dict, str): A dictionary with all path elements. """ w_filename_template = "/%s_%s_%s_weights.h5" # [encoder]_[ds_name]_weights.h5 (test_ds, n_test) = data.load_test_dataset(ds_name, paths["data"], categorical) print(">> Preparing model with encoder %s..." % encoder) model = MyModel(encoder, ds_name, "test") weights_path = paths["weights"] + w_filename_template % (encoder, ds_name, loss_fn_name) if os.path.exists(weights_path): print("Weights are loaded!\n %s"%weights_path) else: download.download_pretrained_weights(paths["weights"], encoder, ds_name, loss_fn_name) model.load_weights(weights_path) del weights_path print(">> Start predicting using model trained on %s..." % ds_name.upper()) results_path = paths["results"] + "%s/%s/%s/" % (ds_name, encoder, loss_fn_name) # Preparing progbar test_progbar = Progbar(n_test) for test_images, test_ori_sizes, test_filenames in test_ds: pred = test_step(test_images, model) for pred, filename, ori_size in zip(pred, test_filenames.numpy(), test_ori_sizes): img = data.postprocess_saliency_map(pred, ori_size, as_image=True) tf.io.write_file(results_path + filename.decode("utf-8"), img) test_progbar.add(test_images.shape[0])
def get_tf_objects(paths): dataset = 'mit1003' device = config.PARAMS["device"] model_name = "model_%s_%s.pb" % (dataset, device) current_path = os.path.dirname(os.path.realpath(__file__)) paths = define_paths(current_path, None) if os.path.isfile(paths["best"] + model_name): with tf.gfile.Open(paths["best"] + model_name, "rb") as file: graph_def.ParseFromString(file.read()) else: if not os.path.isfile(paths["weights"] + model_name): download.download_pretrained_weights(paths["weights"], model_name[:-3]) with tf.gfile.Open(paths["weights"] + model_name, "rb") as file: graph_def.ParseFromString(file.read()) [predicted_maps] = tf.import_graph_def(graph_def, input_map={"input": input_images}, return_elements=["output:0"]) return
def test_model(dataset, paths, device): """The main function for executing network testing. It loads the specified dataset iterator and optimized saliency model. By default, when no model checkpoint is found locally, the pretrained weights will be downloaded. Testing only works for models trained on the same device as specified in the config file. Args: dataset (str): Denotes the dataset that was used during training. paths (dict, str): A dictionary with all path elements. device (str): Represents either "cpu" or "gpu". """ iterator = data.get_dataset_iterator("test", dataset, paths["data"]) next_element, init_op = iterator input_images, original_shape, file_path = next_element #training = tf.placeholder(tf.bool, name="training") ## For BN graph_def = tf.GraphDef() model_name = "model_%s_%s.pb" % (dataset, device) if os.path.isfile(paths["best"] + model_name): with tf.gfile.Open(paths["best"] + model_name, "rb") as file: graph_def.ParseFromString(file.read()) else: if not os.path.isfile(paths["weights"] + model_name): download.download_pretrained_weights(paths["weights"], model_name[:-3]) with tf.gfile.Open(paths["weights"] + model_name, "rb") as file: graph_def.ParseFromString(file.read()) [predicted_maps] = tf.import_graph_def(graph_def, input_map={"input": input_images}, return_elements=["output:0"]) jpeg = data.postprocess_saliency_map(predicted_maps[0], original_shape[0]) print(">> Start testing with %s %s model..." % (dataset.upper(), device)) with tf.Session() as sess: sess.run(init_op) while True: try: #output_file, path = sess.run([jpeg, file_path], feed_dict={training: False}) output_file, path = sess.run([jpeg, file_path]) except tf.errors.OutOfRangeError: break path = path[0][0].decode("utf-8") filename = os.path.basename(path) filename = os.path.splitext(filename)[0] filename += ".jpeg" os.makedirs(paths["images"], exist_ok=True) with open(paths["images"] + filename, "wb") as file: file.write(output_file)
left, right, bottom, top = best_window(saliency_resized_arr) output = original_arr[bottom:top, left:right, :] if show_saliency: bounded = overlay_saliency(original_img, saliency_resized_img, left, right, bottom, top) return bounded return output ### Model loading code graph_def = tf.GraphDef() model_name = "weights/model_mit1003_cpu.pb" download.download_pretrained_weights('weights/', 'model_mit1003_cpu') with tf.gfile.Open(model_name, "rb") as file: graph_def.ParseFromString(file.read()) input_plhd = tf.placeholder(tf.float32, (None, None, None, 3)) [predicted_maps] = tf.import_graph_def(graph_def, input_map={"input": input_plhd}, return_elements=["output:0"]) sess = tf.Session() examples = [["images/1.jpg", True], ["images/2.jpg", True]] thumbnail = "https://ibb.co/hXdbDyD" io = gr.Interface(test_model, [ gr.inputs.Image(label="Your Image"),
def test_model(dataset, paths, device): """The main function for executing network testing. It loads the specified dataset iterator and optimized saliency model. By default, when no model checkpoint is found locally, the pretrained weights will be downloaded. Testing only works for models trained on the same device as specified in the config file. Args: dataset (str): Denotes the dataset that was used during training. paths (dict, str): A dictionary with all path elements. device (str): Represents either "cpu" or "gpu". """ video_file = tf.placeholder(tf.string, shape=()) iterator = data.get_dataset_iterator("test", dataset, paths["data"], video_file) next_element, init_op = iterator input_images, original_shape, file_path = next_element graph_def = tf.GraphDef() model_name = "model_%s_%s.pb" % (dataset, device) if os.path.isfile(paths["best"] + model_name): with tf.gfile.Open(paths["best"] + model_name, "rb") as file: graph_def.ParseFromString(file.read()) else: if not os.path.isfile(paths["weights"] + model_name): download.download_pretrained_weights(paths["weights"], model_name[:-3]) with tf.gfile.Open(paths["weights"] + model_name, "rb") as file: graph_def.ParseFromString(file.read()) [predicted_maps] = tf.import_graph_def(graph_def, input_map={"input": input_images}, return_elements=["output:0"]) print(">> Start testing with %s %s model..." % (dataset.upper(), device)) with tf.Session() as sess: video_files = data._get_file_list(paths["data"]) for vf in video_files: print(vf) saliency_images_list = [] sess.run(init_op, feed_dict={video_file: vf}) while True: try: saliency_images, target_shape, np_file_path = \ sess.run([predicted_maps, original_shape, file_path], feed_dict={video_file:vf}) saliency_images_list.append(saliency_images) except tf.errors.OutOfRangeError: break saliency_video = np.concatenate(saliency_images_list) target_shape = target_shape[0] np_file_path = np_file_path[0][0] commonpath = os.path.commonpath([paths["data"], paths["images"]]) file_path_str = np_file_path.decode("utf8") relative_file_path = os.path.relpath(file_path_str, start=commonpath) output_file_path = os.path.join(paths["images"], relative_file_path) os.makedirs(os.path.dirname(output_file_path), exist_ok=True) fourcc = cv2.VideoWriter_fourcc(*'XVID') frame_size = (target_shape[1], target_shape[0]) frame_size = (saliency_video.shape[2], saliency_video.shape[1]) out = cv2.VideoWriter(output_file_path, fourcc, 25, frame_size, False) saliency_video = np.squeeze(saliency_video) saliency_video *= 255 for frame in saliency_video: saliency_map = data._resize_image(frame, target_shape, True, is_numpy=True) saliency_map = data._crop_image(saliency_map, target_shape, is_numpy=True) saliency_map = np.round(frame) saliency_map = saliency_map.astype(np.uint8) out.write(saliency_map) out.release()
def train_model(ds_name, encoder, paths): """The main function for executing network training. It loads the specified dataset iterator, saliency model, and helper classes. Training is then performed in a new session by iterating over all batches for a number of epochs. After validation on an independent set, the model is saved and the training history is updated. Args: ds_name (str): Denotes the dataset to be used during training. paths (dict, str): A dictionary with all path elements. """ w_filename_template = "/%s_%s_%s_weights.h5" # [encoder]_[ds_name]_[loss_fn_name]_weights.h5 (train_ds, n_train), (val_ds, n_val) = data.load_train_dataset(ds_name, paths["data"]) print(">> Preparing model with encoder %s..." % encoder) model = MyModel(encoder, ds_name, "train") if ds_name != "salicon": salicon_weights = paths["weights"] + w_filename_template % (encoder, "salicon", loss_fn_name) if os.path.exists(salicon_weights): print("Salicon weights are loaded!\n %s"%salicon_weights) else: download.download_pretrained_weights(paths["weights"], encoder, "salicon", loss_fn_name) model.load_weights(salicon_weights) del salicon_weights model.summary() n_epochs = config.PARAMS["n_epochs"] # Preparing loss_fn = globals().get(loss_fn_name, None) optimizer = tf.keras.optimizers.Adam(config.PARAMS["learning_rate"]) train_metric = tf.keras.metrics.Mean(name="train_loss") val_metric = tf.keras.metrics.Mean(name="val_loss") ckpts_path = paths["ckpts"] + "%s/%s/%s/" % (encoder, ds_name, loss_fn_name) ckpt = tf.train.Checkpoint(net=model, train_metric=train_metric, val_metric=val_metric) ckpt_manager = tf.train.CheckpointManager(ckpt, ckpts_path, max_to_keep=n_epochs) start_epoch = 0 # if a checkpoint exists, restore the latest checkpoint. if ckpt_manager.latest_checkpoint: ckpt.restore(ckpt_manager.latest_checkpoint).assert_consumed() start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1]) print ('Checkpoint restored:\n{}'.format(ckpt_manager.latest_checkpoint)) train_metric.reset_states() val_metric.reset_states() print("\n>> Start training model on %s..." % ds_name.upper()) print(("Training details:" + "\n{0:<4}Number of epochs: {n_epochs:d}" + "\n{0:<4}Batch size: {batch_size:d}" + "\n{0:<4}Learning rate: {learning_rate:.1e}" + "\n{0:<4}Loss function: {1}").format(" ", loss_fn_name, **config.PARAMS)) print("_" * 65) if ds_name == "salicon" and start_epoch < 2: model.freeze_unfreeze_encoder_trained_layers(True) for epoch in range(start_epoch, n_epochs): if ds_name == "salicon" and epoch == 2: model.freeze_unfreeze_encoder_trained_layers(False) train_progbar = Progbar(n_train, stateful_metrics=["train_loss"]) for train_x, train_y_true, train_ori_sizes, train_filenames in train_ds: train_y_pred, train_loss = train_step(train_x, train_y_true, model, loss_fn, optimizer) train_metric(train_loss) train_progbar.add(train_x.shape[0], [("train_loss", train_metric.result())]) val_progbar = Progbar(n_val, stateful_metrics=["val_loss"]) for val_x, val_y_true, val_ori_sizes, val_filenames in val_ds: val_y_pred, val_loss = val_step(val_x, val_y_true, model, loss_fn) val_metric(val_loss) val_progbar.add(val_x.shape[0], [("val_loss", val_metric.result())]) train_metrics_results = _print_metrics({"train_loss": train_metric}) val_metrics_results = _print_metrics({"val_loss": val_metric}) print('Epoch {} - {} - {}'.format(epoch+1, train_metrics_results, val_metrics_results)) ckpt_manager.save() # Reset the metrics for the next epoch train_metric.reset_states() val_metric.reset_states() # Picking best result print(">> Picking best result") min_val_loss = None for i, checkpoint in enumerate(ckpt_manager.checkpoints): ckpt.restore(checkpoint).assert_consumed() train_metrics_results = _print_metrics({"train_loss": train_metric}) val_metrics_results = _print_metrics({"val_loss": val_metric}) print('Epoch {} - {} - {}'.format(i+1, train_metrics_results, val_metrics_results)) val_loss_result = val_metric.result() if min_val_loss is None or min_val_loss > val_loss_result: min_train_loss = train_metric.result() min_val_loss = val_loss_result min_index = i ckpt.restore(ckpt_manager.checkpoints[min_index]) print("best result picked -> epoch: {0} - train_{1}: {2} - val_{1}: {3}".format(min_index + 1, loss_fn_name, ('%.4f' if min_train_loss > 1e-3 else '%.4e') % min_train_loss, ('%.4f' if min_val_loss > 1e-3 else '%.4e') % min_val_loss)) # Saving model's weights print(">> Saving model's weights") dest_path = paths["weights"] + w_filename_template % (encoder, ds_name, loss_fn_name) if min_index < 2: model.freeze_unfreeze_encoder_trained_layers(False) model.save_weights(dest_path) print("weights are saved to:\n%s" % dest_path)
def find_n_high(ds_name, encoder, paths, n, metric, negate=False): """The main function for executing network training. It loads the specified dataset iterator, saliency model, and helper classes. Training is then performed in a new session by iterating over all batches for a number of epochs. After validation on an independent set, the model is saved and the training history is updated. Args: ds_name (str): Denotes the dataset to be used during training. paths (dict, str): A dictionary with all path elements. """ w_filename_template = "/%s_%s_%s_weights.h5" # [encoder]_[ds_name]_[loss_fn_name]_weights.h5 (eval_ds, n_eval) = data.load_eval_dataset(ds_name, paths["data"]) print(">> Preparing model with encoder %s..." % encoder) model = MyModel(encoder, ds_name, "train") if "trained_weights" in paths: if os.path.exists(paths["trained_weights"]): weights_path = paths["trained_weights"] else: raise ValueError("could not find the specified weights file.\n specified weights: %s"%paths["trained_weights"]) else: weights_path = paths["weights"] + w_filename_template % (encoder, ds_name, loss_fn_name) if os.path.exists(weights_path): print("Weights are loaded!\n %s"%weights_path) else: download.download_pretrained_weights(paths["weights"], encoder, "salicon", loss_fn_name) model.load_weights(weights_path) del weights_path model.summary() # Preparing print("\n>> Start finding %d %s results for model on %s..." % (n, "worst" if negate else "best",ds_name.upper())) print(("Evaluation details:" + "\n{0:<4}Metric: {1}").format(" ", metric)) print("_" * 65) eval_progbar = Progbar(n_eval) min_heap = [] count = 0 sign = -1 if negate else 1 for eval_x, eval_fixs, eval_y_true, eval_ori_sizes, eval_filenames in eval_ds: eval_y_pred = test_step(eval_x, model) for pred, y_true, fixs, filename, ori_size in zip(eval_y_pred, eval_fixs, eval_y_true, eval_filenames.numpy(), eval_ori_sizes): pred = tf.expand_dims(data.postprocess_saliency_map(pred, ori_size), axis=0) fixs = tf.expand_dims(fixs, axis=0) y_true = tf.expand_dims(y_true, axis=0) score = _calc_metrics([metric], y_true, fixs, pred)[metric].numpy() * sign if count < n: count+=1 heapq.heappush(min_heap, (score, filename.decode("utf-8"))) else: heapq.heappushpop(min_heap, (score, filename.decode("utf-8"))) eval_progbar.add(eval_x.shape[0]) min_heap.sort(reverse=True) for s, n in min_heap: print(s, n)
def eval_results(ds_name, encoder, paths): """The main function for executing network training. It loads the specified dataset iterator, saliency model, and helper classes. Training is then performed in a new session by iterating over all batches for a number of epochs. After validation on an independent set, the model is saved and the training history is updated. Args: ds_name (str): Denotes the dataset to be used during training. paths (dict, str): A dictionary with all path elements. """ w_filename_template = "/%s_%s_%s_weights.h5" # [encoder]_[ds_name]_[loss_fn_name]_weights.h5 (eval_ds, n_eval) = data.load_eval_dataset(ds_name, paths["data"]) print(">> Preparing model with encoder %s..." % encoder) model = MyModel(encoder, ds_name, "train") if "trained_weights" in paths: if os.path.exists(paths["trained_weights"]): weights_path = paths["trained_weights"] else: raise ValueError("could not find the specified weights file.\n specified weights: %s"%paths["trained_weights"]) else: weights_path = paths["weights"] + w_filename_template % (encoder, ds_name, loss_fn_name) if os.path.exists(weights_path): print("Weights are loaded!\n %s"%weights_path) else: download.download_pretrained_weights(paths["weights"], encoder, "salicon", loss_fn_name) model.load_weights(weights_path) del weights_path model.summary() # Preparing metrics = config.PARAMS["metrics"] print("\n>> Start evaluating model on %s..." % ds_name.upper()) print(("Evaluation details:" + "\n{0:<4}Metrics: {2}").format(" ", loss_fn_name, ", ".join(metrics), **config.PARAMS)) print("_" * 65) eval_progbar = Progbar(n_eval) categorical = config.SPECS[ds_name].get("categorical", False) cat_metrics = {} for eval_x, eval_fixs, eval_y_true, eval_ori_sizes, eval_filenames in eval_ds: eval_y_pred = test_step(eval_x, model) for pred, y_true, fixs, filename, ori_size in zip(eval_y_pred, eval_fixs, eval_y_true, eval_filenames.numpy(), eval_ori_sizes): pred = tf.expand_dims(data.postprocess_saliency_map(pred, ori_size), axis=0) fixs = tf.expand_dims(fixs, axis=0) y_true = tf.expand_dims(y_true, axis=0) met_vals = _calc_metrics(metrics, y_true, fixs, pred) if categorical: cat = "/".join(filename.decode("utf-8").split("/")[:-1]) if not cat in cat_metrics: cat_metrics[cat] = {} for name in metrics: cat_metrics[cat][name] = {"sum":0, "count": 0} for name, value in met_vals.items(): cat_metrics[cat][name]["sum"] += value cat_metrics[cat][name]["count"] += 1 eval_progbar.add(eval_x.shape[0], met_vals.items()) for cat, cat_met in cat_metrics.items(): to_print = [] for name, value in cat_met.items(): _mean = value["sum"]/value["count"] to_print.append("{}: {}".format(name, ('%.4f' if _mean > 1e-3 else '%.4e') % _mean)) print('Results ({}): {}'.format(cat, " - ".join(to_print)))