Exemple #1
0
    def initialize_test_model(self):
        next_element = self.test_data.get_next_batch()
        features = self.sess.run(next_element)
        features = convert_dict_to_list_subdicts(features,
                                                 self.config.test_batch_size)

        input_ph, target_ph = create_placeholders(self.config,
                                                  features[0],
                                                  batch_processing=False)

        self.model.input_ph_test = input_ph
        self.model.target_ph_test = target_ph

        self.model.is_training = False
        self.model.output_ops_test, self.model.latent_core_output_init_img_test, self.model.latent_encoder_output_init_img_test = self.model(
            self.model.input_ph_test, self.model.target_ph_test, 1,
            self.model.is_training)

        total_loss_ops_test, loss_ops_test_img, loss_ops_test_iou, loss_ops_test_velocity, loss_ops_test_position, loss_ops_test_distance, loss_ops_test_global = create_loss_ops(
            self.config, self.model.target_ph_test, self.model.output_ops_test)
        ''' remove all inf values --> correspond to padded entries '''
        self.model.loss_op_test_total = total_loss_ops_test
        self.model.loss_ops_test_img = loss_ops_test_img
        self.model.loss_ops_test_iou = loss_ops_test_iou
        self.model.loss_ops_test_velocity = loss_ops_test_velocity
        self.model.loss_ops_test_position = loss_ops_test_position
        self.model.loss_ops_test_distance = loss_ops_test_distance
        self.model.loss_ops_test_global = loss_ops_test_global
    def store_latent_vectors(self):
        assert self.config.n_epochs == 1, "set n_epochs to 1 for test mode"
        prefix = self.config.exp_name
        print("Storing latent vectors")
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

        df = pd.DataFrame(columns=['latent_vector_core_output_init_img', 'latent_vector_encoder_output_init_img', 'exp_id', 'exp_len'])
        sub_dir_name = "latent_vectors_initial_image_of_full_test_set_{}_iterations_trained".format(cur_batch_it)

        dir_path, _ = create_dir(os.path.join("../experiments", prefix), sub_dir_name)
        dataset_name = os.path.basename(self.config.tfrecords_dir)
        file_name = dir_path + "/latent_vectors_gn_dataset_{}.pkl".format(dataset_name)

        while True:
            try:

                features = self.sess.run(self.next_element_test)
                features = convert_dict_to_list_subdicts(features, self.config.test_batch_size)

                for i in range(len(features)):
                    input_graphs_all_exp, target_graphs_all_exp, _ = create_graphs(config=self.config, batch_data=features[i],
                                                                                initial_pos_vel_known=self.config.initial_pos_vel_known,
                                                                                batch_processing=False)

                    exp_id = features[i]['experiment_id']
                    exp_len = features[i]["unpadded_experiment_length"]  # the label

                    _, _, _, _, _, _, _, _, _, latent = self.do_step(
                        input_graphs_all_exp[0], target_graphs_all_exp[0], features[i], sigmoid_threshold=0.5, train=False,
                        batch_processing=False)

                    "shape of latent: (n_nodes, latent_dim)"
                    latent_core_output_init_img = latent[0].nodes
                    latent_encoder_output_init_img = latent[1].nodes

                    df = df.append({'latent_vector_core_output_init_img': latent_core_output_init_img,
                                    'latent_vector_encoder_output_init_img': latent_encoder_output_init_img,
                                    'exp_id': exp_id,
                                    'exp_len': exp_len}, ignore_index=True)

            except tf.errors.OutOfRangeError:
                df.to_pickle(file_name)
                print("Pandas dataframe with {} rows saved to: {} ".format(len(df.index), file_name))
                break
            else:
                print("continue")
                continue
    def train_multiple_batches(self, prefix, n_batches_trained, n_batches_trained_since_last_save):
        features = self.sess.run(self.next_element_train)
        features = convert_dict_to_list_subdicts(features, self.config.train_batch_size)

        input_graphs_batches, target_graphs_batches, _ = create_graphs(config=self.config,
                                                                    batch_data=features,
                                                                    initial_pos_vel_known=self.config.initial_pos_vel_known,
                                                                    batch_processing=True
                                                                    )

        if len(input_graphs_batches[-1]) != self.config.train_batch_size:
            input_graphs_batches = input_graphs_batches[:-1]
            target_graphs_batches = target_graphs_batches[:-1]

        for input_batch, target_batch in zip(input_graphs_batches, target_graphs_batches):
            start_time = time.time()
            last_log_time = start_time

            total_loss, _, loss_img, loss_iou, loss_velocity, loss_position, loss_distance, _, loss_global, _ = self.do_step(input_batch, target_batch, features, train=True, batch_processing=True)

            self.sess.run(self.model.increment_cur_batch_tensor)
            cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

            the_time = time.time()
            elapsed_since_last_log = the_time - last_log_time

            print(
                'batch: {:<8} total loss: {:<8.6f} | img loss: {:<8.6f} | iou loss: {:<8.6f} | vel loss: {:<8.6f} | pos loss: {:<8.6f} | edge loss: {:<8.6f} | global loss: {:<8.6f} | time(s): {:<10.2f} '
                    .format(cur_batch_it, total_loss, loss_img, loss_iou, loss_velocity, loss_position,
                            loss_distance, loss_global, elapsed_since_last_log)
            )
            summaries_dict = {prefix + '_total_loss': total_loss,
                              prefix + '_img_loss': loss_img,
                              prefix + '_iou_loss': loss_iou,
                              prefix + '_velocity_loss': loss_velocity,
                              prefix + '_position_loss': loss_position,
                              prefix + '_distance_loss': loss_distance,
                              prefix + '_global_loss': loss_global
                              }
            self.logger.summarize(cur_batch_it, summaries_dict=summaries_dict, summarizer="train")

        return cur_batch_it, n_batches_trained + len(input_graphs_batches), n_batches_trained_since_last_save + len(input_graphs_batches)
Exemple #4
0
    def initialize_train_model(self):
        next_element = self.train_data.get_next_batch()
        features = self.sess.run(next_element)
        features = convert_dict_to_list_subdicts(features,
                                                 self.config.train_batch_size)

        if self.config.batch_processing:
            input_ph, target_ph = create_placeholders(self.config,
                                                      features,
                                                      batch_processing=True)
        else:
            input_ph, target_ph = create_placeholders(self.config,
                                                      features[0],
                                                      batch_processing=False)

        self.model.input_ph = input_ph
        self.model.target_ph = target_ph
        self.model.is_training = True
        self.model.output_ops_train, self.model.latent_core_output_init_img_train, self.model.latent_encoder_output_init_img_train = self.model(
            self.model.input_ph, self.model.target_ph, 1,
            self.model.is_training)

        total_loss_ops, loss_ops_img, loss_ops_iou, loss_ops_velocity, loss_ops_position, loss_ops_distance, loss_ops_global = create_loss_ops(
            self.config, self.model.target_ph, self.model.output_ops_train)
        ''' remove all inf values --> correspond to padded entries '''
        self.model.loss_op_train_total = total_loss_ops
        self.model.loss_ops_train_img = loss_ops_img  # just for summary, is already included in loss_op_train
        self.model.loss_ops_train_iou = loss_ops_iou
        self.model.loss_ops_train_velocity = loss_ops_velocity
        self.model.loss_ops_train_position = loss_ops_position
        self.model.loss_ops_train_distance = loss_ops_distance
        self.model.loss_ops_train_global = loss_ops_global
        #self.model.train_logits = logits

        self.model.step_op = self.model.optimizer.minimize(
            self.model.loss_op_train_total,
            global_step=self.model.global_step_tensor)
    def train_batch(self, prefix):
        losses = []
        losses_img = []
        losses_iou = []
        losses_velocity = []
        losses_position = []
        losses_distance = []
        losses_global = []

        features = self.sess.run(self.next_element_train)

        features = convert_dict_to_list_subdicts(features, self.config.train_batch_size)

        start_time = time.time()
        last_log_time = start_time


        for i in range(self.config.train_batch_size):
            input_graphs_all_exp, target_graphs_all_exp, _ = create_graphs(config=self.config,
                                                                        batch_data=features[i],
                                                                        initial_pos_vel_known=self.config.initial_pos_vel_known,
                                                                        batch_processing=False
                                                                        )


            for j in range(features[i]["unpadded_experiment_length"]-1):
                total_loss, _, loss_img, loss_iou, loss_velocity, loss_position, loss_distance, _, loss_global, _ = self.do_step(input_graphs_all_exp[j],
                                                                                                       target_graphs_all_exp[j],
                                                                                                       features[i],
                                                                                                       train=True,
                                                                                                       batch_processing=False
                                                                                                       )

                if total_loss is not None:
                    losses.append(total_loss)
                    losses_img.append(loss_img)
                    losses_iou.append(loss_iou)
                    losses_velocity.append(loss_velocity)
                    losses_position.append(loss_position)
                    losses_distance.append(loss_distance)
                    losses_global.append(loss_global)

            the_time = time.time()
            elapsed_since_last_log = the_time - last_log_time

        self.sess.run(self.model.increment_cur_batch_tensor)
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

        if losses:
            batch_loss = np.mean(losses)
            img_batch_loss = np.mean(losses_img)
            iou_batch_loss = np.mean(losses_iou)
            vel_batch_loss = np.mean(losses_velocity)
            pos_batch_loss = np.mean(losses_position)
            dis_batch_loss = np.mean(losses_distance)
            glob_batch_loss = np.mean(losses_global)

            print(
                'batch: {:<8} total loss: {:<8.6f} | img loss: {:<8.6f} | iou loss: {:<8.6f} | vel loss: {:<8.6f} | pos loss: {:<8.6f} | edge loss: {:<8.6f} | global loss: {:<8.6f} | time(s): {:<10.2f} '
                .format(cur_batch_it, batch_loss, img_batch_loss, iou_batch_loss, vel_batch_loss, pos_batch_loss,
                        dis_batch_loss, glob_batch_loss, elapsed_since_last_log)
                )
            summaries_dict = {prefix + '_total_loss': batch_loss,
                              prefix + '_img_loss': img_batch_loss,
                              prefix + '_iou_loss': iou_batch_loss,
                              prefix + '_velocity_loss': vel_batch_loss,
                              prefix + '_position_loss': pos_batch_loss,
                              prefix + '_distance_loss': dis_batch_loss,
                              prefix + '_global_loss': glob_batch_loss
                              }
            self.logger.summarize(cur_batch_it, summaries_dict=summaries_dict, summarizer="train")

        return cur_batch_it
    def compute_metrics_over_test_set(self):
        if not self.config.n_epochs == 1:
            print("test mode --> n_epochs will be set to 1")
            self.config.n_epochs = 1
        prefix = self.config.exp_name
        print("Computing IoU, Precision, Recall and F1 score over full test set".format(
            self.config.initial_pos_vel_known))
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

        iou_list_test_set = []
        prec_score_list_test_set = []
        rec_score_list_test_set = []
        f1_score_list_test_set = []
        sub_dir_name = "metric_computation_over_full_test_set_{}_iterations_trained".format(cur_batch_it)

        dir_path, _ = create_dir(os.path.join("../experiments", prefix), sub_dir_name)
        dataset_name = os.path.basename(self.config.tfrecords_dir)
        csv_name = "dataset_{}.csv".format(dataset_name)

        with open(os.path.join(dir_path, csv_name), 'w') as csv_file:
            writer = csv.writer(csv_file, delimiter='\t', lineterminator='\n', )
            writer.writerow(["(metrics averaged over n shapes and full trajectory) mean IoU", "mean precision", "mean recall", "mean f1 over n shapes", "exp_id"])
            while True:
                try:
                    losses_total, losses_img, losses_iou, losses_velocity, losses_position = [], [], [], [], []
                    losses_distance, outputs_total, targets_total, exp_id_total = [], [], [], []

                    features = self.sess.run(self.next_element_test)
                    features = convert_dict_to_list_subdicts(features, self.config.test_batch_size)

                    start_time = time.time()
                    last_log_time = start_time

                    for i in range(self.config.test_batch_size):
                        input_graphs_all_exp, target_graphs_all_exp, _ = create_graphs(config=self.config,
                                                                                    batch_data=features[i],
                                                                                    initial_pos_vel_known=self.config.initial_pos_vel_known,
                                                                                    batch_processing=False
                                                                                    )
                        output_i, target_i, exp_id_i = [], [], []

                        for j in range(features[i]["unpadded_experiment_length"] - 1):
                            total_loss, output, loss_img, loss_iou, loss_velocity, loss_position, loss_distance, target, _, _ = self.do_step(
                                                                                                                    input_graphs_all_exp[j],
                                                                                                                    target_graphs_all_exp[j],
                                                                                                                    features[i],
                                                                                                                    train=False,
                                                                                                                    batch_processing=False
                                                                                                                    )
                            output = output[0]
                            if total_loss is not None:
                                losses_total.append(total_loss)
                                losses_img.append(loss_img)
                                losses_iou.append(loss_iou)
                                losses_velocity.append(loss_velocity)
                                losses_position.append(loss_position)
                                losses_distance.append(loss_distance)

                            output_i.append(output)
                            target_i.append(target)
                            exp_id_i.append(features[i]['experiment_id'])

                        outputs_total.append((output_i, i))
                        targets_total.append((target_i, i))
                        exp_id_total.append(exp_id_i)

                    the_time = time.time()
                    elapsed_since_last_log = the_time - last_log_time
                    batch_loss,img_batch_loss, iou_batch_loss = np.mean(losses_total), np.mean(losses_img), np.mean(losses_iou)
                    vel_batch_loss, pos_batch_loss, dis_batch_loss = np.mean(losses_velocity), np.mean(losses_position), np.mean(losses_distance)
                    print('total test batch loss: {:<8.6f} | img loss: {:<8.6f} | iou loss: {:<8.6f} | vel loss: {:<8.6f} | pos loss {:<8.6f} | edge loss {:<8.6f} time(s): {:<10.2f}'.format(
                            batch_loss, img_batch_loss, iou_batch_loss, vel_batch_loss, pos_batch_loss, dis_batch_loss,
                            elapsed_since_last_log))

                    predictions_list, ground_truth_list = extract_input_and_output(outputs=outputs_total, targets=targets_total)

                    for pred_experiment, true_experiment, exp_id in zip(predictions_list, ground_truth_list, exp_id_total):
                        iou_scores = []
                        prec_scores = []
                        rec_scores = []
                        f1_scores = []

                        for pred, true in zip(pred_experiment, true_experiment):
                            iou = compute_iou(pred=pred, true=true)
                            mean_obj_prec_score, idx_obj_min_prec, idx_obj_max_prec = compute_precision(pred=pred, true=true)
                            mean_obj_rec_score, idx_obj_min_rec, idx_obj_max_rec = compute_recall(pred=pred, true=true)
                            mean_obj_f1_score, idx_obj_min_f1, idx_obj_max_f1 = compute_f1(pred=pred, true=true)

                            iou_scores.append(iou)
                            prec_scores.append(mean_obj_prec_score)
                            rec_scores.append(mean_obj_rec_score)
                            f1_scores.append(mean_obj_f1_score)

                        iou_traj_mean = np.mean(iou_scores)
                        prec_traj_mean = np.mean(prec_scores)
                        rec_traj_mean = np.mean(rec_scores)
                        f1_traj_mean = np.mean(f1_scores)

                        writer.writerow([iou_traj_mean, prec_traj_mean, rec_traj_mean, f1_traj_mean, exp_id[0]])

                        prec_score_list_test_set.append(prec_traj_mean)
                        rec_score_list_test_set.append(rec_traj_mean)
                        f1_score_list_test_set.append(f1_traj_mean)
                        iou_list_test_set.append(iou_traj_mean)

                    csv_file.flush()
                except tf.errors.OutOfRangeError:
                    break

            iou_test_set_mean = np.mean(iou_list_test_set)
            prec_test_set_mean = np.mean(prec_score_list_test_set)
            rec_test_set_mean = np.mean(rec_score_list_test_set)
            f1_test_set_mean = np.mean(f1_score_list_test_set)

            writer.writerow(["means over full set", " IoU: ", iou_test_set_mean, " Precision: ", prec_test_set_mean, " Recall: ", rec_test_set_mean, "F1: ", f1_test_set_mean])
            print("Done. mean IoU: {}, mean precision: {}, mean recall: {}, mean f1: {}".format(iou_test_set_mean, prec_test_set_mean, rec_test_set_mean, f1_test_set_mean))
