示例#1
0
def check_exp_folder_exists_and_create(exp_id, prefix, dir_name):
    if dir_name is not None:
        dir_path, _ = create_dir(os.path.join("../experiments", prefix), dir_name)
        dir_path, exists = create_dir(dir_path, "summary_images_exp_id_{}".format(exp_id))
        if exists:
            print("skipping export for exp_id: {} (directory already exists)".format(exp_id))
            return dir_path
    else:
        dir_path = create_dir(os.path.join("../experiments", prefix), "summary_images_exp_id_{}".format(exp_id))
    return dir_path
示例#2
0
文件: io.py 项目: cschaefer26/TacoGan
 def save(self, path: Path) -> None:
     path = path.parent / path.stem
     tmp_dir = Path(str(path) + '_save_tmp')
     create_dir(tmp_dir, overwrite=True)
     torch.save(self.tacotron.state_dict(), tmp_dir / 'tacotron.pyt')
     torch.save(self.gan.state_dict(), tmp_dir / 'gan.pyt')
     torch.save(self.taco_opti.state_dict(), tmp_dir / 'taco_opti.pyt')
     torch.save(self.gen_opti.state_dict(), tmp_dir / 'gen_opti.pyt')
     torch.save(self.disc_opti.state_dict(), tmp_dir / 'disc_opti.pyt')
     self.cfg.save(tmp_dir / 'config.yaml')
     shutil.make_archive(path, 'zip', tmp_dir)
     shutil.rmtree(tmp_dir)