def main():
    # create tensorflow session
    sess = tf.Session()

    try:
        args = get_args()
        config = process_config(args.config)

        config.old_tfrecords = args.old_tfrecords
        config.normalize_data = args.normalize_data

    except Exception as e:
        print("An error occurred during processing the configuration file")
        print(e)
        exit(0)

    # create the experiments dirs
    create_dirs(
        [config.summary_dir, config.checkpoint_dir, config.config_file_dir])

    # create your data generator
    train_data = DataGenerator(config, sess, train=True)
    test_data = DataGenerator(config, sess, train=False)

    logger = Logger(sess, config)

    print("using {} rollout steps".format(config.n_rollouts))

    inp_rgb = tf.placeholder("float", [None, 120, 160, 7])
    control = tf.placeholder("float", [None, 6])
    gt_seg = tf.placeholder("float", [None, 120, 160])

    pred = cnnmodel(inp_rgb, control)

    predictions = tf.reshape(
        pred, [-1, pred.get_shape()[1] * pred.get_shape()[2]])
    labels = tf.reshape(gt_seg,
                        [-1, gt_seg.get_shape()[1] * gt_seg.get_shape()[2]])

    global_step_tensor = tf.Variable(0, trainable=False, name='global_step')

    loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
                                                logits=predictions))
    optimizer = tf.train.AdamOptimizer(
        learning_rate=config.learning_rate).minimize(
            loss, global_step=global_step_tensor)

    with tf.variable_scope('cur_epoch'):
        cur_epoch_tensor = tf.Variable(0, trainable=False, name='cur_epoch')
        increment_cur_epoch_tensor = tf.assign(cur_epoch_tensor,
                                               cur_epoch_tensor + 1)

    with tf.variable_scope('global_step'):
        cur_batch_tensor = tf.Variable(0, trainable=False, name='cur_batch')
        increment_cur_batch_tensor = tf.assign(cur_batch_tensor,
                                               cur_batch_tensor + 1)

    next_element_train = train_data.get_next_batch()
    next_element_test = test_data.get_next_batch()

    init = tf.group(tf.global_variables_initializer(),
                    tf.local_variables_initializer())
    sess.run(init)

    saver = tf.train.Saver(max_to_keep=config.max_checkpoints_to_keep)

    latest_checkpoint = tf.train.latest_checkpoint(config.checkpoint_dir)
    if latest_checkpoint:
        print("Loading model checkpoint {} ...\n".format(latest_checkpoint))
        saver.restore(sess, latest_checkpoint)
        print("Model loaded")

    def _process_rollouts(feature, train=True):
        gt_merged_seg_rollout_batch = []
        input_merged_images_rollout_batch = []
        gripper_pos_vel_rollout_batch = []
        for step in range(config.n_rollouts - 1):
            if step < feature["unpadded_experiment_length"]:
                obj_segments = feature["object_segments"][step]
                """ transform (3,120,160,7) into (1,120,160,7) by merging the rgb,depth and seg masks """
                input_merged_images = create_full_images_of_object_masks(
                    obj_segments)

                obj_segments_gt = feature["object_segments"][step + 1]
                gt_merged_seg = create_full_images_of_object_masks(
                    obj_segments_gt)[:, :, 3]

                gripper_pos = feature["gripperpos"][step + 1]
                gripper_vel = feature["grippervel"][step + 1]
                gripper_pos_vel = np.concatenate([gripper_pos, gripper_vel])

                gt_merged_seg_rollout_batch.append(gt_merged_seg)
                input_merged_images_rollout_batch.append(input_merged_images)
                gripper_pos_vel_rollout_batch.append(gripper_pos_vel)

        if train:
            retrn = sess.run(
                [optimizer, loss, pred],
                feed_dict={
                    inp_rgb: input_merged_images_rollout_batch,
                    control: gripper_pos_vel_rollout_batch,
                    gt_seg: gt_merged_seg_rollout_batch
                })
            return retrn[1], retrn[2]

        else:
            retrn = sess.run(
                [loss, pred],
                feed_dict={
                    inp_rgb: input_merged_images_rollout_batch,
                    control: gripper_pos_vel_rollout_batch,
                    gt_seg: gt_merged_seg_rollout_batch
                })
            """ sigmoid cross entropy runs logits through sigmoid but only during train time """
            seg_data = sigmoid(retrn[1])
            seg_data[seg_data >= 0.5] = 1.0
            seg_data[seg_data < 0.5] = 0.0
            return retrn[0], seg_data, gt_merged_seg_rollout_batch

    for cur_epoch in range(cur_epoch_tensor.eval(sess), config.n_epochs + 1,
                           1):
        while True:
            try:
                features = sess.run(next_element_train)
                features = convert_dict_to_list_subdicts(
                    features, config.train_batch_size)
                loss_batch = []
                sess.run(increment_cur_batch_tensor)
                for _ in range(config.train_batch_size):
                    for feature in features:
                        loss_train, _ = _process_rollouts(feature)
                        loss_batch.append([loss_train])

                cur_batch_it = cur_batch_tensor.eval(sess)
                loss_mean_batch = np.mean(loss_batch)

                print('train loss batch {0:} is: {1:.4f}'.format(
                    cur_batch_it, loss_mean_batch))
                summaries_dict = {config.exp_name + '_loss': loss_mean_batch}
                logger.summarize(cur_batch_it,
                                 summaries_dict=summaries_dict,
                                 summarizer="train")

                if cur_batch_it % config.test_interval == 1:
                    print("Executing test batch")
                    features_idx = 0  # always take first element for testing
                    features = sess.run(next_element_test)
                    features = convert_dict_to_list_subdicts(
                        features, config.test_batch_size)
                    loss_test_batch = []

                    for i in range(config.test_batch_size):
                        loss_test, seg_data, gt_seg_data = _process_rollouts(
                            features[features_idx], train=False)
                        loss_test_batch.append(loss_test)

                    loss_test_mean_batch = np.mean(loss_test_batch)
                    summaries_dict = {
                        config.exp_name + '_test_loss': loss_test_mean_batch
                    }
                    logger.summarize(cur_batch_it,
                                     summaries_dict=summaries_dict,
                                     summarizer="test")

                    print('test loss is: {0:.4f}'.format(loss_test_mean_batch))
                    if seg_data is not None and gt_seg_data is not None:
                        """ create gif here """
                        create_seg_gif(features,
                                       features_idx,
                                       config,
                                       seg_data,
                                       gt_seg_data,
                                       dir_name="tests_during_training",
                                       cur_batch_it=cur_batch_it)

                if cur_batch_it % config.model_save_step_interval == 1:
                    print("Saving model...")
                    saver.save(sess, config.checkpoint_dir, global_step_tensor)
                    print("Model saved")

            except tf.errors.OutOfRangeError:
                break

        sess.run(increment_cur_epoch_tensor)

        return None
    def test_batch(self,
                   prefix,
                   initial_pos_vel_known,
                   export_images=False,
                   process_all_nn_outputs=False,
                   sub_dir_name=None,
                   export_latent_data=True,
                   output_results=True):

        losses_total = []
        losses_img = []
        losses_velocity = []
        losses_position = []
        losses_edge = []
        outputs_total = []
        summaries_dict_images = {}

        features = self.sess.run(self.next_element_test)
        features = convert_dict_to_list_subdicts(features,
                                                 self.config.test_batch_size)

        start_time = time.time()
        last_log_time = start_time

        if self.config.do_multi_step_prediction:
            multistep = True
        else:
            multistep = False
        start_idx = 0
        end_idx = self.config.n_predictions

        for i in range(self.config.test_batch_size):
            input_graphs_all_exp, target_graphs_all_exp, _ = create_graphs(
                config=self.config,
                batch_data=features[i],
                initial_pos_vel_known=self.config.initial_pos_vel_known,
                batch_processing=False)

            if multistep:
                input_graphs_all_exp = [
                    input_graphs_all_exp[start_idx:end_idx]
                ]
                target_graphs_all_exp = [
                    target_graphs_all_exp[start_idx:end_idx]
                ]

            in_segxyz, in_image, in_control, _, gt_reconstructions = networkx_graphs_to_images(
                self.config,
                input_graphs_all_exp,
                target_graphs_all_exp,
                multistep=multistep)

            loss_img, out_reconstructions, latent_img_feature = self.sess.run(
                [
                    self.model.loss_op, self.out_prediction_softmax,
                    self.encoder_outputs
                ],
                feed_dict={
                    self.in_segxyz_tf: in_segxyz,
                    self.in_image_tf: in_image,
                    self.gt_predictions:
                    gt_reconstructions,  # this is intentional to maintain same names
                    self.in_control_tf:
                    in_control,  # not actually used for auto-encoding
                    self.is_training: False
                })

            loss_velocity = np.array(0.0)
            loss_position = np.array(0.0)
            loss_edge = np.array(0.0)
            loss_total = loss_img + loss_position + loss_edge + loss_velocity

            losses_total.append(loss_total)
            losses_img.append(loss_img)
            losses_velocity.append(loss_velocity)
            losses_position.append(loss_position)
            losses_edge.append(loss_edge)

            out_reconstructions[out_reconstructions >= 0.5] = 1.0
            out_reconstructions[out_reconstructions < 0.5] = 0.0

            outputs_total.append((out_reconstructions, in_segxyz, in_image,
                                  in_control, i, (start_idx, end_idx)))

        the_time = time.time()
        elapsed_since_last_log = the_time - last_log_time
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

        if not process_all_nn_outputs:
            """ due to brevity, just use last output """
            outputs_total = [outputs_total[-1]]

        batch_loss = np.mean(losses_total)
        img_batch_loss = np.mean(losses_img)
        vel_batch_loss = np.mean(losses_velocity)
        pos_batch_loss = np.mean(losses_position)
        edge_batch_loss = np.mean(losses_edge)

        print(
            'total test batch loss: {:<8.6f} | img loss: {:<8.6f} | vel loss: {:<8.6f} | pos loss {:<8.6f} | edge loss {:<8.6f} time(s): {:<10.2f}'
            .format(batch_loss, img_batch_loss, vel_batch_loss, pos_batch_loss,
                    edge_batch_loss, elapsed_since_last_log))

        summaries_dict = {
            prefix + '_total_loss': batch_loss,
            prefix + '_img_loss': img_batch_loss,
            prefix + '_velocity_loss': vel_batch_loss,
            prefix + '_position_loss': pos_batch_loss,
            prefix + '_edge_loss': edge_batch_loss
        }

        if outputs_total and output_results:
            for output in outputs_total:
                summaries_dict_images = generate_and_export_image_dicts(
                    output=output,
                    features=features,
                    config=self.config,
                    prefix=prefix,
                    cur_batch_it=cur_batch_it,
                    dir_name=sub_dir_name,
                    reduce_dict=True,
                    multistep=multistep)

            if summaries_dict_images:
                summaries_dict = {**summaries_dict, **summaries_dict_images}

        self.logger.summarize(cur_batch_it,
                              summaries_dict=summaries_dict,
                              summarizer="test")

        return batch_loss, img_batch_loss, vel_batch_loss, pos_batch_loss, edge_batch_loss, cur_batch_it
    def test_specific_exp_ids(self):
        assert self.config.n_epochs == 1, "set n_epochs to 1 for test mode"
        prefix = self.config.exp_name
        print("Running tests with initial_pos_vel_known={}".format(self.config.initial_pos_vel_known))
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

        #exp_ids_to_export = [13873, 3621, 8575, 439, 2439, 1630, 14526, 4377, 15364, 6874, 11031, 8962]  # big 3 object dataset
        #dir_name = "3_objects"
        exp_ids_to_export = [2815, 608, 1691, 49, 1834, 1340, 2596, 2843, 306]  # big 5 object dataset
        dir_name = "5_objects"
        #exp_ids_to_export = [10, 1206, 880, 1189, 1087, 2261, 194, 1799]  # big 5 object novel shapes dataset
        #dir_name = "5_novel_objects"

        export_images = self.config.export_test_images
        export_latent_data = True
        process_all_nn_outputs = True

        thresholds_to_test = [0.5]

        for thresh in thresholds_to_test:
            sub_dir_name = "test_{}_specific_exp_ids_{}_iterations_trained_sigmoid_threshold_{}".format(dir_name, cur_batch_it, thresh)
            while True:
                try:
                    losses_total = []
                    losses_img = []
                    losses_iou = []
                    losses_velocity = []
                    losses_position = []
                    losses_distance = []
                    losses_global = []

                    outputs_total = []

                    features = self.sess.run(self.next_element_test)

                    features = convert_dict_to_list_subdicts(features, self.config.test_batch_size)

                    if exp_ids_to_export:
                        features_to_export = []
                        for dct in features:
                            if dct["experiment_id"] in exp_ids_to_export:
                                features_to_export.append(dct)
                                print("added", dct["experiment_id"])

                        features = features_to_export

                    if exp_ids_to_export and not features_to_export:
                        continue


                    start_time = time.time()
                    last_log_time = start_time

                    for i in range(len(features)):
                        input_graphs_all_exp, target_graphs_all_exp, _ = create_graphs(config=self.config,
                                                                            batch_data=features[i],
                                                                            initial_pos_vel_known=self.config.initial_pos_vel_known,
                                                                            batch_processing=False
                                                                            )
                        output_i = []

                        for j in range(features[i]["unpadded_experiment_length"] - 1):
                            total_loss, output, loss_img, loss_iou, loss_velocity, loss_position, loss_distance, _, loss_global, _ = self.do_step(input_graphs_all_exp[j],
                                                                                                           target_graphs_all_exp[j],
                                                                                                           features[i],
                                                                                                           sigmoid_threshold=thresh,
                                                                                                           train=False,
                                                                                                           batch_processing=False
                                                                                                           )
                            output = output[0]
                            if total_loss is not None:
                                losses_total.append(total_loss)
                                losses_img.append(loss_img)
                                losses_iou.append(loss_iou)
                                losses_velocity.append(loss_velocity)
                                losses_position.append(loss_position)
                                losses_distance.append(loss_distance)
                                losses_global.append(loss_global)

                            output_i.append(output)

                        outputs_total.append((output_i, i))

                    the_time = time.time()
                    elapsed_since_last_log = the_time - last_log_time
                    cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

                    if not process_all_nn_outputs:
                        """ due to brevity, just use last results """
                        outputs_for_summary = [outputs_for_summary[-1]]

                    if losses_total:
                        batch_loss = np.mean(losses_total)
                        img_batch_loss = np.mean(losses_img)
                        iou_batch_loss = np.mean(losses_iou)
                        vel_batch_loss = np.mean(losses_velocity)
                        pos_batch_loss = np.mean(losses_position)
                        dis_batch_loss = np.mean(losses_distance)
                        glob_batch_loss = np.mean(losses_global)

                        print('total test batch loss: {:<8.6f} | img loss: {:<10.6f} | iou loss: {:<8.6f} | vel loss: {:<8.6f} | pos loss {:<8.6f} | edge loss {:<8.6f} | global loss {:<8.6f} | time(s): {:<10.2f}'.format(
                                batch_loss, img_batch_loss, iou_batch_loss, vel_batch_loss, pos_batch_loss, dis_batch_loss,
                                glob_batch_loss, elapsed_since_last_log))

                    if outputs_total:
                        if self.config.parallel_batch_processing:
                            with parallel_backend('loky', n_jobs=-2):
                                Parallel()(delayed(generate_results)(output, self.config, prefix, features, cur_batch_it, export_images,
                                                              export_latent_data, sub_dir_name) for output in outputs_for_summary)
                        else:
                            for output in outputs_total:
                                generate_results(output=output,
                                                        config=self.config,
                                                        prefix=prefix,
                                                        features=features,
                                                        cur_batch_it=cur_batch_it,
                                                        export_images=export_images,
                                                        export_latent_data=export_latent_data,
                                                        dir_name=sub_dir_name, reduce_dict=True)
                except tf.errors.OutOfRangeError:
                    break
                else:
                    print("continue")
                    continue
    def train_batch(self, prefix):
        features = self.sess.run(self.next_element_train)
        features = convert_dict_to_list_subdicts(features,
                                                 self.config.train_batch_size)

        if self.config.do_multi_step_prediction:
            multistep = True
        else:
            multistep = False

        start_time = time.time()
        last_log_time = start_time

        input_graphs_batches, target_graphs_batches, random_episode_idx_starts = create_graphs(
            config=self.config,
            batch_data=features,
            initial_pos_vel_known=self.config.initial_pos_vel_known,
            batch_processing=True,
            multistep=multistep)

        input_graphs_batches = input_graphs_batches[0]
        target_graphs_batches = target_graphs_batches[0]
        """ gt_label_rec (taken from input graphs) is shifted by -1 compared to gt_label (taken from target graphs) """
        in_segxyz, in_image, in_control, gt_label, gt_label_rec = networkx_graphs_to_images(
            self.config,
            input_graphs_batches,
            target_graphs_batches,
            multistep=multistep)

        gt_encoder_output, gt_mlp_output = get_encoding_vectors(
            config=self.config,
            random_episode_idx_starts=random_episode_idx_starts,
            train=True)

        if len(gt_encoder_output) == 0 or len(gt_mlp_output) == 0:
            print("at least one .npz file not found, move on")
            cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)
            return cur_batch_it

        gt_latent = np.concatenate([gt_encoder_output, gt_mlp_output])

        _, loss_img, loss_perception, out_predictions, in_rgb_seg_xyz, out_encoder_vectors, dbg_control = self.sess.run(
            [
                self.model.train_op, self.model.img_loss,
                self.model.perception_loss, self.out_prediction_softmax,
                self.in_rgb_seg_xyz, self.out_latent_vectors,
                self.debug_in_control
            ],
            feed_dict={
                self.in_segxyz_tf: in_segxyz,
                self.in_image_tf: in_image,
                self.gt_predictions: gt_label,
                self.in_control_tf: in_control,
                self.gt_latent_vectors: gt_latent,
                self.is_training: True
            })
        loss_velocity = np.array(0.0)
        loss_position = np.array(0.0)
        loss_edge = np.array(0.0)
        loss_total = loss_img + loss_perception + loss_position + loss_edge + loss_velocity

        self.sess.run(self.model.increment_cur_batch_tensor)
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

        the_time = time.time()
        elapsed_since_last_log = the_time - last_log_time

        print(
            'batch: {:<8} total loss: {:<8.6f} | img loss: {:<8.6f} | latent loss: {:<8.6f} | vel loss: {:<8.6f} | pos loss: {:<8.6f} | edge loss: {:<8.6f} time(s): {:<10.2f} '
            .format(cur_batch_it, loss_total, loss_img, loss_perception,
                    loss_velocity, loss_position, loss_edge,
                    elapsed_since_last_log))
        summaries_dict = {
            prefix + '_total_loss': loss_total,
            prefix + '_img_loss': loss_img,
            prefix + '_perception_loss': loss_perception,
            prefix + '_velocity_loss': loss_velocity,
            prefix + '_position_loss': loss_position,
            prefix + '_edge_loss': loss_edge
        }

        self.logger.summarize(cur_batch_it,
                              summaries_dict=summaries_dict,
                              summarizer="train")

        return cur_batch_it
    def compute_metrics_over_test_set_multistep(self):
        assert self.config.n_epochs == 1, "test mode --> n_epochs must be set to 1"

        if self.config.model_zoo_file == "baseline_auto_predictor_extended_multistep" \
                and self.config.use_f_interact and not (self.config.train_batch_size == 1 and self.config.test_batch_size == 1):
            print(
                "--- when use_f_interact is True, train and test batch size need to be 1 since f_interact uses train"
                "batch_size to split the latent vector for the computation of pairwise object interactions"
            )
            return

        prefix = self.config.exp_name
        print(
            "Computing IoU, Precision, Recall and F1 score over full test set".
            format(self.config.initial_pos_vel_known))
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)
        """ this variable is used to distinguish between 1-step and n-step predictions. Setting it to True or False will call different functions:
         a) when set to False: the function uses "gt_label_rec" as the gt data, inputs are split into episode_length/n_prediction -chunks and fragments 
         (if they cannot be evenly divided) will be dismissed, e.g. episode_length 14 with n_predictions=5 will result in two 5-length chunks
         
         b) when set to True: in this mode, the model only predicts 1-steps. The function uses "in_segxyz" as the ground truth data, 
         the inputs don't have to be processed (e.g. splitting them into episode_length/n_predictions chunks) because after one step, 
         the model is reset to ground truth
          """
        test_single_step = False

        if test_single_step:
            mode_txt = "single_step_tested"
        else:
            mode_txt = "{}_step_tested".format(self.config.n_predictions)

        iou_list_test_set = []
        prec_score_list_test_set = []
        rec_score_list_test_set = []
        f1_score_list_test_set = []
        sub_dir_name = "metric_multistep_models_computation_over_full_test_set_{}_iterations_trained".format(
            cur_batch_it)

        dir_path, _ = create_dir(os.path.join("../experiments", prefix),
                                 sub_dir_name)
        dataset_name = os.path.basename(self.config.tfrecords_dir)
        csv_name = "{}_dataset_{}.csv".format(mode_txt, dataset_name)

        with open(os.path.join(dir_path, csv_name), 'w') as csv_file:
            writer = csv.writer(
                csv_file,
                delimiter='\t',
                lineterminator='\n',
            )
            writer.writerow([
                "(metrics averaged over n shapes and full trajectory) mean IoU",
                "mean precision", "mean recall", "mean f1 over n shapes",
                "exp_id"
            ])
            while True:
                try:
                    losses_total, losses_img, losses_iou, losses_velocity, losses_position = [], [], [], [], []
                    losses_distance, outputs_total, targets_total, exp_id_total = [], [], [], []

                    features = self.sess.run(self.next_element_test)
                    features = convert_dict_to_list_subdicts(
                        features, self.config.test_batch_size)

                    start_time = time.time()
                    last_log_time = start_time

                    for i in range(self.config.test_batch_size):
                        input_graphs_all_exp, target_graphs_all_exp, _ = create_graphs(
                            config=self.config,
                            batch_data=features[i],
                            initial_pos_vel_known=self.config.
                            initial_pos_vel_known,
                            batch_processing=False,
                            return_only_unpadded=True,
                            start_episode=0)
                        out_label_lst = []
                        in_seg_lst = []
                        exp_id = features[i]['experiment_id']
                        n_objects = features[i]['n_manipulable_objects']

                        if test_single_step:
                            """ create >episode-length< long pairs of input and target as lists to run single prediction steps """
                            assert len(input_graphs_all_exp) == len(
                                target_graphs_all_exp)
                            assert self.config.n_predictions == 1, 'set n_predictions to 1 if single_step is to be tested and re-initialize model'
                            single_step_prediction_chunks_input = [[
                                input_graph
                            ] for input_graph in input_graphs_all_exp]
                            single_step_prediction_chunks_target = [[
                                target_graph
                            ] for target_graph in target_graphs_all_exp]

                            for lst_inp, lst_targ in zip(
                                    single_step_prediction_chunks_input,
                                    single_step_prediction_chunks_target):
                                in_segxyz, in_image, in_control, gt_label, _ = networkx_graphs_to_images(
                                    self.config, [lst_inp], [lst_targ],
                                    multistep=True)

                                gt_latent = np.zeros(
                                    shape=(n_objects *
                                           self.config.n_predictions, 256))

                                loss_img, out_label = self.sess.run(
                                    [
                                        self.model.img_loss,
                                        self.out_prediction_softmax
                                    ],
                                    feed_dict={
                                        self.in_segxyz_tf: in_segxyz,
                                        self.in_image_tf: in_image,
                                        self.gt_predictions: gt_label,
                                        self.in_control_tf: in_control,
                                        self.gt_latent_vectors: gt_latent,
                                        self.is_training: True
                                    })

                                out_label[out_label >= 0.5] = 1.0
                                out_label[out_label < 0.5] = 0.0

                                loss_velocity = np.array(0.0)
                                loss_position = np.array(0.0)
                                loss_edge = np.array(0.0)
                                loss_iou = 0.0
                                loss_total = loss_img + loss_position + loss_edge + loss_velocity

                                losses_total.append(loss_total)
                                losses_img.append(loss_img)
                                losses_iou.append(loss_iou)
                                losses_velocity.append(loss_velocity)
                                losses_position.append(loss_position)
                                losses_distance.append(loss_edge)

                                in_seg_lst.append(in_segxyz[:, :, :, 0])
                                out_label_lst.append(out_label)

                        else:
                            """ this section is used for producing the output for an entire episode, meaning that the model 
                            is asked to re-predict after n_prediction steps up until the entire episode is covered 
                            (with potential crop in the end if episode length/n_predictions is odd) """
                            assert len(input_graphs_all_exp) == len(
                                target_graphs_all_exp)
                            assert self.config.n_predictions > 1, 'set n_predictions > 1 if multi_step is to be tested and re-initialize model'
                            n_prediction_chunks_target = []
                            n_prediction_chunks_input = []
                            n_predictions = self.config.n_predictions

                            for j in range(0, len(target_graphs_all_exp),
                                           n_predictions):
                                chunk = target_graphs_all_exp[j:j +
                                                              n_predictions]
                                n_prediction_chunks_target.append(chunk)

                                chunk = input_graphs_all_exp[j:j +
                                                             n_predictions]
                                n_prediction_chunks_input.append(chunk)
                            """ if the length of an episode cannot be evenly divided by n_predictions, remove the last 
                            odd list. end_idx ensures the array split is correctly handled later"""
                            assert all(len(inp_lst) == len(targ_lst) for inp_lst, targ_lst in
                                       zip(n_prediction_chunks_input, n_prediction_chunks_target)), \
                                "input and target lists are not equal after fragmenting them into episode length/n_predictions parts"

                            n_prediction_chunks_target_after_filtering = [
                                chunk for chunk in n_prediction_chunks_target
                                if len(chunk) == n_predictions
                            ]
                            n_prediction_chunks_input_after_filtering = [
                                chunk for chunk in n_prediction_chunks_input
                                if len(chunk) == n_predictions
                            ]

                            #number_removed_chunks = (len(n_prediction_chunks_target) - len(n_prediction_chunks_target_after_filtering))
                            #print("number of removed chunks (of up to 2 (2 step prediction) or 5 (5 step prediction) frames): ", number_removed_chunks)

                            for lst_inp, lst_targ in zip(
                                    n_prediction_chunks_input_after_filtering,
                                    n_prediction_chunks_target_after_filtering
                            ):
                                in_segxyz, in_image, in_control, gt_label, gt_label_rec = networkx_graphs_to_images(
                                    self.config, [lst_inp], [lst_targ],
                                    multistep=True)

                                gt_latent = np.zeros(
                                    shape=(n_objects *
                                           self.config.n_predictions, 256))

                                loss_img, out_label = self.sess.run(
                                    [
                                        self.model.img_loss,
                                        self.out_prediction_softmax
                                    ],
                                    feed_dict={
                                        self.in_segxyz_tf: in_segxyz,
                                        self.in_image_tf: in_image,
                                        self.gt_predictions: gt_label,
                                        self.in_control_tf: in_control,
                                        self.gt_latent_vectors: gt_latent,
                                        self.is_training: True
                                    })

                                out_label[out_label >= 0.5] = 1.0
                                out_label[out_label < 0.5] = 0.0

                                loss_velocity = np.array(0.0)
                                loss_position = np.array(0.0)
                                loss_edge = np.array(0.0)
                                loss_iou = 0.0
                                loss_total = loss_img + loss_position + loss_edge + loss_velocity

                                losses_total.append(loss_total)
                                losses_img.append(loss_img)
                                losses_iou.append(loss_iou)
                                losses_velocity.append(loss_velocity)
                                losses_position.append(loss_position)
                                losses_distance.append(loss_edge)
                                """ in multistep prediction, in_segxyz only contains the first image, therefore use 
                                gt_label_rec that contains all ground truth input images. Also split the model output 
                                by number of predictions since the model is set-up to append predictions along the 
                                first dimension batch-wise, i.e. ([batch1, batch2]) while both batches each have shape 
                                (5, 120, 160) for 5 objects. This results in total shape (10, 120, 160) and requires a 
                                split for example: 2 predictions, 5 objects: (10, 120, 160) --> [(5, 120, 160), (5, 120, 160)] """
                                gt_label_rec_split = np.split(gt_label_rec,
                                                              n_predictions,
                                                              axis=0)
                                out_label_split = np.split(out_label,
                                                           n_predictions,
                                                           axis=0)
                                in_seg_lst = in_seg_lst + gt_label_rec_split
                                out_label_lst = out_label_lst + out_label_split

                        out_label_entire_trajectory, in_seg_entire_trajectory = [], []

                        for n in range(n_objects):
                            out_obj_lst = []
                            in_obj_lst = []
                            for time_step_out, time_step_in in zip(
                                    out_label_lst, in_seg_lst):
                                out_obj_lst.append(time_step_out[n])
                                in_obj_lst.append(time_step_in[n])
                            out_label_entire_trajectory.append(
                                np.array(out_obj_lst))
                            in_seg_entire_trajectory.append(
                                np.array(in_obj_lst))

                        outputs_total.append(out_label_entire_trajectory)
                        targets_total.append(in_seg_entire_trajectory)
                        exp_id_total.append(exp_id)

                    the_time = time.time()
                    elapsed_since_last_log = the_time - last_log_time
                    batch_loss, img_batch_loss, iou_batch_loss = np.mean(
                        losses_total), np.mean(losses_img), np.mean(losses_iou)
                    vel_batch_loss, pos_batch_loss, dis_batch_loss = np.mean(
                        losses_velocity), np.mean(losses_position), np.mean(
                            losses_distance)
                    print(
                        'total test batch loss: {:<8.6f} | img loss: {:<8.6f} | iou loss: {:<8.6f} | vel loss: {:<8.6f} | pos loss {:<8.6f} | edge loss {:<8.6f} time(s): {:<10.2f}'
                        .format(batch_loss, img_batch_loss, iou_batch_loss,
                                vel_batch_loss, pos_batch_loss, dis_batch_loss,
                                elapsed_since_last_log))

                    for pred_experiment, true_experiment, exp_id in zip(
                            outputs_total, targets_total, exp_id_total):
                        iou_scores = []
                        prec_scores = []
                        rec_scores = []
                        f1_scores = []

                        # switch (n_objects, exp_len,...) to (exp_len, n_objects) since IoU computed per time step
                        pred_experiment = np.swapaxes(pred_experiment, 0, 1)
                        true_experiment = np.swapaxes(true_experiment, 0, 1)

                        for pred, true in zip(pred_experiment,
                                              true_experiment):
                            iou = compute_iou(pred=pred, true=true)
                            mean_obj_prec_score, idx_obj_min_prec, idx_obj_max_prec = compute_precision(
                                pred=pred, true=true)
                            mean_obj_rec_score, idx_obj_min_rec, idx_obj_max_rec = compute_recall(
                                pred=pred, true=true)
                            mean_obj_f1_score, idx_obj_min_f1, idx_obj_max_f1 = compute_f1(
                                pred=pred, true=true)

                            iou_scores.append(iou)
                            prec_scores.append(mean_obj_prec_score)
                            rec_scores.append(mean_obj_rec_score)
                            f1_scores.append(mean_obj_f1_score)

                        iou_traj_mean = np.mean(iou_scores)
                        prec_traj_mean = np.mean(prec_scores)
                        rec_traj_mean = np.mean(rec_scores)
                        f1_traj_mean = np.mean(f1_scores)

                        if exp_id is None:
                            print(
                                "shapes pred_experiment and true experiment: ",
                                np.shape(pred_experiment),
                                np.shape(true_experiment))

                        writer.writerow([
                            iou_traj_mean, prec_traj_mean, rec_traj_mean,
                            f1_traj_mean, exp_id
                        ])

                        if not (np.isnan(iou_traj_mean)
                                or np.isnan(prec_traj_mean)
                                or np.isnan(rec_traj_mean)
                                or np.isnan(f1_traj_mean)):
                            prec_score_list_test_set.append(prec_traj_mean)
                            rec_score_list_test_set.append(rec_traj_mean)
                            f1_score_list_test_set.append(f1_traj_mean)
                            iou_list_test_set.append(iou_traj_mean)

                    csv_file.flush()
                except tf.errors.OutOfRangeError:
                    break

            iou_test_set_mean = np.mean(iou_list_test_set)
            prec_test_set_mean = np.mean(prec_score_list_test_set)
            rec_test_set_mean = np.mean(rec_score_list_test_set)
            f1_test_set_mean = np.mean(f1_score_list_test_set)

            writer.writerow([
                "means over full set", " IoU: ", iou_test_set_mean,
                " Precision: ", prec_test_set_mean, " Recall: ",
                rec_test_set_mean, "F1: ", f1_test_set_mean
            ])
            if test_single_step:
                mode = "(model trained multi-step but tested single-step)"
            else:
                mode = "(model trained & tested multi-step)"

            print(
                "Done. mean IoU {}: {}, mean precision: {}, mean recall: {}, mean f1: {}"
                .format(mode, iou_test_set_mean, prec_test_set_mean,
                        rec_test_set_mean, f1_test_set_mean))
    def test_batch(self,
                   prefix,
                   initial_pos_vel_known,
                   export_images=False,
                   process_all_nn_outputs=False,
                   sub_dir_name=None,
                   export_latent_data=True,
                   output_results=True):

        losses_total = []
        losses_img = []
        losses_perception = []
        losses_velocity = []
        losses_position = []
        losses_edge = []
        outputs_total = []
        summaries_dict_images = {}

        features = self.sess.run(self.next_element_test)
        features = convert_dict_to_list_subdicts(features,
                                                 self.config.test_batch_size)

        start_time = time.time()
        last_log_time = start_time

        if self.config.do_multi_step_prediction:
            multistep = True
        else:
            multistep = False
        start_idx = 0
        end_idx = self.config.n_predictions

        for i in range(self.config.test_batch_size):
            input_graphs_all_exp, target_graphs_all_exp, random_episode_idx_starts = create_graphs(
                config=self.config,
                batch_data=features[i],
                initial_pos_vel_known=self.config.initial_pos_vel_known,
                batch_processing=False)

            if multistep:
                """ for gt_reconstruction we also need all input graphs """
                input_graphs_all_exp = [
                    input_graphs_all_exp[start_idx:end_idx]
                ]
                target_graphs_all_exp = [
                    target_graphs_all_exp[start_idx:end_idx]
                ]

            gt_encoder_output, gt_mlp_output = get_encoding_vectors(
                config=self.config,
                random_episode_idx_starts=random_episode_idx_starts,
                train=False)

            if len(gt_encoder_output) == 0 or len(gt_mlp_output) == 0:
                print("at least one .npz file not found, move on")
                return

            gt_latent = np.concatenate([gt_encoder_output, gt_mlp_output])
            """ below implementation is correct! this is required because f_interact requires a dynamic split in 
            the batch size of the latent vector which is not implemented --> instead, we assume the batch size for 
            train and test are the same, requiring the test batch to pad up to length of a train batch"""
            if self.config.use_f_interact:
                input_graphs_all_exp = [
                    input_graphs_all_exp[0]
                    for _ in range(self.config.train_batch_size)
                ]
                target_graphs_all_exp = [
                    target_graphs_all_exp[0]
                    for _ in range(self.config.train_batch_size)
                ]

                gt_latent = [
                    gt_latent for i in range(self.config.train_batch_size)
                ]
                gt_latent = np.concatenate(gt_latent, axis=0)

            in_segxyz, in_image, in_control, gt_label, _ = networkx_graphs_to_images(
                self.config,
                input_graphs_all_exp,
                target_graphs_all_exp,
                multistep=multistep)

            loss_img, loss_perception, out_predictions = self.sess.run(
                [
                    self.model.img_loss, self.model.perception_loss,
                    self.out_prediction_softmax
                ],
                feed_dict={
                    self.in_segxyz_tf: in_segxyz,
                    self.in_image_tf: in_image,
                    self.gt_predictions: gt_label,
                    self.in_control_tf: in_control,
                    self.gt_latent_vectors: gt_latent,
                    self.is_training: False
                })

            if self.config.use_f_interact:
                pred_splits = np.split(out_predictions,
                                       self.config.n_predictions)
                preds = []
                n_objects = features[i]['n_manipulable_objects']
                """ pred has shape (n_objects*train_batch_size, 120, 160) --> get first batch, i.e. :n_objects"""
                for pred in pred_splits:
                    preds.append(pred[:n_objects])

                out_predictions = np.concatenate(preds, axis=0)

            loss_velocity = np.array(0.0)
            loss_position = np.array(0.0)
            loss_edge = np.array(0.0)
            loss_total = loss_img + loss_perception + loss_position + loss_edge + loss_velocity

            losses_total.append(loss_total)
            losses_img.append(loss_img)
            losses_perception.append(loss_perception)
            losses_velocity.append(loss_velocity)
            losses_position.append(loss_position)
            losses_edge.append(loss_edge)

            out_predictions[out_predictions >= 0.5] = 1.0
            out_predictions[out_predictions < 0.5] = 0.0

            outputs_total.append((out_predictions, in_segxyz, in_image,
                                  in_control, i, (start_idx, end_idx)))

        the_time = time.time()
        elapsed_since_last_log = the_time - last_log_time
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

        if not process_all_nn_outputs:
            """ due to brevity, just use last output """
            outputs_total = [outputs_total[-1]]

        batch_loss = np.mean(losses_total)
        img_batch_loss = np.mean(losses_img)
        perception_batch_loss = np.mean(losses_perception)
        vel_batch_loss = np.mean(losses_velocity)
        pos_batch_loss = np.mean(losses_position)
        edge_batch_loss = np.mean(losses_edge)

        print(
            'total test batch loss: {:<8.6f} | img loss: {:<8.6f} | latent loss: {:<8.6f} | vel loss: {:<8.6f} | pos loss {:<8.6f} | edge loss {:<8.6f} time(s): {:<10.2f}'
            .format(batch_loss, img_batch_loss, perception_batch_loss,
                    vel_batch_loss, pos_batch_loss, edge_batch_loss,
                    elapsed_since_last_log))

        summaries_dict = {
            prefix + '_total_loss': batch_loss,
            prefix + '_img_loss': img_batch_loss,
            prefix + '_perception_loss': perception_batch_loss,
            prefix + '_velocity_loss': vel_batch_loss,
            prefix + '_position_loss': pos_batch_loss,
            prefix + '_edge_loss': edge_batch_loss
        }

        if outputs_total and output_results:
            for output in outputs_total:
                summaries_dict_images = generate_and_export_image_dicts(
                    output=output,
                    features=features,
                    config=self.config,
                    prefix=prefix,
                    cur_batch_it=cur_batch_it,
                    dir_name=sub_dir_name,
                    reduce_dict=True,
                    multistep=multistep)

            if summaries_dict_images:
                summaries_dict = {**summaries_dict, **summaries_dict_images}

        self.logger.summarize(cur_batch_it,
                              summaries_dict=summaries_dict,
                              summarizer="test")

        return batch_loss, img_batch_loss, vel_batch_loss, pos_batch_loss, edge_batch_loss, cur_batch_it
    def test_specific_exp_ids(self):
        if not self.config.n_epochs == 1:
            print(
                "test mode for specific exp ids --> n_epochs will be set to 1")
            self.config.n_epochs = 1
        prefix = self.config.exp_name
        print("Running tests with initial_pos_vel_known={}".format(
            self.config.initial_pos_vel_known))
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

        if "5_objects_50_rollouts_padded_novel" in self.config.tfrecords_dir:
            exp_ids_to_export = [10, 1206, 880, 1189, 1087, 2261, 194,
                                 1799]  # big 5 object novel shapes dataset
            dir_name = "5_novel_objects"
        elif "5_objects_50_rollouts" in self.config.tfrecords_dir:
            exp_ids_to_export = [
                2815, 608, 1691, 49, 1834, 1340, 2596, 2843, 306
            ]  # big 5 object dataset
            dir_name = "5_objects"
        else:
            exp_ids_to_export = [
                13873, 3621, 8575, 439, 2439, 1630, 14526, 4377, 15364, 6874,
                11031, 8962
            ]  # big 3 object dataset
            dir_name = "3_objects"

        process_all_nn_outputs = True
        thresholds_to_test = 0.5

        reset_after_n_predictions = False
        start_idx = 0
        end_idx = self.config.n_predictions

        if self.config.n_predictions > 1 and not reset_after_n_predictions:
            multistep = True
            dir_suffix = "show_pred_from_start"
            start_episode = 0
        elif self.config.n_predictions > 1 and reset_after_n_predictions:
            multistep = True
            dir_suffix = "reset_pred_after_n_predictions"
            start_episode = None
        else:
            multistep = False
            dir_suffix = ""
            start_episode = 0

        sub_dir_name = "test_{}_specific_exp_ids_{}_iterations_trained_sigmoid_threshold_{}_mode_{}".format(
            dir_name, cur_batch_it, thresholds_to_test, dir_suffix)
        while True:
            try:
                losses_total = []
                losses_img = []
                losses_velocity = []
                losses_position = []
                losses_edge = []
                outputs_total = []

                features = self.sess.run(self.next_element_test)

                features = convert_dict_to_list_subdicts(
                    features, self.config.test_batch_size)

                if exp_ids_to_export:
                    features_to_export = []
                    for dct in features:
                        if dct["experiment_id"] in exp_ids_to_export:
                            features_to_export.append(dct)
                            print("added", dct["experiment_id"])

                    features = features_to_export

                if exp_ids_to_export and not features_to_export:
                    continue

                start_time = time.time()
                last_log_time = start_time

                for i in range(len(features)):
                    input_graphs_all_exp, target_graphs_all_exp, _ = create_graphs(
                        config=self.config,
                        batch_data=features[i],
                        initial_pos_vel_known=self.config.
                        initial_pos_vel_known,
                        batch_processing=False,
                        return_only_unpadded=True,
                        start_episode=start_episode)

                    if multistep:
                        input_graphs_all_exp = [
                            input_graphs_all_exp[start_idx:end_idx]
                        ]
                        target_graphs_all_exp = [
                            target_graphs_all_exp[start_idx:end_idx]
                        ]

                    in_segxyz, in_image, in_control, gt_label, _ = networkx_graphs_to_images(
                        self.config,
                        input_graphs_all_exp,
                        target_graphs_all_exp,
                        multistep=multistep)

                    loss_img, out_label = self.sess.run(
                        [self.model.loss_op, self.out_prediction_softmax],
                        feed_dict={
                            self.in_segxyz_tf: in_segxyz,
                            self.in_image_tf: in_image,
                            self.gt_predictions: gt_label,
                            self.in_control_tf: in_control,
                            self.is_training: True
                        })
                    loss_velocity = np.array(0.0)
                    loss_position = np.array(0.0)
                    loss_edge = np.array(0.0)
                    loss_total = loss_img + loss_position + loss_edge + loss_velocity

                    losses_total.append(loss_total)
                    losses_img.append(loss_img)
                    losses_velocity.append(loss_velocity)
                    losses_position.append(loss_position)
                    losses_edge.append(loss_edge)

                    out_label[out_label >= 0.5] = 1.0
                    out_label[out_label < 0.5] = 0.0

                    outputs_total.append((out_label, in_segxyz, in_image,
                                          in_control, i, (start_idx, end_idx)))

                the_time = time.time()
                elapsed_since_last_log = the_time - last_log_time
                cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

                if not process_all_nn_outputs:
                    """ due to brevity, just use last output """
                    outputs_total = [outputs_total[-1]]

                batch_loss = np.mean(losses_total)
                img_batch_loss = np.mean(losses_img)
                vel_batch_loss = np.mean(losses_velocity)
                pos_batch_loss = np.mean(losses_position)
                edge_batch_loss = np.mean(losses_edge)

                print(
                    'total test batch loss: {:<8.6f} | img loss: {:<8.6f} | vel loss: {:<8.6f} | pos loss {:<8.6f} | edge loss {:<8.6f} time(s): {:<10.2f}'
                    .format(batch_loss, img_batch_loss, vel_batch_loss,
                            pos_batch_loss, edge_batch_loss,
                            elapsed_since_last_log))

                if outputs_total:
                    for output in outputs_total:
                        generate_and_export_image_dicts(
                            output=output,
                            features=features,
                            config=self.config,
                            prefix=prefix,
                            cur_batch_it=cur_batch_it,
                            dir_name=sub_dir_name,
                            reduce_dict=True,
                            multistep=multistep)

            except tf.errors.OutOfRangeError:
                break
            else:
                print("continue")
                continue
    def train_batch(self, prefix):
        features = self.sess.run(self.next_element_train)
        features = convert_dict_to_list_subdicts(features,
                                                 self.config.train_batch_size)

        if self.config.do_multi_step_prediction:
            multistep = True
        else:
            multistep = False

        start_time = time.time()
        last_log_time = start_time

        input_graphs_batches, target_graphs_batches, _ = create_graphs(
            config=self.config,
            batch_data=features,
            initial_pos_vel_known=self.config.initial_pos_vel_known,
            batch_processing=True,
            shuffle=False,
            multistep=multistep)

        input_graphs_batches = input_graphs_batches[0][0]
        target_graphs_batches = target_graphs_batches[0][0]
        """ gt_label_rec (taken from input graphs) is shifted by -1 compared to gt_label (taken from target graphs) """
        in_segxyz, in_image, in_control, _, gt_reconstruction = networkx_graphs_to_images(
            self.config,
            input_graphs_batches,
            target_graphs_batches,
            multistep=multistep)

        _, loss_img, out_reconstructions, in_rgb_seg_xyz, latent_feature_img = self.sess.run(
            [
                self.model.train_op, self.model.loss_op,
                self.out_prediction_softmax, self.in_rgb_seg_xyz,
                self.encoder_outputs
            ],
            feed_dict={
                self.in_segxyz_tf: in_segxyz,
                self.in_image_tf: in_image,
                self.gt_predictions:
                gt_reconstruction,  # this is intentional to maintain same names
                self.in_control_tf:
                in_control,  #  this is actually not used in auto-encoding
                self.is_training: True
            })
        loss_velocity = np.array(0.0)
        loss_position = np.array(0.0)
        loss_edge = np.array(0.0)
        loss_total = loss_img + loss_position + loss_edge + loss_velocity

        self.sess.run(self.model.increment_cur_batch_tensor)
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

        the_time = time.time()
        elapsed_since_last_log = the_time - last_log_time

        print(
            'batch: {:<8} total loss: {:<8.6f} | img loss: {:<8.6f} | vel loss: {:<8.6f} | pos loss: {:<8.6f} | edge loss: {:<8.6f} time(s): {:<10.2f} '
            .format(cur_batch_it, loss_total, loss_img, loss_velocity,
                    loss_position, loss_edge, elapsed_since_last_log))
        summaries_dict = {
            prefix + '_total_loss': loss_total,
            prefix + '_img_loss': loss_img,
            prefix + '_velocity_loss': loss_velocity,
            prefix + '_position_loss': loss_position,
            prefix + '_edge_loss': loss_edge
        }

        self.logger.summarize(cur_batch_it,
                              summaries_dict=summaries_dict,
                              summarizer="train")

        return cur_batch_it
    def compute_metrics_over_test_set(self):
        if not self.config.n_epochs == 1:
            print("test mode --> n_epochs will be set to 1")
            self.config.n_epochs = 1
        prefix = self.config.exp_name
        print(
            "Computing IoU, Precision, Recall and F1 score over full test set".
            format(self.config.initial_pos_vel_known))
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

        iou_list_test_set = []
        prec_score_list_test_set = []
        rec_score_list_test_set = []
        f1_score_list_test_set = []
        sub_dir_name = "metric_computation_over_full_test_set_{}_iterations_trained".format(
            cur_batch_it)

        dir_path, _ = create_dir(os.path.join("../experiments", prefix),
                                 sub_dir_name)
        dataset_name = os.path.basename(self.config.tfrecords_dir)
        csv_name = "dataset_{}.csv".format(dataset_name)

        with open(os.path.join(dir_path, csv_name), 'w') as csv_file:
            writer = csv.writer(
                csv_file,
                delimiter='\t',
                lineterminator='\n',
            )
            writer.writerow([
                "(metrics averaged over n shapes and full trajectory) mean IoU",
                "mean precision", "mean recall", "mean f1 over n shapes",
                "exp_id"
            ])
            while True:
                try:
                    losses_total, losses_img, losses_iou, losses_velocity, losses_position = [], [], [], [], []
                    losses_distance, outputs_total, targets_total, exp_id_total = [], [], [], []

                    features = self.sess.run(self.next_element_test)
                    features = convert_dict_to_list_subdicts(
                        features, self.config.test_batch_size)

                    start_time = time.time()
                    last_log_time = start_time

                    for i in range(self.config.test_batch_size):
                        input_graphs_all_exp, target_graphs_all_exp, _ = create_graphs(
                            config=self.config,
                            batch_data=features[i],
                            initial_pos_vel_known=self.config.
                            initial_pos_vel_known,
                            batch_processing=False)
                        output_i, target_i, exp_id_i = [], [], []

                        input_graphs_all_exp = [input_graphs_all_exp]
                        target_graphs_all_exp = [target_graphs_all_exp]

                        in_segxyz, in_image, in_control, gt_label = networkx_graphs_to_images(
                            self.config, input_graphs_all_exp,
                            target_graphs_all_exp)

                        loss_img, out_label = self.sess.run(
                            [self.model.loss_op, self.out_prediction_softmax],
                            feed_dict={
                                self.in_segxyz_tf: in_segxyz,
                                self.in_image_tf: in_image,
                                self.gt_predictions: gt_label,
                                self.in_control_tf: in_control,
                                self.is_training: True
                            })

                        loss_velocity = np.array(0.0)
                        loss_position = np.array(0.0)
                        loss_edge = np.array(0.0)
                        loss_iou = 0.0
                        loss_total = loss_img + loss_position + loss_edge + loss_velocity

                        losses_total.append(loss_total)
                        losses_img.append(loss_img)
                        losses_iou.append(loss_iou)
                        losses_velocity.append(loss_velocity)
                        losses_position.append(loss_position)
                        losses_distance.append(loss_edge)

                        out_label[out_label >= 0.5] = 1.0
                        out_label[out_label < 0.5] = 0.0

                        exp_id_i.append(features[i]['experiment_id'])

                        unpad_exp_length = features[i][
                            'unpadded_experiment_length']
                        n_objects = features[i]['n_manipulable_objects']
                        print(np.shape(out_label))
                        out_label_split = np.split(out_label,
                                                   unpad_exp_length - 1)
                        in_seg_split = np.split(in_segxyz[:, :, :, 0],
                                                unpad_exp_length - 1)

                        out_label_entire_trajectory, in_seg_entire_trajectory = [], []

                        for n in range(n_objects):
                            out_obj_lst = []
                            in_obj_lst = []
                            for time_step_out, time_step_in in zip(
                                    out_label_split, in_seg_split):
                                out_obj_lst.append(time_step_out[n])
                                in_obj_lst.append(time_step_in[n])
                            out_label_entire_trajectory.append(
                                np.array(out_obj_lst))
                            in_seg_entire_trajectory.append(
                                np.array(in_obj_lst))

                        outputs_total.append(out_label_entire_trajectory)
                        targets_total.append(in_seg_entire_trajectory)
                        exp_id_total.append(exp_id_i)

                    the_time = time.time()
                    elapsed_since_last_log = the_time - last_log_time
                    batch_loss, img_batch_loss, iou_batch_loss = np.mean(
                        losses_total), np.mean(losses_img), np.mean(losses_iou)
                    vel_batch_loss, pos_batch_loss, dis_batch_loss = np.mean(
                        losses_velocity), np.mean(losses_position), np.mean(
                            losses_distance)
                    print(
                        'total test batch loss: {:<8.6f} | img loss: {:<8.6f} | iou loss: {:<8.6f} | vel loss: {:<8.6f} | pos loss {:<8.6f} | edge loss {:<8.6f} time(s): {:<10.2f}'
                        .format(batch_loss, img_batch_loss, iou_batch_loss,
                                vel_batch_loss, pos_batch_loss, dis_batch_loss,
                                elapsed_since_last_log))

                    for pred_experiment, true_experiment, exp_id in zip(
                            outputs_total, targets_total, exp_id_total):
                        iou_scores = []
                        prec_scores = []
                        rec_scores = []
                        f1_scores = []

                        # switch (n_objects, exp_len,...) to (exp_len, n_objects) since IoU computed per time step
                        pred_experiment = np.swapaxes(pred_experiment, 0, 1)
                        true_experiment = np.swapaxes(true_experiment, 0, 1)

                        for pred, true in zip(pred_experiment,
                                              true_experiment):
                            iou = compute_iou(pred=pred, true=true)
                            mean_obj_prec_score, idx_obj_min_prec, idx_obj_max_prec = compute_precision(
                                pred=pred, true=true)
                            mean_obj_rec_score, idx_obj_min_rec, idx_obj_max_rec = compute_recall(
                                pred=pred, true=true)
                            mean_obj_f1_score, idx_obj_min_f1, idx_obj_max_f1 = compute_f1(
                                pred=pred, true=true)

                            iou_scores.append(iou)
                            prec_scores.append(mean_obj_prec_score)
                            rec_scores.append(mean_obj_rec_score)
                            f1_scores.append(mean_obj_f1_score)

                        iou_traj_mean = np.mean(iou_scores)
                        prec_traj_mean = np.mean(prec_scores)
                        rec_traj_mean = np.mean(rec_scores)
                        f1_traj_mean = np.mean(f1_scores)

                        writer.writerow([
                            iou_traj_mean, prec_traj_mean, rec_traj_mean,
                            f1_traj_mean, exp_id[0]
                        ])

                        prec_score_list_test_set.append(prec_traj_mean)
                        rec_score_list_test_set.append(rec_traj_mean)
                        f1_score_list_test_set.append(f1_traj_mean)
                        iou_list_test_set.append(iou_traj_mean)

                    csv_file.flush()
                except tf.errors.OutOfRangeError:
                    break

            iou_test_set_mean = np.mean(iou_list_test_set)
            prec_test_set_mean = np.mean(prec_score_list_test_set)
            rec_test_set_mean = np.mean(rec_score_list_test_set)
            f1_test_set_mean = np.mean(f1_score_list_test_set)

            writer.writerow([
                "means over full set", " IoU: ", iou_test_set_mean,
                " Precision: ", prec_test_set_mean, " Recall: ",
                rec_test_set_mean, "F1: ", f1_test_set_mean
            ])
            print(
                "Done. mean IoU: {}, mean precision: {}, mean recall: {}, mean f1: {}"
                .format(iou_test_set_mean, prec_test_set_mean,
                        rec_test_set_mean, f1_test_set_mean))
    def test_batch(self, prefix, initial_pos_vel_known, export_images=False, process_all_nn_outputs=False, sub_dir_name=None,
                   export_latent_data=False, output_results=True):

        losses_total = []
        losses_img = []
        losses_iou = []
        losses_velocity = []
        losses_position = []
        losses_distance = []
        losses_global = []

        outputs_total = []
        targets_total = []
        exp_id_total = []
        summaries_dict = {}
        summaries_dict_images = {}

        features = self.sess.run(self.next_element_test)
        features = convert_dict_to_list_subdicts(features, self.config.test_batch_size)

        start_time = time.time()
        last_log_time = start_time

        for i in range(self.config.test_batch_size):
            input_graphs_all_exp, target_graphs_all_exp, _ = create_graphs(config=self.config,
                                                                        batch_data=features[i],
                                                                        initial_pos_vel_known=self.config.initial_pos_vel_known,
                                                                        batch_processing=False
                                                                        )
            output_i = []
            target_i = []
            exp_id_i = []


            for j in range(features[i]["unpadded_experiment_length"] - 1):
                total_loss, output, loss_img, loss_iou, loss_velocity, loss_position, loss_distance, target, loss_global, _ = self.do_step(input_graphs_all_exp[j],
                                                                                                       target_graphs_all_exp[j],
                                                                                                       features[i],
                                                                                                       train=False,
                                                                                                       batch_processing=False
                                                                                                       )
                output = output[0]
                if total_loss is not None:
                    losses_total.append(total_loss)
                    losses_img.append(loss_img)
                    losses_iou.append(loss_iou)
                    losses_velocity.append(loss_velocity)
                    losses_position.append(loss_position)
                    losses_distance.append(loss_distance)
                    losses_global.append(loss_global)

                output_i.append(output)
                target_i.append(target)
                exp_id_i.append(features[i]['experiment_id'])


            outputs_total.append((output_i, i))
            targets_total.append((target_i, i))
            exp_id_total.append(exp_id_i)

        the_time = time.time()
        elapsed_since_last_log = the_time - last_log_time
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

        if not process_all_nn_outputs:
            """ due to brevity, just use last output """
            outputs_total = [outputs_total[-1]]

        if losses_total:
            batch_loss = np.mean(losses_total)
            img_batch_loss = np.mean(losses_img)
            iou_batch_loss = np.mean(losses_iou)
            vel_batch_loss = np.mean(losses_velocity)
            pos_batch_loss = np.mean(losses_position)
            dis_batch_loss = np.mean(losses_distance)
            glob_batch_loss = np.mean(losses_global)

            print('total test batch loss: {:<8.6f} | img loss: {:<8.6f} | iou loss: {:<8.6f} | vel loss: {:<8.6f} | pos loss {:<8.6f} | edge loss {:<8.6f} | global loss {:<8.6f} | time(s): {:<10.2f}'.format(
                batch_loss, img_batch_loss, iou_batch_loss, vel_batch_loss, pos_batch_loss, dis_batch_loss, glob_batch_loss, elapsed_since_last_log))

            summaries_dict = {prefix + '_total_loss': batch_loss,
                              prefix + '_img_loss': img_batch_loss,
                              prefix + '_iou_loss': iou_batch_loss,
                              prefix + '_velocity_loss': vel_batch_loss,
                              prefix + '_position_loss': pos_batch_loss,
                              prefix + '_edge_loss': dis_batch_loss,
                              prefix + '_global_loss': glob_batch_loss
                              }

        else:
            batch_loss, img_batch_loss, vel_batch_loss, pos_batch_loss, dis_batch_loss, iou_batch_loss = 0, 0, 0, 0, 0, 0

        if outputs_total and output_results:
            if self.config.parallel_batch_processing:
                with parallel_backend('threading', n_jobs=-2):
                    results = Parallel()(delayed(generate_results)(output,
                                                                      self.config,
                                                                      prefix,
                                                                      features,
                                                                      cur_batch_it,
                                                                      export_images,
                                                                      export_latent_data,
                                                                      sub_dir_name,
                                                                      True,
                                                                      ['seg']) for output in outputs_total)

            else:
                for output in outputs_total:
                    summaries_dict_images, summaries_pos_dict_images, _ = generate_results(output=output,
                                                                         config=self.config,
                                                                         prefix=prefix,
                                                                         features=features,
                                                                         cur_batch_it=cur_batch_it,
                                                                         export_images=export_images,
                                                                         export_latent_data=export_latent_data,
                                                                         dir_name=sub_dir_name,
                                                                         reduce_dict=True,
                                                                         output_selection=['seg', 'rgb', 'depth']
                                                                        )

            if summaries_dict_images:
                if self.config.parallel_batch_processing:
                    """ parallel mode returns list, just use first element as a summary for the logger """
                    summaries_dict_images = results[0]
                    summaries_pos_dict_images = results[1]

                    summaries_dict_images = summaries_dict_images[0]
                    if summaries_pos_dict_images is not None: 
                        summaries_pos_dict_images = summaries_pos_dict_images[0]

                if summaries_pos_dict_images is not None:
                    summaries_dict = {**summaries_dict, **summaries_dict_images, **summaries_pos_dict_images}
                else:
                    summaries_dict = {**summaries_dict, **summaries_dict_images}
                cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)
                self.logger.summarize(cur_batch_it, summaries_dict=summaries_dict, summarizer="test")

        return batch_loss, img_batch_loss, vel_batch_loss, pos_batch_loss, dis_batch_loss, iou_batch_loss, cur_batch_it, outputs_total, targets_total, exp_id_total
    def test_specific_exp_ids(self):
        assert self.config.n_epochs == 1, "test mode for specific exp ids --> n_epochs must be set to 1"
        if self.config.model_zoo_file == "baseline_auto_predictor_extended_multistep" \
                and self.config.use_f_interact and not (self.config.train_batch_size == 1 and self.config.test_batch_size == 1):
            print(
                "--- when use_f_interact is True, train and test batch size need to be 1 since f_interact uses train"
                "batch_size to split the latent vector for the computation of pairwise object interactions"
            )
            return

        prefix = self.config.exp_name
        print("Running tests with initial_pos_vel_known={}".format(
            self.config.initial_pos_vel_known))
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

        if "5_objects_50_rollouts_padded_novel" in self.config.tfrecords_dir:
            exp_ids_to_export = [
                1899, 1045, 1790, 1472, 980, 2080, 1464, 985, 141, 2521, 2643,
                735, 620, 1667, 62
            ]
            #exp_ids_to_export = [10, 1206, 880, 1189, 1087, 2261, 194, 1799]  # big 5 object novel shapes dataset
            dir_name = "5_novel_objects_more"
            #dir_name = "5_novel_objects"
        elif "5_objects_50_rollouts" in self.config.tfrecords_dir:
            exp_ids_to_export = [
                2815, 608, 1691, 49, 1834, 1340, 2596, 2843, 306
            ]  # big 5 object dataset
            dir_name = "5_objects"
        elif "2_bigger_cubes" in self.config.tfrecords_dir:
            path = "/scr2/fabiof/repos/GNforInteraction/mains/test_ids_2_bigger_cubes_dataset.txt"
            with open(path) as f:
                exp_ids_to_export = [
                    int(x) for line in f for x in line.split()
                ]
            print(exp_ids_to_export)
            #exp_ids_to_export = [2815, 608, 1691, 49, 1834, 1340, 2596, 2843, 306]  # big 5 object dataset
            dir_name = "2_cubes"
        else:
            exp_ids_to_export = [
                13873, 3621, 8575, 439, 2439, 1630, 14526, 4377, 15364, 6874,
                11031, 8962
            ]  # big 3 object dataset
            #exp_ids_to_export = [9896, 7140, 12844, 15693, 3770, 13327, 8437, 314, 1428, 402, 5355, 9303, 8474, 78] # train ids
            #dir_name = "3_objects_TRAINSET"
            dir_name = "3_objects"
        """ set this to true if full episodes should be predicted with the multistep model, i.e. after every 
        n_predictions, the model input will be reset to the actual ground truth to allow to fully observe the 
        predictions for an entire episode"""
        reset_after_n_predictions = True

        start_idx = 0
        end_idx = self.config.n_predictions

        if self.config.n_predictions > 1 and not reset_after_n_predictions:
            multistep = True
            dir_suffix = "show_pred_from_start"
            start_episode = 0
            pad_ground_truth_to_exp_len = False
        elif self.config.n_predictions > 1 and reset_after_n_predictions:
            multistep = True
            dir_suffix = "reset_model_input_after_{}_predictions".format(
                self.config.n_predictions)
            start_episode = None
            pad_ground_truth_to_exp_len = False
        else:
            multistep = False
            dir_suffix = ""
            start_episode = 0
            pad_ground_truth_to_exp_len = True

        sub_dir_name = "test_{}_specific_exp_ids_{}_iterations_trained_sigmoid_threshold_{}_mode_{}".format(
            dir_name, cur_batch_it, 0.5, dir_suffix)
        while True:
            try:
                losses_total = []
                losses_img = []
                losses_velocity = []
                losses_position = []
                losses_edge = []
                outputs_total = []

                features = self.sess.run(self.next_element_test)

                features = convert_dict_to_list_subdicts(
                    features, self.config.test_batch_size)

                if exp_ids_to_export:
                    features_to_export = []
                    for dct in features:
                        if dct["experiment_id"] in exp_ids_to_export:
                            features_to_export.append(dct)
                            print("added", dct["experiment_id"])

                    features = features_to_export

                if exp_ids_to_export and not features_to_export:
                    continue

                start_time = time.time()
                last_log_time = start_time

                for i in range(len(features)):
                    input_graphs_all_exp, target_graphs_all_exp, _ = create_graphs(
                        config=self.config,
                        batch_data=features[i],
                        initial_pos_vel_known=self.config.
                        initial_pos_vel_known,
                        batch_processing=False,
                        return_only_unpadded=True,
                        start_episode=start_episode)

                    n_objects = features[i]['n_manipulable_objects']

                    if reset_after_n_predictions:
                        """ this section is used for producing the output for an entire episode, meaning that the model 
                        is asked to re-predict after n_prediction steps up until the entire episode is covered 
                        (with potential crop in the end if episode length/n_predictions is odd) """
                        assert len(input_graphs_all_exp) == len(
                            target_graphs_all_exp)
                        n_prediction_chunks_target = []
                        n_prediction_chunks_input = []
                        n_predictions = self.config.n_predictions

                        for j in range(0, len(target_graphs_all_exp),
                                       n_predictions):
                            chunk = target_graphs_all_exp[j:j + n_predictions]
                            n_prediction_chunks_target.append(chunk)

                            chunk = input_graphs_all_exp[j:j + n_predictions]
                            n_prediction_chunks_input.append(chunk)
                        """ if the length of an episode cannot be evenly divided by n_predictions, remove the last 
                        odd list. end_idx ensures the array split is correctly handled later"""
                        n_prediction_chunks_target = [
                            chunk for chunk in n_prediction_chunks_target
                            if len(chunk) == n_predictions
                        ]
                        n_prediction_chunks_input = [
                            chunk for chunk in n_prediction_chunks_input
                            if len(chunk) == n_predictions
                        ]
                        end_idx = len(
                            n_prediction_chunks_target) * n_predictions

                        out_label_lst = []
                        in_segxyz_lst = []
                        in_image_lst = []
                        in_control_lst = []

                        for lst_inp, lst_targ in zip(
                                n_prediction_chunks_input,
                                n_prediction_chunks_target):
                            in_segxyz, in_image, in_control, gt_label, _ = networkx_graphs_to_images(
                                self.config, [lst_inp], [lst_targ],
                                multistep=True)

                            gt_latent = np.zeros(
                                shape=(n_objects * self.config.n_predictions,
                                       256))

                            loss_img, out_label = self.sess.run(
                                [
                                    self.model.img_loss,
                                    self.out_prediction_softmax
                                ],
                                feed_dict={
                                    self.in_segxyz_tf: in_segxyz,
                                    self.in_image_tf: in_image,
                                    self.gt_predictions: gt_label,
                                    self.in_control_tf: in_control,
                                    self.gt_latent_vectors: gt_latent,
                                    self.is_training: True
                                })

                            out_label[out_label >= 0.5] = 1.0
                            out_label[out_label < 0.5] = 0.0

                            loss_velocity = np.array(0.0)
                            loss_position = np.array(0.0)
                            loss_edge = np.array(0.0)
                            loss_total = loss_img + loss_position + loss_edge + loss_velocity

                            losses_total.append(loss_total)
                            losses_img.append(loss_img)
                            losses_velocity.append(loss_velocity)
                            losses_position.append(loss_position)
                            losses_edge.append(loss_edge)

                            out_label_lst.append(out_label)
                            in_segxyz_lst.append(in_segxyz)
                            in_image_lst.append(in_image)
                            in_control_lst.append(in_control)

                        out_label, in_segxyz, in_image, in_control = np.concatenate(out_label_lst, axis=0), \
                                                                     np.concatenate(in_segxyz_lst, axis=0), \
                                                                     np.concatenate(in_image_lst, axis=0), \
                                                                     np.concatenate(in_control_lst, axis=0)

                        outputs_total.append(
                            (out_label, in_segxyz, in_image, in_control, i,
                             (start_idx, end_idx)))

                    else:
                        if multistep:
                            input_graphs_all_exp = [
                                input_graphs_all_exp[start_idx:end_idx]
                            ]
                            target_graphs_all_exp = [
                                target_graphs_all_exp[start_idx:end_idx]
                            ]

                        in_segxyz, in_image, in_control, gt_label, _ = networkx_graphs_to_images(
                            self.config,
                            input_graphs_all_exp,
                            target_graphs_all_exp,
                            multistep=multistep)

                        gt_latent = np.zeros(shape=(n_objects, 256))

                        loss_img, out_label = self.sess.run(
                            [self.model.img_loss, self.out_prediction_softmax],
                            feed_dict={
                                self.in_segxyz_tf: in_segxyz,
                                self.in_image_tf: in_image,
                                self.gt_predictions: gt_label,
                                self.in_control_tf: in_control,
                                self.gt_latent_vectors: gt_latent,
                                self.is_training: True
                            })
                        loss_velocity = np.array(0.0)
                        loss_position = np.array(0.0)
                        loss_edge = np.array(0.0)
                        loss_total = loss_img + loss_position + loss_edge + loss_velocity

                        losses_total.append(loss_total)
                        losses_img.append(loss_img)
                        losses_velocity.append(loss_velocity)
                        losses_position.append(loss_position)
                        losses_edge.append(loss_edge)

                        out_label[out_label >= 0.5] = 1.0
                        out_label[out_label < 0.5] = 0.0

                        outputs_total.append(
                            (out_label, in_segxyz, in_image, in_control, i,
                             (start_idx, end_idx)))

                the_time = time.time()
                elapsed_since_last_log = the_time - last_log_time
                cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

                batch_loss = np.mean(losses_total)
                img_batch_loss = np.mean(losses_img)
                vel_batch_loss = np.mean(losses_velocity)
                pos_batch_loss = np.mean(losses_position)
                edge_batch_loss = np.mean(losses_edge)

                print(
                    'total test batch loss: {:<8.6f} | img loss: {:<8.6f} | vel loss: {:<8.6f} | pos loss {:<8.6f} | edge loss {:<8.6f} time(s): {:<10.2f}'
                    .format(batch_loss, img_batch_loss, vel_batch_loss,
                            pos_batch_loss, edge_batch_loss,
                            elapsed_since_last_log))

                if outputs_total:
                    for output in outputs_total:
                        generate_and_export_image_dicts(
                            output=output,
                            features=features,
                            config=self.config,
                            prefix=prefix,
                            cur_batch_it=cur_batch_it,
                            dir_name=sub_dir_name,
                            reduce_dict=True,
                            multistep=multistep,
                            pad_ground_truth_to_exp_len=
                            pad_ground_truth_to_exp_len)

            except tf.errors.OutOfRangeError:
                break
            else:
                print("continue")
                continue
    def store_latent_vectors(self):
        assert self.config.n_epochs == 1, "set n_epochs to 1 for test mode"
        assert self.config.test_batch_size == 1, "set test_batch_size to 1 for test mode"
        prefix = self.config.exp_name
        print("Storing latent vectors baseline")
        cur_batch_it = self.model.cur_batch_tensor.eval(self.sess)

        df = pd.DataFrame(
            columns=['latent_vector_init_img', 'exp_id', 'exp_len'])
        sub_dir_name = "latent_vectors_initial_image_of_full_test_set_{}_iterations_trained".format(
            cur_batch_it)

        dir_path, _ = create_dir(os.path.join("../experiments", prefix),
                                 sub_dir_name)
        dataset_name = os.path.basename(self.config.tfrecords_dir)
        file_name = dir_path + "/latent_vectors_baseline_auto_predictor_dataset_{}.pkl".format(
            dataset_name)

        while True:
            try:
                features = self.sess.run(self.next_element_test)
                features = convert_dict_to_list_subdicts(
                    features, self.config.test_batch_size)

                for i in range(len(features)):
                    input_graphs_all_exp, target_graphs_all_exp, _ = create_graphs(
                        config=self.config,
                        batch_data=features[i],
                        initial_pos_vel_known=self.config.
                        initial_pos_vel_known,
                        batch_processing=False)

                    input_graphs_all_exp = [input_graphs_all_exp[0]]
                    target_graphs_all_exp = [target_graphs_all_exp[0]]

                    exp_id = features[i]['experiment_id']
                    exp_len = features[i][
                        "unpadded_experiment_length"]  # the label

                    #print(np.shape(input_graphs_all_exp), np.shape(target_graphs_all_exp))

                    input_graphs_all_exp = [input_graphs_all_exp]
                    target_graphs_all_exp = [target_graphs_all_exp]

                    in_segxyz, in_image, in_control, gt_label = networkx_graphs_to_images(
                        self.config, input_graphs_all_exp,
                        target_graphs_all_exp)

                    #print(np.shape(in_segxyz), np.shape(in_image), np.shape(in_control), np.shape(gt_label), exp_len)

                    loss_img, out_label, latent_init_img = self.sess.run(
                        [
                            self.model.loss_op, self.out_prediction_softmax,
                            self.latent_init_img
                        ],
                        feed_dict={
                            self.in_segxyz_tf: in_segxyz,
                            self.in_image_tf: in_image,
                            self.gt_predictions: gt_label,
                            self.in_control_tf: in_control,
                            self.is_training: True
                        })

                    #print(np.shape(latent_init_img))

                    df = df.append(
                        {
                            'latent_vector_init_img': latent_init_img,
                            'exp_id': exp_id,
                            'exp_len': exp_len
                        },
                        ignore_index=True)

            except tf.errors.OutOfRangeError:
                df.to_pickle(file_name)
                print("Pandas dataframe with {} rows saved to: {} ".format(
                    len(df.index), file_name))
                break
            else:
                print("continue")
                continue
    def save_encoder_vectors(self, train=True):
        assert self.config.n_epochs == 1, "set n_epochs to 1 for test mode"

        if "5_objects_50_rollouts_padded_novel" in self.config.tfrecords_dir:
            dir_name = "auto_encoding_features_5_objects_50_rollouts_novel"
        elif "5_objects_50_rollouts" in self.config.tfrecords_dir:
            dir_name = "auto_encoding_features_5_objects_50_rollouts"
        else:
            #dir_name = "auto_encoding_features_3_objects_15_rollouts"
            dir_name = "auto_encoding_features_2_bigger_cubes_dataset_50_rollouts"

        if train:
            next_element = self.next_element_train
            sub_dir_name = "train"
            batch_size = self.config.train_batch_size
        else:
            next_element = self.next_element_test
            sub_dir_name = "test"
            batch_size = self.config.test_batch_size

        dir_path, _ = create_dir(os.path.join("/scr2/fabiof/data/"), dir_name)
        dir_path, _ = create_dir(os.path.join("/scr2/fabiof/data/", dir_name),
                                 sub_dir_name)
        #dir_path, _ = create_dir(os.path.join("../experiments/"), dir_name)
        #dir_path, _ = create_dir(os.path.join("../experiments/", dir_name), sub_dir_name)
        iterator = 0
        while True:
            try:
                features = self.sess.run(next_element)
                features = convert_dict_to_list_subdicts(features, batch_size)

                for i in range(len(features)):
                    input_graphs_all_exp, target_graphs_all_exp, _ = create_graphs(
                        config=self.config,
                        batch_data=features[i],
                        initial_pos_vel_known=self.config.
                        initial_pos_vel_known,
                        batch_processing=False,
                        return_only_unpadded=True,
                        start_episode=0)

                    encoder_outputs = []
                    exp_id = features[i]['experiment_id']
                    iterator = iterator + 1

                    assert len(input_graphs_all_exp) == len(
                        target_graphs_all_exp)
                    single_step_prediction_chunks_input = [
                        [input_graph] for input_graph in input_graphs_all_exp
                    ]
                    single_step_prediction_chunks_target = [[
                        target_graph
                    ] for target_graph in target_graphs_all_exp]

                    for lst_inp, lst_targ in zip(
                            single_step_prediction_chunks_input,
                            single_step_prediction_chunks_target):
                        in_segxyz, in_image, in_control, gt_label, _ = networkx_graphs_to_images(
                            self.config, [lst_inp], [lst_targ], multistep=True)

                        in_images, out_reconstructions, encoder_output = self.sess.run(
                            [
                                self.in_rgb_seg_xyz,
                                self.out_prediction_softmax,
                                self.encoder_outputs
                            ],
                            feed_dict={
                                self.in_segxyz_tf: in_segxyz,
                                self.in_image_tf: in_image,
                                self.gt_predictions: gt_label,
                                self.in_control_tf: in_control,
                                self.is_training: True
                            })

                        encoder_outputs.append(encoder_output)
                    print("saved encoder vector number {} under {}".format(
                        iterator, os.path.join(dir_path, str(exp_id))))
                    np.savez_compressed(os.path.join(dir_path, str(exp_id)),
                                        encoder_outputs=encoder_outputs,
                                        exp_id=exp_id)

            except tf.errors.OutOfRangeError:
                break