示例#3
0
def check_exp_folder_exists_and_create(features, features_index, prefix,
                                       dir_name, cur_batch_it):
    exp_id = features[features_index]['experiment_id']
    if dir_name is not None:
        dir_path, _ = create_dir(os.path.join("../experiments", prefix),
                                 dir_name)
        dir_path, exists = create_dir(
            dir_path,
            "summary_images_batch_{}_exp_id_{}".format(cur_batch_it, exp_id))
        if exists:
            print("skipping export for exp_id: {} (directory already exists)".
                  format(exp_id))
            return False
    else:
        dir_path = create_dir(
            os.path.join("../experiments", prefix),
            "summary_images_batch_{}_exp_id_{}".format(cur_batch_it, exp_id))
    return dir_path
    def compute_losses_over_test_set(self):
        if not self.config.n_epochs == 1:
            print("test mode --> n_epochs will be set to 1")
            self.config.n_epochs = 1
        prefix = self.config.exp_name
        print("Computing average losses over full test set with initial_pos_vel_known={}".format(self.config.initial_pos_vel_known))
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

        full_batch_loss, full_img_loss, full_vel_loss, full_pos_loss, full_dist_loss, full_iou_loss = [], [], [], [], [], []

        while True:
            try:
                batch_total, img_loss, vel_loss, pos_loss, dist_loss, iou_loss, _, _, _, _ = self.test_batch(prefix=prefix,
                                                                    export_images=self.config.export_test_images,
                                                                    initial_pos_vel_known=self.config.initial_pos_vel_known,
                                                                    process_all_nn_outputs=True,
                                                                    sub_dir_name="test_{}_iterations_trained".format(cur_batch_it),
                                                                    output_results=False)
                full_batch_loss.append(batch_total)
                full_img_loss.append(img_loss)
                full_vel_loss.append(vel_loss)
                full_pos_loss.append(pos_loss)
                full_dist_loss.append(dist_loss)
                full_iou_loss.append(iou_loss)

            except tf.errors.OutOfRangeError:
                break
        mean_total_loss = np.mean(full_batch_loss)
        mean_img_loss = np.mean(full_img_loss)
        mean_vel_loss = np.mean(full_vel_loss)
        mean_pos_loss = np.mean(full_pos_loss)
        mean_dist_loss = np.mean(full_dist_loss)
        mean_iou_loss = np.mean(full_iou_loss)

        output_dir, _ = create_dir(os.path.join("../experiments", prefix), "loss_over_test_set")
        dataset_name = os.path.basename(self.config.tfrecords_dir)

        if self.config.loss_type == "cross_entropy_seg_only":
            img_loss_type = "mean binary cross entropy loss"
        else:
            img_loss_type = "mean img loss"

        str_out = "mean loss over all test samples of dataset: {}\nmean total loss: " \
                  "{}\n{}: {}\nmean vel loss: {}\nmean pos loss: {}\nmean dist loss: {}\nmean iou loss: {}".\
            format(self.config.tfrecords_dir, mean_total_loss, img_loss_type, mean_img_loss, mean_vel_loss, mean_pos_loss, mean_dist_loss, mean_iou_loss)

        with open(output_dir + '/mean_losses_over_full_test_set_of_{}.txt'.format(dataset_name), "a+") as text_file:
            text_file.write(str_out + "\n")

        print(str_out)

        return mean_total_loss, mean_vel_loss, mean_pos_loss, mean_dist_loss, mean_iou_loss
    def store_latent_vectors(self):
        assert self.config.n_epochs == 1, "set n_epochs to 1 for test mode"
        prefix = self.config.exp_name
        print("Storing latent vectors")
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

        df = pd.DataFrame(columns=['latent_vector_core_output_init_img', 'latent_vector_encoder_output_init_img', 'exp_id', 'exp_len'])
        sub_dir_name = "latent_vectors_initial_image_of_full_test_set_{}_iterations_trained".format(cur_batch_it)

        dir_path, _ = create_dir(os.path.join("../experiments", prefix), sub_dir_name)
        dataset_name = os.path.basename(self.config.tfrecords_dir)
        file_name = dir_path + "/latent_vectors_gn_dataset_{}.pkl".format(dataset_name)

        while True:
            try:

                features = self.sess.run(self.next_element_test)
                features = convert_dict_to_list_subdicts(features, self.config.test_batch_size)

                for i in range(len(features)):
                    input_graphs_all_exp, target_graphs_all_exp, _ = create_graphs(config=self.config, batch_data=features[i],
                                                                                initial_pos_vel_known=self.config.initial_pos_vel_known,
                                                                                batch_processing=False)

                    exp_id = features[i]['experiment_id']
                    exp_len = features[i]["unpadded_experiment_length"]  # the label

                    _, _, _, _, _, _, _, _, _, latent = self.do_step(
                        input_graphs_all_exp[0], target_graphs_all_exp[0], features[i], sigmoid_threshold=0.5, train=False,
                        batch_processing=False)

                    "shape of latent: (n_nodes, latent_dim)"
                    latent_core_output_init_img = latent[0].nodes
                    latent_encoder_output_init_img = latent[1].nodes

                    df = df.append({'latent_vector_core_output_init_img': latent_core_output_init_img,
                                    'latent_vector_encoder_output_init_img': latent_encoder_output_init_img,
                                    'exp_id': exp_id,
                                    'exp_len': exp_len}, ignore_index=True)

            except tf.errors.OutOfRangeError:
                df.to_pickle(file_name)
                print("Pandas dataframe with {} rows saved to: {} ".format(len(df.index), file_name))
                break
            else:
                print("continue")
                continue
    def compute_metrics_over_test_set(self):
        if not self.config.n_epochs == 1:
            print("test mode --> n_epochs will be set to 1")
            self.config.n_epochs = 1
        prefix = self.config.exp_name
        print("Computing IoU, Precision, Recall and F1 score over full test set".format(
            self.config.initial_pos_vel_known))
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

        iou_list_test_set = []
        prec_score_list_test_set = []
        rec_score_list_test_set = []
        f1_score_list_test_set = []
        sub_dir_name = "metric_computation_over_full_test_set_{}_iterations_trained".format(cur_batch_it)

        dir_path, _ = create_dir(os.path.join("../experiments", prefix), sub_dir_name)
        dataset_name = os.path.basename(self.config.tfrecords_dir)
        csv_name = "dataset_{}.csv".format(dataset_name)

        with open(os.path.join(dir_path, csv_name), 'w') as csv_file:
            writer = csv.writer(csv_file, delimiter='\t', lineterminator='\n', )
            writer.writerow(["(metrics averaged over n shapes and full trajectory) mean IoU", "mean precision", "mean recall", "mean f1 over n shapes", "exp_id"])
            while True:
                try:
                    losses_total, losses_img, losses_iou, losses_velocity, losses_position = [], [], [], [], []
                    losses_distance, outputs_total, targets_total, exp_id_total = [], [], [], []

                    features = self.sess.run(self.next_element_test)
                    features = convert_dict_to_list_subdicts(features, self.config.test_batch_size)

                    start_time = time.time()
                    last_log_time = start_time

                    for i in range(self.config.test_batch_size):
                        input_graphs_all_exp, target_graphs_all_exp, _ = create_graphs(config=self.config,
                                                                                    batch_data=features[i],
                                                                                    initial_pos_vel_known=self.config.initial_pos_vel_known,
                                                                                    batch_processing=False
                                                                                    )
                        output_i, target_i, exp_id_i = [], [], []

                        for j in range(features[i]["unpadded_experiment_length"] - 1):
                            total_loss, output, loss_img, loss_iou, loss_velocity, loss_position, loss_distance, target, _, _ = self.do_step(
                                                                                                                    input_graphs_all_exp[j],
                                                                                                                    target_graphs_all_exp[j],
                                                                                                                    features[i],
                                                                                                                    train=False,
                                                                                                                    batch_processing=False
                                                                                                                    )
                            output = output[0]
                            if total_loss is not None:
                                losses_total.append(total_loss)
                                losses_img.append(loss_img)
                                losses_iou.append(loss_iou)
                                losses_velocity.append(loss_velocity)
                                losses_position.append(loss_position)
                                losses_distance.append(loss_distance)

                            output_i.append(output)
                            target_i.append(target)
                            exp_id_i.append(features[i]['experiment_id'])

                        outputs_total.append((output_i, i))
                        targets_total.append((target_i, i))
                        exp_id_total.append(exp_id_i)

                    the_time = time.time()
                    elapsed_since_last_log = the_time - last_log_time
                    batch_loss,img_batch_loss, iou_batch_loss = np.mean(losses_total), np.mean(losses_img), np.mean(losses_iou)
                    vel_batch_loss, pos_batch_loss, dis_batch_loss = np.mean(losses_velocity), np.mean(losses_position), np.mean(losses_distance)
                    print('total test batch loss: {:<8.6f} | img loss: {:<8.6f} | iou loss: {:<8.6f} | vel loss: {:<8.6f} | pos loss {:<8.6f} | edge loss {:<8.6f} time(s): {:<10.2f}'.format(
                            batch_loss, img_batch_loss, iou_batch_loss, vel_batch_loss, pos_batch_loss, dis_batch_loss,
                            elapsed_since_last_log))

                    predictions_list, ground_truth_list = extract_input_and_output(outputs=outputs_total, targets=targets_total)

                    for pred_experiment, true_experiment, exp_id in zip(predictions_list, ground_truth_list, exp_id_total):
                        iou_scores = []
                        prec_scores = []
                        rec_scores = []
                        f1_scores = []

                        for pred, true in zip(pred_experiment, true_experiment):
                            iou = compute_iou(pred=pred, true=true)
                            mean_obj_prec_score, idx_obj_min_prec, idx_obj_max_prec = compute_precision(pred=pred, true=true)
                            mean_obj_rec_score, idx_obj_min_rec, idx_obj_max_rec = compute_recall(pred=pred, true=true)
                            mean_obj_f1_score, idx_obj_min_f1, idx_obj_max_f1 = compute_f1(pred=pred, true=true)

                            iou_scores.append(iou)
                            prec_scores.append(mean_obj_prec_score)
                            rec_scores.append(mean_obj_rec_score)
                            f1_scores.append(mean_obj_f1_score)

                        iou_traj_mean = np.mean(iou_scores)
                        prec_traj_mean = np.mean(prec_scores)
                        rec_traj_mean = np.mean(rec_scores)
                        f1_traj_mean = np.mean(f1_scores)

                        writer.writerow([iou_traj_mean, prec_traj_mean, rec_traj_mean, f1_traj_mean, exp_id[0]])

                        prec_score_list_test_set.append(prec_traj_mean)
                        rec_score_list_test_set.append(rec_traj_mean)
                        f1_score_list_test_set.append(f1_traj_mean)
                        iou_list_test_set.append(iou_traj_mean)

                    csv_file.flush()
                except tf.errors.OutOfRangeError:
                    break

            iou_test_set_mean = np.mean(iou_list_test_set)
            prec_test_set_mean = np.mean(prec_score_list_test_set)
            rec_test_set_mean = np.mean(rec_score_list_test_set)
            f1_test_set_mean = np.mean(f1_score_list_test_set)

            writer.writerow(["means over full set", " IoU: ", iou_test_set_mean, " Precision: ", prec_test_set_mean, " Recall: ", rec_test_set_mean, "F1: ", f1_test_set_mean])
            print("Done. mean IoU: {}, mean precision: {}, mean recall: {}, mean f1: {}".format(iou_test_set_mean, prec_test_set_mean, rec_test_set_mean, f1_test_set_mean))
    def store_latent_vectors(self):
        assert self.config.n_epochs == 1, "set n_epochs to 1 for test mode"
        assert self.config.test_batch_size == 1, "set test_batch_size to 1 for test mode"
        prefix = self.config.exp_name
        print("Storing latent vectors baseline")
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

        df = pd.DataFrame(
            columns=['latent_vector_init_img', 'exp_id', 'exp_len'])
        sub_dir_name = "latent_vectors_initial_image_of_full_test_set_{}_iterations_trained".format(
            cur_batch_it)

        dir_path, _ = create_dir(os.path.join("../experiments", prefix),
                                 sub_dir_name)
        dataset_name = os.path.basename(self.config.tfrecords_dir)
        file_name = dir_path + "/latent_vectors_baseline_auto_predictor_dataset_{}.pkl".format(
            dataset_name)

        while True:
            try:
                features = self.sess.run(self.next_element_test)
                features = convert_dict_to_list_subdicts(
                    features, self.config.test_batch_size)

                for i in range(len(features)):
                    input_graphs_all_exp, target_graphs_all_exp, _ = create_graphs(
                        config=self.config,
                        batch_data=features[i],
                        initial_pos_vel_known=self.config.
                        initial_pos_vel_known,
                        batch_processing=False)

                    input_graphs_all_exp = [input_graphs_all_exp[0]]
                    target_graphs_all_exp = [target_graphs_all_exp[0]]

                    exp_id = features[i]['experiment_id']
                    exp_len = features[i][
                        "unpadded_experiment_length"]  # the label

                    #print(np.shape(input_graphs_all_exp), np.shape(target_graphs_all_exp))

                    input_graphs_all_exp = [input_graphs_all_exp]
                    target_graphs_all_exp = [target_graphs_all_exp]

                    in_segxyz, in_image, in_control, gt_label = networkx_graphs_to_images(
                        self.config, input_graphs_all_exp,
                        target_graphs_all_exp)

                    #print(np.shape(in_segxyz), np.shape(in_image), np.shape(in_control), np.shape(gt_label), exp_len)

                    loss_img, out_label, latent_init_img = self.sess.run(
                        [
                            self.model.loss_op, self.out_prediction_softmax,
                            self.latent_init_img
                        ],
                        feed_dict={
                            self.in_segxyz_tf: in_segxyz,
                            self.in_image_tf: in_image,
                            self.gt_predictions: gt_label,
                            self.in_control_tf: in_control,
                            self.is_training: True
                        })

                    #print(np.shape(latent_init_img))

                    df = df.append(
                        {
                            'latent_vector_init_img': latent_init_img,
                            'exp_id': exp_id,
                            'exp_len': exp_len
                        },
                        ignore_index=True)

            except tf.errors.OutOfRangeError:
                df.to_pickle(file_name)
                print("Pandas dataframe with {} rows saved to: {} ".format(
                    len(df.index), file_name))
                break
            else:
                print("continue")
                continue
    def compute_metrics_over_test_set_multistep(self):
        assert self.config.n_epochs == 1, "test mode --> n_epochs must be set to 1"

        if self.config.model_zoo_file == "baseline_auto_predictor_extended_multistep" \
                and self.config.use_f_interact and not (self.config.train_batch_size == 1 and self.config.test_batch_size == 1):
            print(
                "--- when use_f_interact is True, train and test batch size need to be 1 since f_interact uses train"
                "batch_size to split the latent vector for the computation of pairwise object interactions"
            )
            return

        prefix = self.config.exp_name
        print(
            "Computing IoU, Precision, Recall and F1 score over full test set".
            format(self.config.initial_pos_vel_known))
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)
        """ this variable is used to distinguish between 1-step and n-step predictions. Setting it to True or False will call different functions:
         a) when set to False: the function uses "gt_label_rec" as the gt data, inputs are split into episode_length/n_prediction -chunks and fragments 
         (if they cannot be evenly divided) will be dismissed, e.g. episode_length 14 with n_predictions=5 will result in two 5-length chunks
         
         b) when set to True: in this mode, the model only predicts 1-steps. The function uses "in_segxyz" as the ground truth data, 
         the inputs don't have to be processed (e.g. splitting them into episode_length/n_predictions chunks) because after one step, 
         the model is reset to ground truth
          """
        test_single_step = False

        if test_single_step:
            mode_txt = "single_step_tested"
        else:
            mode_txt = "{}_step_tested".format(self.config.n_predictions)

        iou_list_test_set = []
        prec_score_list_test_set = []
        rec_score_list_test_set = []
        f1_score_list_test_set = []
        sub_dir_name = "metric_multistep_models_computation_over_full_test_set_{}_iterations_trained".format(
            cur_batch_it)

        dir_path, _ = create_dir(os.path.join("../experiments", prefix),
                                 sub_dir_name)
        dataset_name = os.path.basename(self.config.tfrecords_dir)
        csv_name = "{}_dataset_{}.csv".format(mode_txt, dataset_name)

        with open(os.path.join(dir_path, csv_name), 'w') as csv_file:
            writer = csv.writer(
                csv_file,
                delimiter='\t',
                lineterminator='\n',
            )
            writer.writerow([
                "(metrics averaged over n shapes and full trajectory) mean IoU",
                "mean precision", "mean recall", "mean f1 over n shapes",
                "exp_id"
            ])
            while True:
                try:
                    losses_total, losses_img, losses_iou, losses_velocity, losses_position = [], [], [], [], []
                    losses_distance, outputs_total, targets_total, exp_id_total = [], [], [], []

                    features = self.sess.run(self.next_element_test)
                    features = convert_dict_to_list_subdicts(
                        features, self.config.test_batch_size)

                    start_time = time.time()
                    last_log_time = start_time

                    for i in range(self.config.test_batch_size):
                        input_graphs_all_exp, target_graphs_all_exp, _ = create_graphs(
                            config=self.config,
                            batch_data=features[i],
                            initial_pos_vel_known=self.config.
                            initial_pos_vel_known,
                            batch_processing=False,
                            return_only_unpadded=True,
                            start_episode=0)
                        out_label_lst = []
                        in_seg_lst = []
                        exp_id = features[i]['experiment_id']
                        n_objects = features[i]['n_manipulable_objects']

                        if test_single_step:
                            """ create >episode-length< long pairs of input and target as lists to run single prediction steps """
                            assert len(input_graphs_all_exp) == len(
                                target_graphs_all_exp)
                            assert self.config.n_predictions == 1, 'set n_predictions to 1 if single_step is to be tested and re-initialize model'
                            single_step_prediction_chunks_input = [[
                                input_graph
                            ] for input_graph in input_graphs_all_exp]
                            single_step_prediction_chunks_target = [[
                                target_graph
                            ] for target_graph in target_graphs_all_exp]

                            for lst_inp, lst_targ in zip(
                                    single_step_prediction_chunks_input,
                                    single_step_prediction_chunks_target):
                                in_segxyz, in_image, in_control, gt_label, _ = networkx_graphs_to_images(
                                    self.config, [lst_inp], [lst_targ],
                                    multistep=True)

                                gt_latent = np.zeros(
                                    shape=(n_objects *
                                           self.config.n_predictions, 256))

                                loss_img, out_label = self.sess.run(
                                    [
                                        self.model.img_loss,
                                        self.out_prediction_softmax
                                    ],
                                    feed_dict={
                                        self.in_segxyz_tf: in_segxyz,
                                        self.in_image_tf: in_image,
                                        self.gt_predictions: gt_label,
                                        self.in_control_tf: in_control,
                                        self.gt_latent_vectors: gt_latent,
                                        self.is_training: True
                                    })

                                out_label[out_label >= 0.5] = 1.0
                                out_label[out_label < 0.5] = 0.0

                                loss_velocity = np.array(0.0)
                                loss_position = np.array(0.0)
                                loss_edge = np.array(0.0)
                                loss_iou = 0.0
                                loss_total = loss_img + loss_position + loss_edge + loss_velocity

                                losses_total.append(loss_total)
                                losses_img.append(loss_img)
                                losses_iou.append(loss_iou)
                                losses_velocity.append(loss_velocity)
                                losses_position.append(loss_position)
                                losses_distance.append(loss_edge)

                                in_seg_lst.append(in_segxyz[:, :, :, 0])
                                out_label_lst.append(out_label)

                        else:
                            """ this section is used for producing the output for an entire episode, meaning that the model 
                            is asked to re-predict after n_prediction steps up until the entire episode is covered 
                            (with potential crop in the end if episode length/n_predictions is odd) """
                            assert len(input_graphs_all_exp) == len(
                                target_graphs_all_exp)
                            assert self.config.n_predictions > 1, 'set n_predictions > 1 if multi_step is to be tested and re-initialize model'
                            n_prediction_chunks_target = []
                            n_prediction_chunks_input = []
                            n_predictions = self.config.n_predictions

                            for j in range(0, len(target_graphs_all_exp),
                                           n_predictions):
                                chunk = target_graphs_all_exp[j:j +
                                                              n_predictions]
                                n_prediction_chunks_target.append(chunk)

                                chunk = input_graphs_all_exp[j:j +
                                                             n_predictions]
                                n_prediction_chunks_input.append(chunk)
                            """ if the length of an episode cannot be evenly divided by n_predictions, remove the last 
                            odd list. end_idx ensures the array split is correctly handled later"""
                            assert all(len(inp_lst) == len(targ_lst) for inp_lst, targ_lst in
                                       zip(n_prediction_chunks_input, n_prediction_chunks_target)), \
                                "input and target lists are not equal after fragmenting them into episode length/n_predictions parts"

                            n_prediction_chunks_target_after_filtering = [
                                chunk for chunk in n_prediction_chunks_target
                                if len(chunk) == n_predictions
                            ]
                            n_prediction_chunks_input_after_filtering = [
                                chunk for chunk in n_prediction_chunks_input
                                if len(chunk) == n_predictions
                            ]

                            #number_removed_chunks = (len(n_prediction_chunks_target) - len(n_prediction_chunks_target_after_filtering))
                            #print("number of removed chunks (of up to 2 (2 step prediction) or 5 (5 step prediction) frames): ", number_removed_chunks)

                            for lst_inp, lst_targ in zip(
                                    n_prediction_chunks_input_after_filtering,
                                    n_prediction_chunks_target_after_filtering
                            ):
                                in_segxyz, in_image, in_control, gt_label, gt_label_rec = networkx_graphs_to_images(
                                    self.config, [lst_inp], [lst_targ],
                                    multistep=True)

                                gt_latent = np.zeros(
                                    shape=(n_objects *
                                           self.config.n_predictions, 256))

                                loss_img, out_label = self.sess.run(
                                    [
                                        self.model.img_loss,
                                        self.out_prediction_softmax
                                    ],
                                    feed_dict={
                                        self.in_segxyz_tf: in_segxyz,
                                        self.in_image_tf: in_image,
                                        self.gt_predictions: gt_label,
                                        self.in_control_tf: in_control,
                                        self.gt_latent_vectors: gt_latent,
                                        self.is_training: True
                                    })

                                out_label[out_label >= 0.5] = 1.0
                                out_label[out_label < 0.5] = 0.0

                                loss_velocity = np.array(0.0)
                                loss_position = np.array(0.0)
                                loss_edge = np.array(0.0)
                                loss_iou = 0.0
                                loss_total = loss_img + loss_position + loss_edge + loss_velocity

                                losses_total.append(loss_total)
                                losses_img.append(loss_img)
                                losses_iou.append(loss_iou)
                                losses_velocity.append(loss_velocity)
                                losses_position.append(loss_position)
                                losses_distance.append(loss_edge)
                                """ in multistep prediction, in_segxyz only contains the first image, therefore use 
                                gt_label_rec that contains all ground truth input images. Also split the model output 
                                by number of predictions since the model is set-up to append predictions along the 
                                first dimension batch-wise, i.e. ([batch1, batch2]) while both batches each have shape 
                                (5, 120, 160) for 5 objects. This results in total shape (10, 120, 160) and requires a 
                                split for example: 2 predictions, 5 objects: (10, 120, 160) --> [(5, 120, 160), (5, 120, 160)] """
                                gt_label_rec_split = np.split(gt_label_rec,
                                                              n_predictions,
                                                              axis=0)
                                out_label_split = np.split(out_label,
                                                           n_predictions,
                                                           axis=0)
                                in_seg_lst = in_seg_lst + gt_label_rec_split
                                out_label_lst = out_label_lst + out_label_split

                        out_label_entire_trajectory, in_seg_entire_trajectory = [], []

                        for n in range(n_objects):
                            out_obj_lst = []
                            in_obj_lst = []
                            for time_step_out, time_step_in in zip(
                                    out_label_lst, in_seg_lst):
                                out_obj_lst.append(time_step_out[n])
                                in_obj_lst.append(time_step_in[n])
                            out_label_entire_trajectory.append(
                                np.array(out_obj_lst))
                            in_seg_entire_trajectory.append(
                                np.array(in_obj_lst))

                        outputs_total.append(out_label_entire_trajectory)
                        targets_total.append(in_seg_entire_trajectory)
                        exp_id_total.append(exp_id)

                    the_time = time.time()
                    elapsed_since_last_log = the_time - last_log_time
                    batch_loss, img_batch_loss, iou_batch_loss = np.mean(
                        losses_total), np.mean(losses_img), np.mean(losses_iou)
                    vel_batch_loss, pos_batch_loss, dis_batch_loss = np.mean(
                        losses_velocity), np.mean(losses_position), np.mean(
                            losses_distance)
                    print(
                        'total test batch loss: {:<8.6f} | img loss: {:<8.6f} | iou loss: {:<8.6f} | vel loss: {:<8.6f} | pos loss {:<8.6f} | edge loss {:<8.6f} time(s): {:<10.2f}'
                        .format(batch_loss, img_batch_loss, iou_batch_loss,
                                vel_batch_loss, pos_batch_loss, dis_batch_loss,
                                elapsed_since_last_log))

                    for pred_experiment, true_experiment, exp_id in zip(
                            outputs_total, targets_total, exp_id_total):
                        iou_scores = []
                        prec_scores = []
                        rec_scores = []
                        f1_scores = []

                        # switch (n_objects, exp_len,...) to (exp_len, n_objects) since IoU computed per time step
                        pred_experiment = np.swapaxes(pred_experiment, 0, 1)
                        true_experiment = np.swapaxes(true_experiment, 0, 1)

                        for pred, true in zip(pred_experiment,
                                              true_experiment):
                            iou = compute_iou(pred=pred, true=true)
                            mean_obj_prec_score, idx_obj_min_prec, idx_obj_max_prec = compute_precision(
                                pred=pred, true=true)
                            mean_obj_rec_score, idx_obj_min_rec, idx_obj_max_rec = compute_recall(
                                pred=pred, true=true)
                            mean_obj_f1_score, idx_obj_min_f1, idx_obj_max_f1 = compute_f1(
                                pred=pred, true=true)

                            iou_scores.append(iou)
                            prec_scores.append(mean_obj_prec_score)
                            rec_scores.append(mean_obj_rec_score)
                            f1_scores.append(mean_obj_f1_score)

                        iou_traj_mean = np.mean(iou_scores)
                        prec_traj_mean = np.mean(prec_scores)
                        rec_traj_mean = np.mean(rec_scores)
                        f1_traj_mean = np.mean(f1_scores)

                        if exp_id is None:
                            print(
                                "shapes pred_experiment and true experiment: ",
                                np.shape(pred_experiment),
                                np.shape(true_experiment))

                        writer.writerow([
                            iou_traj_mean, prec_traj_mean, rec_traj_mean,
                            f1_traj_mean, exp_id
                        ])

                        if not (np.isnan(iou_traj_mean)
                                or np.isnan(prec_traj_mean)
                                or np.isnan(rec_traj_mean)
                                or np.isnan(f1_traj_mean)):
                            prec_score_list_test_set.append(prec_traj_mean)
                            rec_score_list_test_set.append(rec_traj_mean)
                            f1_score_list_test_set.append(f1_traj_mean)
                            iou_list_test_set.append(iou_traj_mean)

                    csv_file.flush()
                except tf.errors.OutOfRangeError:
                    break

            iou_test_set_mean = np.mean(iou_list_test_set)
            prec_test_set_mean = np.mean(prec_score_list_test_set)
            rec_test_set_mean = np.mean(rec_score_list_test_set)
            f1_test_set_mean = np.mean(f1_score_list_test_set)

            writer.writerow([
                "means over full set", " IoU: ", iou_test_set_mean,
                " Precision: ", prec_test_set_mean, " Recall: ",
                rec_test_set_mean, "F1: ", f1_test_set_mean
            ])
            if test_single_step:
                mode = "(model trained multi-step but tested single-step)"
            else:
                mode = "(model trained & tested multi-step)"

            print(
                "Done. mean IoU {}: {}, mean precision: {}, mean recall: {}, mean f1: {}"
                .format(mode, iou_test_set_mean, prec_test_set_mean,
                        rec_test_set_mean, f1_test_set_mean))
    def compute_metrics_over_test_set(self):
        if not self.config.n_epochs == 1:
            print("test mode --> n_epochs will be set to 1")
            self.config.n_epochs = 1
        prefix = self.config.exp_name
        print(
            "Computing IoU, Precision, Recall and F1 score over full test set".
            format(self.config.initial_pos_vel_known))
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

        iou_list_test_set = []
        prec_score_list_test_set = []
        rec_score_list_test_set = []
        f1_score_list_test_set = []
        sub_dir_name = "metric_computation_over_full_test_set_{}_iterations_trained".format(
            cur_batch_it)

        dir_path, _ = create_dir(os.path.join("../experiments", prefix),
                                 sub_dir_name)
        dataset_name = os.path.basename(self.config.tfrecords_dir)
        csv_name = "dataset_{}.csv".format(dataset_name)

        with open(os.path.join(dir_path, csv_name), 'w') as csv_file:
            writer = csv.writer(
                csv_file,
                delimiter='\t',
                lineterminator='\n',
            )
            writer.writerow([
                "(metrics averaged over n shapes and full trajectory) mean IoU",
                "mean precision", "mean recall", "mean f1 over n shapes",
                "exp_id"
            ])
            while True:
                try:
                    losses_total, losses_img, losses_iou, losses_velocity, losses_position = [], [], [], [], []
                    losses_distance, outputs_total, targets_total, exp_id_total = [], [], [], []

                    features = self.sess.run(self.next_element_test)
                    features = convert_dict_to_list_subdicts(
                        features, self.config.test_batch_size)

                    start_time = time.time()
                    last_log_time = start_time

                    for i in range(self.config.test_batch_size):
                        input_graphs_all_exp, target_graphs_all_exp, _ = create_graphs(
                            config=self.config,
                            batch_data=features[i],
                            initial_pos_vel_known=self.config.
                            initial_pos_vel_known,
                            batch_processing=False)
                        output_i, target_i, exp_id_i = [], [], []

                        input_graphs_all_exp = [input_graphs_all_exp]
                        target_graphs_all_exp = [target_graphs_all_exp]

                        in_segxyz, in_image, in_control, gt_label = networkx_graphs_to_images(
                            self.config, input_graphs_all_exp,
                            target_graphs_all_exp)

                        loss_img, out_label = self.sess.run(
                            [self.model.loss_op, self.out_prediction_softmax],
                            feed_dict={
                                self.in_segxyz_tf: in_segxyz,
                                self.in_image_tf: in_image,
                                self.gt_predictions: gt_label,
                                self.in_control_tf: in_control,
                                self.is_training: True
                            })

                        loss_velocity = np.array(0.0)
                        loss_position = np.array(0.0)
                        loss_edge = np.array(0.0)
                        loss_iou = 0.0
                        loss_total = loss_img + loss_position + loss_edge + loss_velocity

                        losses_total.append(loss_total)
                        losses_img.append(loss_img)
                        losses_iou.append(loss_iou)
                        losses_velocity.append(loss_velocity)
                        losses_position.append(loss_position)
                        losses_distance.append(loss_edge)

                        out_label[out_label >= 0.5] = 1.0
                        out_label[out_label < 0.5] = 0.0

                        exp_id_i.append(features[i]['experiment_id'])

                        unpad_exp_length = features[i][
                            'unpadded_experiment_length']
                        n_objects = features[i]['n_manipulable_objects']
                        print(np.shape(out_label))
                        out_label_split = np.split(out_label,
                                                   unpad_exp_length - 1)
                        in_seg_split = np.split(in_segxyz[:, :, :, 0],
                                                unpad_exp_length - 1)

                        out_label_entire_trajectory, in_seg_entire_trajectory = [], []

                        for n in range(n_objects):
                            out_obj_lst = []
                            in_obj_lst = []
                            for time_step_out, time_step_in in zip(
                                    out_label_split, in_seg_split):
                                out_obj_lst.append(time_step_out[n])
                                in_obj_lst.append(time_step_in[n])
                            out_label_entire_trajectory.append(
                                np.array(out_obj_lst))
                            in_seg_entire_trajectory.append(
                                np.array(in_obj_lst))

                        outputs_total.append(out_label_entire_trajectory)
                        targets_total.append(in_seg_entire_trajectory)
                        exp_id_total.append(exp_id_i)

                    the_time = time.time()
                    elapsed_since_last_log = the_time - last_log_time
                    batch_loss, img_batch_loss, iou_batch_loss = np.mean(
                        losses_total), np.mean(losses_img), np.mean(losses_iou)
                    vel_batch_loss, pos_batch_loss, dis_batch_loss = np.mean(
                        losses_velocity), np.mean(losses_position), np.mean(
                            losses_distance)
                    print(
                        'total test batch loss: {:<8.6f} | img loss: {:<8.6f} | iou loss: {:<8.6f} | vel loss: {:<8.6f} | pos loss {:<8.6f} | edge loss {:<8.6f} time(s): {:<10.2f}'
                        .format(batch_loss, img_batch_loss, iou_batch_loss,
                                vel_batch_loss, pos_batch_loss, dis_batch_loss,
                                elapsed_since_last_log))

                    for pred_experiment, true_experiment, exp_id in zip(
                            outputs_total, targets_total, exp_id_total):
                        iou_scores = []
                        prec_scores = []
                        rec_scores = []
                        f1_scores = []

                        # switch (n_objects, exp_len,...) to (exp_len, n_objects) since IoU computed per time step
                        pred_experiment = np.swapaxes(pred_experiment, 0, 1)
                        true_experiment = np.swapaxes(true_experiment, 0, 1)

                        for pred, true in zip(pred_experiment,
                                              true_experiment):
                            iou = compute_iou(pred=pred, true=true)
                            mean_obj_prec_score, idx_obj_min_prec, idx_obj_max_prec = compute_precision(
                                pred=pred, true=true)
                            mean_obj_rec_score, idx_obj_min_rec, idx_obj_max_rec = compute_recall(
                                pred=pred, true=true)
                            mean_obj_f1_score, idx_obj_min_f1, idx_obj_max_f1 = compute_f1(
                                pred=pred, true=true)

                            iou_scores.append(iou)
                            prec_scores.append(mean_obj_prec_score)
                            rec_scores.append(mean_obj_rec_score)
                            f1_scores.append(mean_obj_f1_score)

                        iou_traj_mean = np.mean(iou_scores)
                        prec_traj_mean = np.mean(prec_scores)
                        rec_traj_mean = np.mean(rec_scores)
                        f1_traj_mean = np.mean(f1_scores)

                        writer.writerow([
                            iou_traj_mean, prec_traj_mean, rec_traj_mean,
                            f1_traj_mean, exp_id[0]
                        ])

                        prec_score_list_test_set.append(prec_traj_mean)
                        rec_score_list_test_set.append(rec_traj_mean)
                        f1_score_list_test_set.append(f1_traj_mean)
                        iou_list_test_set.append(iou_traj_mean)

                    csv_file.flush()
                except tf.errors.OutOfRangeError:
                    break

            iou_test_set_mean = np.mean(iou_list_test_set)
            prec_test_set_mean = np.mean(prec_score_list_test_set)
            rec_test_set_mean = np.mean(rec_score_list_test_set)
            f1_test_set_mean = np.mean(f1_score_list_test_set)

            writer.writerow([
                "means over full set", " IoU: ", iou_test_set_mean,
                " Precision: ", prec_test_set_mean, " Recall: ",
                rec_test_set_mean, "F1: ", f1_test_set_mean
            ])
            print(
                "Done. mean IoU: {}, mean precision: {}, mean recall: {}, mean f1: {}"
                .format(iou_test_set_mean, prec_test_set_mean,
                        rec_test_set_mean, f1_test_set_mean))
    def save_encoder_vectors(self, train=True):
        assert self.config.n_epochs == 1, "set n_epochs to 1 for test mode"

        if "5_objects_50_rollouts_padded_novel" in self.config.tfrecords_dir:
            dir_name = "auto_encoding_features_5_objects_50_rollouts_novel"
        elif "5_objects_50_rollouts" in self.config.tfrecords_dir:
            dir_name = "auto_encoding_features_5_objects_50_rollouts"
        else:
            #dir_name = "auto_encoding_features_3_objects_15_rollouts"
            dir_name = "auto_encoding_features_2_bigger_cubes_dataset_50_rollouts"

        if train:
            next_element = self.next_element_train
            sub_dir_name = "train"
            batch_size = self.config.train_batch_size
        else:
            next_element = self.next_element_test
            sub_dir_name = "test"
            batch_size = self.config.test_batch_size

        dir_path, _ = create_dir(os.path.join("/scr2/fabiof/data/"), dir_name)
        dir_path, _ = create_dir(os.path.join("/scr2/fabiof/data/", dir_name),
                                 sub_dir_name)
        #dir_path, _ = create_dir(os.path.join("../experiments/"), dir_name)
        #dir_path, _ = create_dir(os.path.join("../experiments/", dir_name), sub_dir_name)
        iterator = 0
        while True:
            try:
                features = self.sess.run(next_element)
                features = convert_dict_to_list_subdicts(features, batch_size)

                for i in range(len(features)):
                    input_graphs_all_exp, target_graphs_all_exp, _ = create_graphs(
                        config=self.config,
                        batch_data=features[i],
                        initial_pos_vel_known=self.config.
                        initial_pos_vel_known,
                        batch_processing=False,
                        return_only_unpadded=True,
                        start_episode=0)

                    encoder_outputs = []
                    exp_id = features[i]['experiment_id']
                    iterator = iterator + 1

                    assert len(input_graphs_all_exp) == len(
                        target_graphs_all_exp)
                    single_step_prediction_chunks_input = [
                        [input_graph] for input_graph in input_graphs_all_exp
                    ]
                    single_step_prediction_chunks_target = [[
                        target_graph
                    ] for target_graph in target_graphs_all_exp]

                    for lst_inp, lst_targ in zip(
                            single_step_prediction_chunks_input,
                            single_step_prediction_chunks_target):
                        in_segxyz, in_image, in_control, gt_label, _ = networkx_graphs_to_images(
                            self.config, [lst_inp], [lst_targ], multistep=True)

                        in_images, out_reconstructions, encoder_output = self.sess.run(
                            [
                                self.in_rgb_seg_xyz,
                                self.out_prediction_softmax,
                                self.encoder_outputs
                            ],
                            feed_dict={
                                self.in_segxyz_tf: in_segxyz,
                                self.in_image_tf: in_image,
                                self.gt_predictions: gt_label,
                                self.in_control_tf: in_control,
                                self.is_training: True
                            })

                        encoder_outputs.append(encoder_output)
                    print("saved encoder vector number {} under {}".format(
                        iterator, os.path.join(dir_path, str(exp_id))))
                    np.savez_compressed(os.path.join(dir_path, str(exp_id)),
                                        encoder_outputs=encoder_outputs,
                                        exp_id=exp_id)

            except tf.errors.OutOfRangeError:
                break