Exemplo n.º 1
0
    def pack_and_save_metrics(self, start_time, create_summary_csv, train_losses, val_losses):
        """
        Given current epochs start_time, train losses, val losses and whether to create a new stats csv file, pack stats
        and save into a statistics csv file. Return a new start time for the new epoch.
        :param start_time: The start time of the current epoch
        :param create_summary_csv: A boolean variable indicating whether to create a new statistics file or
        append results to existing one
        :param train_losses: A dictionary with the current train losses
        :param val_losses: A dictionary with the currrent val loss
        :return: The current time, to be used for the next epoch.
        """
        epoch_summary_losses = self.merge_two_dicts(first_dict=train_losses, second_dict=val_losses)
        epoch_summary_string = self.build_loss_summary_string(epoch_summary_losses)
        epoch_summary_losses["epoch"] = self.epoch
        epoch_summary_losses['epoch_run_time'] = time.time() - start_time

        if create_summary_csv:
            self.summary_statistics_filepath = save_statistics(self.logs_filepath, list(epoch_summary_losses.keys()),
                                                               create=True)
            self.create_summary_csv = False

        start_time = time.time()
        print("epoch {} -> {}".format(epoch_summary_losses["epoch"], epoch_summary_string))

        self.summary_statistics_filepath = save_statistics(self.logs_filepath,
                                                           list(epoch_summary_losses.values()))
        return start_time
    def run_experiment(self):
        total_losses = {
            "loss": [],
            "precision": [],
            "hr": [],
            "diversity": [],
            "curr_epoch": []
        }

        for epoch_idx in range(self.starting_epoch,
                               self.configs['num_of_epochs']):
            print(f"Epoch: {epoch_idx}")
            self.pre_epoch_init_function()

            average_loss = self.run_training_epoch()
            precision_mean, hr_mean, diversity = self.run_evaluation_epoch()

            if precision_mean > self.best_val_model_precision:
                self.best_val_model_precision = precision_mean
                self.best_val_model_idx = epoch_idx

            self.writer.add_scalar('Average training loss per epoch',
                                   average_loss, epoch_idx)

            self.writer.add_scalar('Precision', precision_mean, epoch_idx)
            self.writer.add_scalar('Hit Ratio', hr_mean, epoch_idx)
            self.writer.add_scalar('Diversity', diversity, epoch_idx)

            print(
                f'HR: {hr_mean}, Precision: {precision_mean}, Diversity: {diversity}'
            )

            self.state['current_epoch_idx'] = epoch_idx
            self.state[
                'best_val_model_precision'] = self.best_val_model_precision
            self.state['best_val_model_idx'] = self.best_val_model_idx

            if self.configs['save_model']:
                self.save_model(model_save_dir=self.experiment_saved_models,
                                model_save_name="train_model",
                                model_idx=epoch_idx,
                                state=self.state)

            total_losses['loss'].append(average_loss)
            total_losses['precision'].append(precision_mean)
            total_losses['hr'].append(hr_mean)
            total_losses['diversity'].append(diversity)
            total_losses['curr_epoch'].append(epoch_idx)

            save_statistics(
                experiment_log_dir=self.experiment_logs,
                filename='summary.csv',
                stats_dict=total_losses,
                current_epoch=epoch_idx,
                continue_from_mode=True if
                (self.starting_epoch != 0 or epoch_idx > 0) else False)

        self.writer.flush()
        self.writer.close()
Exemplo n.º 3
0
    def run_experiment(self):
        total_losses = {"train_miou": [], "train_acc": [], "train_loss": [], "val_miou": [], "val_acc": [],
                        "val_loss": [], "curr_epoch": []}  
        for i, epoch_idx in enumerate(range(self.starting_epoch, self.num_epochs)):
            epoch_start_time = time.time()
            current_epoch_losses = {"train_miou": [], "train_acc": [], "train_loss": [],"val_miou": [], "val_acc": [], "val_loss": []}

            current_epoch_losses = self.run_training_epoch(current_epoch_losses)
            #print(self.optimizer.param_groups[0]['lr'])
            current_epoch_losses = self.run_validation_epoch(current_epoch_losses)

            val_mean_miou = np.mean(current_epoch_losses['val_miou'])
            if val_mean_miou > self.best_val_model_acc:  
                self.best_val_model_acc = val_mean_miou  
                self.best_val_model_idx = epoch_idx  

            for key, value in current_epoch_losses.items():
                total_losses[key].append(np.mean(value))

            total_losses['curr_epoch'].append(epoch_idx)
            save_statistics(experiment_log_dir=self.experiment_logs, filename='summary.csv',
                            stats_dict=total_losses, current_epoch=i,
                            continue_from_mode=True if (self.starting_epoch != 0 or i > 0) else False) 

            out_string = "_".join(
                ["{}_{:.4f}".format(key, np.mean(value)) for key, value in current_epoch_losses.items()])
            epoch_elapsed_time = time.time() - epoch_start_time  
            epoch_elapsed_time = "{:.4f}".format(epoch_elapsed_time)
            print("Epoch {}:".format(epoch_idx),"Iteration {}:".format(self.scheduler.last_epoch), out_string, "epoch time", epoch_elapsed_time, "seconds")
            self.state['current_epoch_idx'] = epoch_idx
            self.state['best_val_model_acc'] = self.best_val_model_acc
            self.state['best_val_model_idx'] = self.best_val_model_idx
            if(self.experiment_name != "test"):
                #if(epoch_idx==0 or (epoch_idx+1)%10==0):
                #    self.save_model(model_save_dir=self.experiment_saved_models,
                #            model_save_name="train_model", model_idx=epoch_idx, state=self.state)
                self.save_model(model_save_dir=self.experiment_saved_models,
                            model_save_name="train_model", model_idx='latest', state=self.state)
            
        if(self.experiment_name != "test"):
            print("Generating test set evaluation metrics")
            self.load_model(model_save_dir=self.experiment_saved_models, model_idx='latest',
                            model_save_name="train_model")
            current_epoch_losses = {"test_miou": [], "test_acc": [], "test_loss": []}  

            current_epoch_losses = self.run_testing_epoch(current_epoch_losses=current_epoch_losses)

            test_losses = {key: [np.mean(value)] for key, value in
                       current_epoch_losses.items()}  

            save_statistics(experiment_log_dir=self.experiment_logs, filename='test_summary.csv',
                        stats_dict=test_losses, current_epoch=0, continue_from_mode=False)
        else:
            test_losses = 0
        return total_losses, test_losses 
Exemplo n.º 4
0
    def evaluated_test_set_using_the_best_models(self, top_n_models):
        per_epoch_statistics = self.state['per_epoch_statistics']
        val_acc = np.copy(per_epoch_statistics['val_accuracy_mean'])
        val_idx = np.array([i for i in range(len(val_acc))])
        sorted_idx = np.argsort(val_acc, axis=0).astype(dtype=np.int32)[::-1][:top_n_models]

        sorted_val_acc = val_acc[sorted_idx]
        val_idx = val_idx[sorted_idx]
        print(sorted_idx)
        print(sorted_val_acc)

        top_n_idx = val_idx[:top_n_models]
        per_model_per_batch_preds = [[] for i in range(top_n_models)]
        per_model_per_batch_targets = [[] for i in range(top_n_models)]
        test_losses = [dict() for i in range(top_n_models)]
        for idx, model_idx in enumerate(top_n_idx):
            self.state = \
                self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model",
                                      model_idx=model_idx + 1)
            with tqdm.tqdm(total=int(self.args.num_evaluation_tasks / self.args.batch_size)) as pbar_test:
                for sample_idx, test_sample in enumerate(
                        self.data.get_test_batches(total_batches=int(self.args.num_evaluation_tasks / self.args.batch_size),
                                                   augment_images=False)):
                    #print(test_sample[4])
                    per_model_per_batch_targets[idx].extend(np.array(test_sample[3]))
                    per_model_per_batch_preds = self.test_evaluation_iteration(val_sample=test_sample,
                                                                               sample_idx=sample_idx,
                                                                               model_idx=idx,
                                                                               per_model_per_batch_preds=per_model_per_batch_preds,
                                                                               pbar_test=pbar_test)
        # for i in range(top_n_models):
        #     print("test assertion", 0)
        #     print(per_model_per_batch_targets[0], per_model_per_batch_targets[i])
        #     assert np.equal(np.array(per_model_per_batch_targets[0]), np.array(per_model_per_batch_targets[i]))

        per_batch_preds = np.mean(per_model_per_batch_preds, axis=0)
        #print(per_batch_preds.shape)
        per_batch_max = np.argmax(per_batch_preds, axis=2)
        per_batch_targets = np.array(per_model_per_batch_targets[0]).reshape(per_batch_max.shape)
        #print(per_batch_max)
        accuracy = np.mean(np.equal(per_batch_targets, per_batch_max))
        accuracy_std = np.std(np.equal(per_batch_targets, per_batch_max))

        test_losses = {"test_accuracy_mean": accuracy, "test_accuracy_std": accuracy_std}

        _ = save_statistics(self.logs_filepath,
                            list(test_losses.keys()),
                            create=True, filename="test_summary.csv")

        summary_statistics_filepath = save_statistics(self.logs_filepath,
                                                      list(test_losses.values()),
                                                      create=False, filename="test_summary.csv")
        print(test_losses)
        print("saved test performance at", summary_statistics_filepath)
Exemplo n.º 5
0
 def write_task_lang_log(self, log):
     """
     Writes the log from a train iteration in tidy format to the task/lang log file
     :param log: list containing [task name, language, iteration, support loss, support accuracy, query loss, query accuracy]
     :return:
     """
     for line in log:
         save_statistics(self.logs_filepath,
                         line,
                         filename="task_lang_log.csv",
                         create=False)
Exemplo n.º 6
0
    target_placeholder=data_targets,
    dropout_rate=dropout_rate,
    batch_size=batch_size,
    num_channels=train_data.inputs.shape[2],
    n_classes=train_data.num_classes,
    is_training=training_phase,
    augment_rotate_flag=rotate_data,
    strided_dim_reduction=strided_dim_reduction,
    use_batch_normalization=batch_norm)  # initialize our computational graph

if continue_from_epoch == -1:  # if this is a new experiment and not continuation of a previous one then generate a new
    # statistics file
    save_statistics(
        logs_filepath,
        "result_summary_statistics", [
            "epoch", "train_c_loss", "train_c_accuracy", "val_c_loss",
            "val_c_accuracy", "test_c_loss", "test_c_accuracy"
        ],
        create=True)

start_epoch = continue_from_epoch if continue_from_epoch != -1 else 0  # if new experiment start from 0 otherwise
# continue where left off

summary_op, losses_ops, c_error_opt_op = classifier_network.init_train(
)  # get graph operations (ops)

total_train_batches = train_data.num_batches
total_val_batches = val_data.num_batches
total_test_batches = test_data.num_batches

best_epoch = 0
    args.samples_per_class,
    args.use_full_context_embeddings,
    full_context_unroll_k=args.full_context_unroll_k,
    args=args)
total_train_batches = args.total_iter_per_epoch
total_val_batches = args.total_iter_per_epoch
total_test_batches = args.total_iter_per_epoch

saved_models_filepath, logs_filepath = build_experiment_folder(
    args.experiment_title)

save_statistics(logs_filepath, [
    "epoch", "total_train_c_loss_mean", "total_train_c_loss_std",
    "total_train_accuracy_mean", "total_train_accuracy_std",
    "total_val_c_loss_mean", "total_val_c_loss_std", "total_val_accuracy_mean",
    "total_val_accuracy_std", "total_test_c_loss_mean",
    "total_test_c_loss_std", "total_test_accuracy_mean",
    "total_test_accuracy_std"
],
                create=True)

# Experiment initialization and running
with tf.Session() as sess:
    sess.run(init)
    train_saver = tf.train.Saver()
    val_saver = tf.train.Saver()
    if args.continue_from_epoch != -1:  #load checkpoint if needed
        checkpoint = "saved_models/{}_{}.ckpt".format(args.experiment_title,
                                                      args.continue_from_epoch)
        variables_to_restore = []
        for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
Exemplo n.º 8
0
    def evaluated_test_set_using_the_best_models(self, top_n_models,
                                                 dataset_name):
        per_epoch_statistics = self.state['per_epoch_statistics']
        val_acc = np.copy(per_epoch_statistics['val_accuracy_mean'])
        val_idx = np.array([i for i in range(len(val_acc))])
        sorted_idx = np.argsort(
            val_acc, axis=0).astype(dtype=np.int32)[::-1][:top_n_models]

        sorted_val_acc = val_acc[sorted_idx]
        val_idx = val_idx[sorted_idx]
        print(sorted_idx)
        print(sorted_val_acc)
        top_n_models = 1

        top_n_idx = val_idx[:top_n_models]
        per_model_per_batch_preds = [[] for i in range(top_n_models)]
        per_model_per_batch_targets = [[] for i in range(top_n_models)]
        test_losses = [dict() for i in range(top_n_models)]
        total_losses = dict()
        total_accs = dict()
        old_state = self.state

        for idx, model_idx in enumerate(top_n_idx):
            total_losses = dict()
            total_accs = dict()
            self.state = \
                self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model",
                                      model_idx=model_idx + 1)
            with tqdm.tqdm(total=int(self.args.num_evaluation_tasks /
                                     self.args.batch_size)) as pbar_test:
                if dataset_name == 'train':
                    batches = self.data.get_test_train_batches(
                        total_batches=int(self.args.num_evaluation_tasks /
                                          self.args.batch_size),
                        augment_images=False)
                else:
                    batches = self.data.get_test_batches(total_batches=int(
                        self.args.num_evaluation_tasks / self.args.batch_size),
                                                         augment_images=False)
                for sample_idx, test_sample in enumerate(batches):
                    per_model_per_batch_targets[idx].extend(
                        np.array(test_sample[3]))
                    per_model_per_batch_preds, tst_losses, total_losses, total_accs = self.test_evaluation_iteration(
                        val_sample=test_sample,
                        total_losses=total_losses,
                        total_accs=total_accs,
                        sample_idx=sample_idx,
                        model_idx=idx,
                        per_model_per_batch_preds=per_model_per_batch_preds,
                        pbar_test=pbar_test)
                    if idx == 0:
                        per_mean = np.asarray(per_model_per_batch_preds[0])
                if dataset_name == 'train':
                    nn = self.num_train_tasks
                else:
                    nn = self.num_test_tasks
                accs = -1 * np.ones(nn)
                for ii in range(nn):
                    if ii in total_accs:
                        accs[ii] = np.mean(np.asarray(total_accs[ii]))

                print("ACCURACIES")
                print(accs)
                print(np.mean(accs))
                print(np.min(accs))
                print(np.std(accs))
                print(np.max(accs))
                sorted_accs = np.argsort(accs)
                print(sorted_accs[:3])

        per_batch_preds = per_mean
        per_batch_max = np.argmax(per_batch_preds, axis=2)
        per_batch_targets = np.array(per_model_per_batch_targets[0]).reshape(
            per_batch_max.shape)
        accuracy = np.mean(np.equal(per_batch_targets, per_batch_max))
        accuracy_std = np.std(np.equal(per_batch_targets, per_batch_max))

        tst_losses["test_accuracy_mean"] = accuracy
        tst_losses["test_accuracy_std"] = accuracy_std
        test_losses = {
            "test_accuracy_mean": accuracy,
            "test_accuracy_std": accuracy_std
        }

        _ = save_statistics(self.logs_filepath,
                            list(tst_losses.keys()),
                            create=True,
                            filename="test_summary.csv")

        summary_statistics_filepath = save_statistics(
            self.logs_filepath,
            list(tst_losses.values()),
            create=False,
            filename="test_summary.csv")
        print("saved test performance at", summary_statistics_filepath)
        self.state = old_state
        return accs
Exemplo n.º 9
0
    def run_experiment(self):
        with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
            sess.run(self.init)
            self.train_writer = tf.summary.FileWriter("{}/train_logs/".format(self.log_path),
                                                      graph=tf.get_default_graph())
            self.validation_writer = tf.summary.FileWriter("{}/validation_logs/".format(self.log_path),
                                                           graph=tf.get_default_graph())
            self.train_saver = tf.train.Saver()
            self.val_saver = tf.train.Saver()

            start_from_epoch = 0
            if self.continue_from_epoch!=-1:
                start_from_epoch = self.continue_from_epoch
                checkpoint = "{}/{}_{}.ckpt".format(self.saved_models_filepath, self.experiment_name, self.continue_from_epoch)
                variables_to_restore = []
                for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
                    print(var)
                    variables_to_restore.append(var)

                tf.logging.info('Fine-tuning from %s' % checkpoint)

                fine_tune = slim.assign_from_checkpoint_fn(
                    checkpoint,
                    variables_to_restore,
                    ignore_missing_vars=True)
                fine_tune(sess)

            self.iter_done = 0
            self.disc_iter = 5
            self.gen_iter = 1
            best_d_val_loss = np.inf

            if self.spherical_interpolation:
                dim = int(np.sqrt(self.num_generations)*2)
                self.z_2d_vectors = interpolations.create_mine_grid(rows=dim,
                                                                    cols=dim,
                                                                    dim=self.z_dim, space=3, anchors=None,
                                                                    spherical=True, gaussian=True)
                self.z_vectors = interpolations.create_mine_grid(rows=1, cols=self.num_generations, dim=self.z_dim,
                                                                 space=3, anchors=None, spherical=True, gaussian=True)
            else:
                self.z_vectors = np.random.normal(size=(self.num_generations, self.z_dim))
                self.z_2d_vectors = np.random.normal(size=(self.num_generations, self.z_dim))

            with tqdm.tqdm(total=self.total_epochs-start_from_epoch) as pbar_e:
                for e in range(start_from_epoch, self.total_epochs):

                    train_g_loss = []
                    val_g_loss = []
                    train_d_loss = []
                    val_d_loss = []

                    with tqdm.tqdm(total=self.total_train_batches) as pbar_train:
                        for iter in range(self.total_train_batches):

                            cur_sample = 0

                            for n in range(self.disc_iter):
                                x_train_i, x_train_j = self.data.get_train_batch()
                                x_val_i, x_val_j = self.data.get_val_batch()

                                _, d_train_loss_value = sess.run(
                                    [self.graph_ops["d_opt_op"], self.losses["d_losses"]],
                                    feed_dict={self.input_x_i: x_train_i,
                                               self.input_x_j: x_train_j,
                                               self.dropout_rate: self.dropout_rate_value,
                                               self.training_phase: True, self.random_rotate: True})

                                d_val_loss_value = sess.run(
                                    self.losses["d_losses"],
                                    feed_dict={self.input_x_i: x_val_i,
                                               self.input_x_j: x_val_j,
                                               self.dropout_rate: self.dropout_rate_value,
                                               self.training_phase: False, self.random_rotate: False})

                                cur_sample += 1
                                train_d_loss.append(d_train_loss_value)
                                val_d_loss.append(d_val_loss_value)

                            for n in range(self.gen_iter):
                                x_train_i, x_train_j = self.data.get_train_batch()
                                x_val_i, x_val_j = self.data.get_val_batch()
                                _, g_train_loss_value, train_summaries = sess.run(
                                    [self.graph_ops["g_opt_op"], self.losses["g_losses"],
                                     self.summary],
                                    feed_dict={self.input_x_i: x_train_i,
                                               self.input_x_j: x_train_j,
                                               self.dropout_rate: self.dropout_rate_value,
                                               self.training_phase: True, self.random_rotate: True})

                                g_val_loss_value, val_summaries = sess.run(
                                    [self.losses["g_losses"], self.summary],
                                    feed_dict={self.input_x_i: x_val_i,
                                               self.input_x_j: x_val_j,
                                               self.dropout_rate: self.dropout_rate_value,
                                               self.training_phase: False, self.random_rotate: False})

                                cur_sample += 1
                                train_g_loss.append(g_train_loss_value)
                                val_g_loss.append(g_val_loss_value)

                                if iter % (self.tensorboard_update_interval) == 0:
                                    self.train_writer.add_summary(train_summaries, global_step=self.iter_done)
                                    self.validation_writer.add_summary(val_summaries, global_step=self.iter_done)


                            self.iter_done = self.iter_done + 1
                            iter_out = "{}_train_d_loss: {}, train_g_loss: {}, " \
                                       "val_d_loss: {}, val_g_loss: {}".format(self.iter_done,
                                                                               d_train_loss_value, g_train_loss_value,
                                                                               d_val_loss_value,
                                                                               g_val_loss_value)
                            pbar_train.set_description(iter_out)
                            pbar_train.update(1)

                    total_d_train_loss_mean = np.mean(train_d_loss)
                    total_d_train_loss_std = np.std(train_d_loss)
                    total_g_train_loss_mean = np.mean(train_g_loss)
                    total_g_train_loss_std = np.std(train_g_loss)

                    print(
                        "Epoch {}: d_train_loss_mean: {}, d_train_loss_std: {},"
                                  "g_train_loss_mean: {}, g_train_loss_std: {}"
                        .format(e, total_d_train_loss_mean,
                                total_d_train_loss_std,
                                total_g_train_loss_mean,
                                total_g_train_loss_std))

                    total_d_val_loss_mean = np.mean(val_d_loss)
                    total_d_val_loss_std = np.std(val_d_loss)
                    total_g_val_loss_mean = np.mean(val_g_loss)
                    total_g_val_loss_std = np.std(val_g_loss)

                    print(
                        "Epoch {}: d_val_loss_mean: {}, d_val_loss_std: {},"
                        "g_val_loss_mean: {}, g_val_loss_std: {}, "
                            .format(e, total_d_val_loss_mean,
                                    total_d_val_loss_std,
                                    total_g_val_loss_mean,
                                    total_g_val_loss_std))



                    sample_generator(num_generations=self.num_generations, sess=sess, same_images=self.same_images,
                                     inputs=x_train_i,
                                     data=self.data, batch_size=self.batch_size, z_input=self.z_input,
                                     file_name="{}/train_z_variations_{}_{}.png".format(self.save_image_path,
                                                                                        self.experiment_name,
                                                                                        e),
                                     input_a=self.input_x_i, training_phase=self.training_phase,
                                     z_vectors=self.z_vectors, dropout_rate=self.dropout_rate,
                                     dropout_rate_value=self.dropout_rate_value)

                    sample_two_dimensions_generator(sess=sess,
                                                    same_images=self.same_images,
                                                    inputs=x_train_i,
                                                    data=self.data, batch_size=self.batch_size, z_input=self.z_input,
                                                    file_name="{}/train_z_spherical_{}_{}".format(self.save_image_path,
                                                                                                  self.experiment_name,
                                                                                                  e),
                                                    input_a=self.input_x_i, training_phase=self.training_phase,
                                                    dropout_rate=self.dropout_rate,
                                                    dropout_rate_value=self.dropout_rate_value,
                                                    z_vectors=self.z_2d_vectors)

                    with tqdm.tqdm(total=self.total_gen_batches) as pbar_samp:
                        for i in range(self.total_gen_batches):
                            x_gen_a = self.data.get_gen_batch()
                            sample_generator(num_generations=self.num_generations, sess=sess,
                                             same_images=self.same_images,
                                             inputs=x_gen_a,
                                             data=self.data, batch_size=self.batch_size, z_input=self.z_input,
                                             file_name="{}/test_z_variations_{}_{}_{}.png".format(self.save_image_path,
                                                                                                  self.experiment_name,
                                                                                                  e, i),
                                             input_a=self.input_x_i, training_phase=self.training_phase,
                                             z_vectors=self.z_vectors, dropout_rate=self.dropout_rate,
                                             dropout_rate_value=self.dropout_rate_value)

                            sample_two_dimensions_generator(sess=sess,
                                                            same_images=self.same_images,
                                                            inputs=x_gen_a,
                                                            data=self.data, batch_size=self.batch_size,
                                                            z_input=self.z_input,
                                                            file_name="{}/val_z_spherical_{}_{}_{}".format(
                                                                self.save_image_path,
                                                                self.experiment_name,
                                                                e, i),
                                                            input_a=self.input_x_i,
                                                            training_phase=self.training_phase,
                                                            dropout_rate=self.dropout_rate,
                                                            dropout_rate_value=self.dropout_rate_value,
                                                            z_vectors=self.z_2d_vectors)

                            pbar_samp.update(1)

                    train_save_path = self.train_saver.save(sess, "{}/train_saved_model_{}_{}.ckpt".format(
                        self.saved_models_filepath,
                        self.experiment_name, e))

                    if total_d_val_loss_mean<best_d_val_loss:
                        best_d_val_loss = total_d_val_loss_mean
                        val_save_path = self.train_saver.save(sess, "{}/val_saved_model_{}_{}.ckpt".format(
                            self.saved_models_filepath,
                            self.experiment_name, e))
                        print("Saved current best val model at", val_save_path)

                    save_statistics(self.log_path, [e, total_d_train_loss_mean, total_d_val_loss_mean,
                                                total_d_train_loss_std, total_d_val_loss_std,
                                                total_g_train_loss_mean, total_g_val_loss_mean,
                                                total_g_train_loss_std, total_g_val_loss_std])

                    pbar_e.update(1)
Exemplo n.º 10
0
    def __init__(self, parser, data):
        tf.reset_default_graph()

        args = parser.parse_args()
        self.continue_from_epoch = args.continue_from_epoch
        self.experiment_name = args.experiment_title
        self.saved_models_filepath, self.log_path, self.save_image_path = build_experiment_folder(self.experiment_name)
        self.num_gpus = args.num_of_gpus
        self.batch_size = args.batch_size
        gen_depth_per_layer = args.generator_inner_layers
        discr_depth_per_layer = args.discriminator_inner_layers
        self.z_dim = args.z_dim
        self.num_generations = args.num_generations
        self.dropout_rate_value = args.dropout_rate_value
        self.data = data
        self.reverse_channels = False

        generator_layers = [64, 64, 128, 128]
        discriminator_layers = [64, 64, 128, 128]

        gen_inner_layers = [gen_depth_per_layer, gen_depth_per_layer, gen_depth_per_layer, gen_depth_per_layer]
        discr_inner_layers = [discr_depth_per_layer, discr_depth_per_layer, discr_depth_per_layer,
                              discr_depth_per_layer]
        generator_layer_padding = ["SAME", "SAME", "SAME", "SAME"]

        image_height = data.image_height
        image_width = data.image_width
        image_channel = data.image_channel

        self.input_x_i = tf.placeholder(tf.float32, [self.num_gpus, self.batch_size, image_height, image_width,
                                                     image_channel], 'inputs-1')
        self.input_x_j = tf.placeholder(tf.float32, [self.num_gpus, self.batch_size, image_height, image_width,
                                                     image_channel], 'inputs-2-same-class')

        self.z_input = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], 'z-input')
        self.training_phase = tf.placeholder(tf.bool, name='training-flag')
        self.random_rotate = tf.placeholder(tf.bool, name='rotation-flag')
        self.dropout_rate = tf.placeholder(tf.float32, name='dropout-prob')

        dagan = DAGAN(batch_size=self.batch_size, input_x_i=self.input_x_i, input_x_j=self.input_x_j,
                      dropout_rate=self.dropout_rate, generator_layer_sizes=generator_layers,
                      generator_layer_padding=generator_layer_padding, num_channels=data.image_channel,
                      is_training=self.training_phase, augment=self.random_rotate,
                      discriminator_layer_sizes=discriminator_layers,
                      discr_inner_conv=discr_inner_layers,
                      gen_inner_conv=gen_inner_layers, num_gpus=self.num_gpus, z_dim=self.z_dim, z_inputs=self.z_input)

        self.summary, self.losses, self.graph_ops = dagan.init_train()
        self.same_images = dagan.sample_same_images()

        self.total_train_batches = int(data.training_data_size / (self.batch_size * self.num_gpus))

        self.total_gen_batches = int(data.generation_data_size / (self.batch_size * self.num_gpus))

        self.init = tf.global_variables_initializer()
        self.spherical_interpolation = True
        self.tensorboard_update_interval = int(self.total_train_batches/100/self.num_gpus)
        self.total_epochs = 200

        if self.continue_from_epoch == -1:
            save_statistics(self.log_path, ['epoch', 'total_d_train_loss_mean', 'total_d_val_loss_mean',
                                            'total_d_train_loss_std', 'total_d_val_loss_std',
                                            'total_g_train_loss_mean', 'total_g_val_loss_mean',
                                            'total_g_train_loss_std', 'total_g_val_loss_std'], create=True)
    args)

start_epoch, latest_loadpath = get_start_epoch(args)
args.latest_loadpath = latest_loadpath
best_epoch, best_test_acc = get_best_epoch(args)
if best_epoch >= 0:
    print('Best evaluation acc so far at {} epochs: {:0.2f}'.format(
        best_epoch, best_test_acc))

if not args.resume:
    save_statistics(logs_filepath,
                    "result_summary_statistics", [
                        "epoch",
                        "train_loss",
                        "test_loss",
                        "train_loss_c",
                        "test_loss_c",
                        "train_acc",
                        "test_acc",
                    ],
                    create=True)

######################################################################################################### Model
num_classes = 10 if args.dataset != 'Cifar-100' else 100
net = ModelSelector(in_shape=in_shape,
                    num_classes=num_classes).select(args.model, args)
print_network_stats(net)
net = net.to(device)

######################################################################################################### Optimisation
params = net.parameters()
Exemplo n.º 12
0
    def run_experiment(self):
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            sess.run(self.init)
            self.writer = tf.summary.FileWriter(self.log_path,
                                                graph=tf.get_default_graph())
            self.saver = tf.train.Saver()
            start_from_epoch = 0
            if self.continue_from_epoch != -1:
                start_from_epoch = self.continue_from_epoch
                checkpoint = "{}/{}_{}.ckpt".format(self.saved_models_filepath,
                                                    self.experiment_name,
                                                    self.continue_from_epoch)
                variables_to_restore = []
                for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
                    print(var)
                    variables_to_restore.append(var)

                tf.logging.info('Fine-tuning from %s' % checkpoint)

                fine_tune = slim.assign_from_checkpoint_fn(
                    checkpoint, variables_to_restore, ignore_missing_vars=True)
                fine_tune(sess)

            self.iter_done = 0
            self.disc_iter = 5
            self.gen_iter = 1

            if self.spherical_interpolation:
                dim = int(np.sqrt(self.num_generations) * 2)
                self.z_2d_vectors = interpolations.create_mine_grid(
                    rows=dim,
                    cols=dim,
                    dim=self.z_dim,
                    space=3,
                    anchors=None,
                    spherical=True,
                    gaussian=True)
                self.z_vectors = interpolations.create_mine_grid(
                    rows=1,
                    cols=self.num_generations,
                    dim=self.z_dim,
                    space=3,
                    anchors=None,
                    spherical=True,
                    gaussian=True)
            else:
                self.z_vectors = np.random.normal(size=(self.num_generations,
                                                        self.z_dim))
                self.z_2d_vectors = np.random.normal(
                    size=(self.num_generations, self.z_dim))

            with tqdm.tqdm(total=self.total_epochs -
                           start_from_epoch) as pbar_e:
                for e in range(start_from_epoch, self.total_epochs):

                    total_g_loss = 0.
                    total_d_loss = 0.
                    save_path = self.saver.save(
                        sess,
                        "{}/{}_{}.ckpt".format(self.saved_models_filepath,
                                               self.experiment_name, e))
                    print("Model saved at", save_path)
                    with tqdm.tqdm(
                            total=self.total_train_batches) as pbar_train:
                        x_train_a_gan_list, x_train_b_gan_same_class_list = self.data.get_train_batch(
                        )

                        sample_generator(
                            num_generations=self.num_generations,
                            sess=sess,
                            same_images=self.same_images,
                            inputs=x_train_a_gan_list,
                            data=self.data,
                            batch_size=self.batch_size,
                            z_input=self.z_input,
                            file_name="{}/train_z_variations_{}_{}.png".format(
                                self.save_image_path, self.experiment_name, e),
                            input_a=self.input_x_i,
                            training_phase=self.training_phase,
                            z_vectors=self.z_vectors,
                            dropout_rate=self.dropout_rate,
                            dropout_rate_value=self.dropout_rate_value)

                        sample_two_dimensions_generator(
                            sess=sess,
                            same_images=self.same_images,
                            inputs=x_train_a_gan_list,
                            data=self.data,
                            batch_size=self.batch_size,
                            z_input=self.z_input,
                            file_name="{}/train_z_spherical_{}_{}".format(
                                self.save_image_path, self.experiment_name, e),
                            input_a=self.input_x_i,
                            training_phase=self.training_phase,
                            dropout_rate=self.dropout_rate,
                            dropout_rate_value=self.dropout_rate_value,
                            z_vectors=self.z_2d_vectors)

                        with tqdm.tqdm(
                                total=self.total_gen_batches) as pbar_samp:
                            for i in range(self.total_gen_batches):
                                x_gen_a = self.data.get_gen_batch()
                                sample_generator(
                                    num_generations=self.num_generations,
                                    sess=sess,
                                    same_images=self.same_images,
                                    inputs=x_gen_a,
                                    data=self.data,
                                    batch_size=self.batch_size,
                                    z_input=self.z_input,
                                    file_name=
                                    "{}/test_z_variations_{}_{}_{}.png".format(
                                        self.save_image_path,
                                        self.experiment_name, e, i),
                                    input_a=self.input_x_i,
                                    training_phase=self.training_phase,
                                    z_vectors=self.z_vectors,
                                    dropout_rate=self.dropout_rate,
                                    dropout_rate_value=self.dropout_rate_value)

                                sample_two_dimensions_generator(
                                    sess=sess,
                                    same_images=self.same_images,
                                    inputs=x_gen_a,
                                    data=self.data,
                                    batch_size=self.batch_size,
                                    z_input=self.z_input,
                                    file_name="{}/val_z_spherical_{}_{}_{}".
                                    format(self.save_image_path,
                                           self.experiment_name, e, i),
                                    input_a=self.input_x_i,
                                    training_phase=self.training_phase,
                                    dropout_rate=self.dropout_rate,
                                    dropout_rate_value=self.dropout_rate_value,
                                    z_vectors=self.z_2d_vectors)

                                pbar_samp.update(1)

                        for i in range(self.total_train_batches):

                            for j in range(self.disc_iter):
                                x_train_a_gan_list, x_train_b_gan_same_class_list = self.data.get_train_batch(
                                )
                                _, d_loss_value = sess.run(
                                    [
                                        self.graph_ops["d_opt_op"],
                                        self.losses["d_losses"]
                                    ],
                                    feed_dict={
                                        self.input_x_i: x_train_a_gan_list,
                                        self.input_x_j:
                                        x_train_b_gan_same_class_list,
                                        self.dropout_rate:
                                        self.dropout_rate_value,
                                        self.training_phase: True,
                                        self.random_rotate: True
                                    })
                                total_d_loss += d_loss_value

                            for j in range(self.gen_iter):
                                x_train_a_gan_list, x_train_b_gan_same_class_list = \
                                    self.data.get_train_batch()
                                _, g_loss_value, summaries, = sess.run(
                                    [
                                        self.graph_ops["g_opt_op"],
                                        self.losses["g_losses"], self.summary
                                    ],
                                    feed_dict={
                                        self.input_x_i: x_train_a_gan_list,
                                        self.input_x_j:
                                        x_train_b_gan_same_class_list,
                                        self.dropout_rate:
                                        self.dropout_rate_value,
                                        self.training_phase: True,
                                        self.random_rotate: True
                                    })

                                total_g_loss += g_loss_value

                            if i % (self.tensorboard_update_interval) == 0:
                                self.writer.add_summary(summaries)
                            self.iter_done = self.iter_done + 1
                            iter_out = "d_loss: {}, g_loss: {}".format(
                                d_loss_value, g_loss_value)
                            pbar_train.set_description(iter_out)
                            pbar_train.update(1)

                    total_g_loss /= (self.total_train_batches * self.gen_iter)

                    total_d_loss /= (self.total_train_batches * self.disc_iter)

                    print("Epoch {}: d_loss: {}, wg_loss: {}".format(
                        e, total_d_loss, total_g_loss))

                    total_g_val_loss = 0.
                    total_d_val_loss = 0.

                    with tqdm.tqdm(total=self.total_test_batches) as pbar_val:
                        for i in range(self.total_test_batches):

                            for j in range(self.disc_iter):
                                x_test_a, x_test_b = self.data.get_test_batch()
                                d_loss_value = sess.run(
                                    self.losses["d_losses"],
                                    feed_dict={
                                        self.input_x_i: x_test_a,
                                        self.input_x_j: x_test_b,
                                        self.training_phase: False,
                                        self.random_rotate: False,
                                        self.dropout_rate:
                                        self.dropout_rate_value
                                    })

                                total_d_val_loss += d_loss_value

                            for j in range(self.gen_iter):
                                x_test_a, x_test_b = self.data.get_test_batch()
                                g_loss_value = sess.run(
                                    self.losses["g_losses"],
                                    feed_dict={
                                        self.input_x_i: x_test_a,
                                        self.input_x_j: x_test_b,
                                        self.training_phase: False,
                                        self.random_rotate: False,
                                        self.dropout_rate:
                                        self.dropout_rate_value
                                    })

                                total_g_val_loss += (g_loss_value)

                            self.iter_done = self.iter_done + 1
                            iter_out = "d_loss: {}, g_loss: {}".format(
                                d_loss_value, g_loss_value)
                            pbar_val.set_description(iter_out)
                            pbar_val.update(1)

                    total_g_val_loss /= (self.total_test_batches *
                                         self.gen_iter)
                    total_d_val_loss /= (self.total_test_batches *
                                         self.disc_iter)

                    print("Epoch {}: d_val_loss: {}, wg_val_loss: {}".format(
                        e, total_d_val_loss, total_g_val_loss))

                    save_statistics(self.log_path, [
                        e, total_d_loss, total_g_loss, total_d_val_loss,
                        total_g_val_loss
                    ])

                    pbar_e.update(1)
Exemplo n.º 13
0
    def evaluate_test_set_using_the_best_models(self, top_n_models):
        per_epoch_statistics = self.state["per_epoch_statistics"]
        val_acc = np.copy(per_epoch_statistics["val_loss_mean"])
        val_idx = np.array([i for i in range(len(val_acc))])
        sorted_idx = np.argsort(val_acc,
                                axis=0).astype(dtype=np.int32)[:top_n_models]

        sorted_val_acc = val_acc[sorted_idx]
        val_idx = val_idx[sorted_idx]
        print(sorted_idx)
        print(sorted_val_acc)

        top_n_idx = val_idx[:top_n_models]
        per_model_per_batch_loss = [[] for i in range(top_n_models)]
        # per_model_per_batch_targets = [[] for i in range(top_n_models)]
        test_losses = [dict() for i in range(top_n_models)]
        for idx, model_idx in enumerate(top_n_idx):
            self.state = self.model.load_model(
                model_save_dir=self.saved_models_filepath,
                model_name="train_model",
                model_idx=model_idx + 1,
            )
            with tqdm.tqdm(total=int(self.args.num_evaluation_tasks /
                                     self.args.batch_size)) as pbar_test:
                for sample_idx, test_sample in enumerate(
                        self.data.get_test_batches(
                            total_batches=int(self.args.num_evaluation_tasks /
                                              self.args.batch_size),
                            augment_images=False,
                        )):
                    # print(test_sample[4])
                    # per_model_per_batch_targets[idx].extend(np.array(test_sample[3]))
                    per_model_per_batch_loss = self.test_evaluation_iteration(
                        val_sample=test_sample,
                        sample_idx=sample_idx,
                        model_idx=idx,
                        per_model_per_batch_preds=per_model_per_batch_loss,
                        pbar_test=pbar_test,
                    )

        per_batch_loss = np.mean(per_model_per_batch_loss, axis=0)
        loss = np.mean(per_batch_loss)
        loss_std = np.std(per_batch_loss)

        test_losses = {"test_loss_mean": loss, "test_loss_std": loss_std}

        _ = save_statistics(
            self.logs_filepath,
            list(test_losses.keys()),
            create=True,
            filename="test_summary.csv",
        )

        summary_statistics_filepath = save_statistics(
            self.logs_filepath,
            list(test_losses.values()),
            create=False,
            filename="test_summary.csv",
        )
        print(test_losses)
        print("saved test performance at", summary_statistics_filepath)
Exemplo n.º 14
0
    def run_experiment(self):
        """
        Runs a full training experiment with evaluations of the model on the val set at every epoch. Furthermore,
        will return the test set evaluation results on the best performing validation model.
        """
        with tqdm.tqdm(initial=self.state['current_iter'],
                       total=int(self.args.total_iter_per_epoch * self.args.total_epochs)) as pbar_train:

            while self.state['current_iter'] < (self.args.total_epochs * self.args.total_iter_per_epoch):
                better_val_model = False

                for train_sample_idx, train_sample in enumerate(
                        self.data.get_train_batches(total_batches=int(self.args.total_iter_per_epoch *
                                                                      self.args.total_epochs) - self.state[
                                                                      'current_iter'],
                                                    augment_images=self.augment_flag)):
                    # print(self.state['current_iter'], (self.args.total_epochs * self.args.total_iter_per_epoch))
                    train_losses, total_losses, self.state['current_iter'] = self.train_iteration(
                        train_sample=train_sample,
                        total_losses=self.total_losses,
                        epoch_idx=(self.state['current_iter'] /
                                   self.args.total_iter_per_epoch),
                        pbar_train=pbar_train,
                        current_iter=self.state['current_iter'],
                        sample_idx=self.state['current_iter'])

                    if self.state['current_iter'] % self.args.total_iter_per_epoch == 0:

                        total_losses = dict()
                        val_losses = dict()
                        with tqdm.tqdm(total=self.args.total_iter_per_epoch) as pbar_val:
                            for _, val_sample in enumerate(
                                    self.data.get_val_batches(total_batches=int(self.args.total_iter_per_epoch),
                                                              augment_images=False)):
                                val_losses, total_losses = self.evaluation_iteration(val_sample=val_sample,
                                                                                     total_losses=total_losses,
                                                                                     pbar_val=pbar_val)

                            if val_losses["val_accuracy_mean"] > self.state['best_val_acc']:
                                print("Best validation accuracy", val_losses["val_accuracy_mean"])
                                self.state['best_val_acc'] = val_losses["val_accuracy_mean"]
                                self.state['best_val_iter'] = self.state['current_iter']
                                self.state['best_epoch'] = int(
                                    self.state['best_val_iter'] / self.args.total_iter_per_epoch)
                                better_val_model = True

                        self.epoch += 1
                        self.state = self.merge_two_dicts(first_dict=self.merge_two_dicts(first_dict=self.state,
                                                                                          second_dict=train_losses),
                                                          second_dict=val_losses)
                        self.save_models(model=self.model, epoch=self.epoch, state=self.state)

                        self.start_time = self.pack_and_save_metrics(start_time=self.start_time,
                                                                     create_summary_csv=self.create_summary_csv,
                                                                     train_losses=train_losses, val_losses=val_losses)

                        self.total_losses = dict()

                        self.epochs_done_in_this_run += 1

                        if self.epoch % 1 == 0 and better_val_model:
                            total_losses = dict()
                            test_losses = dict()
                            with tqdm.tqdm(total=self.args.total_iter_per_epoch) as pbar_test:
                                for _, test_sample in enumerate(
                                        self.data.get_test_batches(total_batches=int(self.args.total_iter_per_epoch),
                                                                   augment_images=False)):
                                    test_losses, total_losses = self.evaluation_iteration(val_sample=test_sample,
                                                                                          total_losses=total_losses,
                                                                                          pbar_val=pbar_test,
                                                                                          phase='test')

                            _ = save_statistics(self.logs_filepath,
                                                list(test_losses.keys()),
                                                create=True, filename="test_summary.csv")
                            summary_statistics_filepath = save_statistics(self.logs_filepath,
                                                                          list(test_losses.values()),
                                                                          create=False, filename="test_summary.csv")

                            print("saved test performance at", summary_statistics_filepath)

                        if self.epochs_done_in_this_run >= self.total_epochs_before_pause:
                            print("train_seed {}, val_seed: {}, at pause time".format(self.data.dataset.seed["train"],
                                                                                      self.data.dataset.seed["val"]))
                            sys.exit()
    def run_experiment(self):
        total_losses = {
            "loss": [],
            "precision": [],
            "hr": [],
            "F1 Score": [],
            "diversity": [],
            "CC": [],
            "curr_epoch": []
        }

        assert self.configs['type'] in [
            'linear', 'sigmoid', 'cosine', 'constant'
        ]

        if self.configs['type'] == 'linear':
            self.KL_weight = cycle_linear(0.001, self.configs['max_beta'],
                                          self.configs['num_of_epochs'],
                                          self.configs['cycles'],
                                          self.configs['ratio'])
        elif self.configs['type'] == 'sigmoid':
            self.KL_weight = cycle_sigmoid(0.001, self.configs['max_beta'],
                                           self.configs['num_of_epochs'],
                                           self.configs['cycles'],
                                           self.configs['ratio'])
        elif self.configs['type'] == 'cosine':
            self.KL_weight = cycle_cosine(0.001, self.configs['max_beta'],
                                          self.configs['num_of_epochs'],
                                          self.configs['cycles'],
                                          self.configs['ratio'])
        else:
            self.KL_weight = np.full(self.configs['num_of_epochs'],
                                     self.configs['max_beta'])

        for epoch_idx in range(self.starting_epoch,
                               self.configs['num_of_epochs']):
            print(f"Epoch: {epoch_idx}")
            average_loss = self.run_training_epoch(epoch_idx)
            precision_mean, hr_mean, diversity, cc = self.run_evaluation_epoch(
                self.movie_categories.shape[1], epoch_idx)

            f1_score = 2 * (precision_mean * hr_mean) / (precision_mean +
                                                         hr_mean)

            if precision_mean > self.best_val_model_precision:
                self.best_val_model_precision = precision_mean
                self.best_val_model_idx = epoch_idx

            self.writer.add_scalar('Average training loss per epoch',
                                   average_loss, epoch_idx)

            self.writer.add_scalar('Precision', precision_mean, epoch_idx)
            self.writer.add_scalar('Hit Ratio', hr_mean, epoch_idx)
            self.writer.add_scalar('F1 Score', f1_score, epoch_idx)
            self.writer.add_scalar('Diversity', diversity, epoch_idx)
            self.writer.add_scalar('CC', cc, epoch_idx)

            print(
                f'HR: {hr_mean}, Precision: {precision_mean}, F1: {f1_score}, Diversity: {diversity}, CC: {cc}'
            )

            self.state['current_epoch_idx'] = epoch_idx
            self.state[
                'best_val_model_precision'] = self.best_val_model_precision
            self.state['best_val_model_idx'] = self.best_val_model_idx

            if self.configs['save_model']:
                self.save_model(model_save_dir=self.experiment_saved_models,
                                model_save_name="train_model",
                                model_idx=epoch_idx,
                                state=self.state)

            total_losses['loss'].append(average_loss)
            total_losses['precision'].append(precision_mean)
            total_losses['hr'].append(hr_mean)
            total_losses['F1 Score'].append(f1_score)
            total_losses['diversity'].append(diversity)
            total_losses['CC'].append(cc)
            total_losses['curr_epoch'].append(epoch_idx)

            save_statistics(
                experiment_log_dir=self.experiment_logs,
                filename='summary.csv',
                stats_dict=total_losses,
                current_epoch=epoch_idx,
                continue_from_mode=True if
                (self.starting_epoch != 0 or epoch_idx > 0) else False)

        self.writer.flush()
        self.writer.close()
    def evaluate_test_set_using_the_best_models(self, top_n_models):
        if 'per_epoch_statistics' in self.state:
            per_epoch_statistics = self.state['per_epoch_statistics']
            val_acc = np.copy(per_epoch_statistics['val_accuracy_mean'])
            val_idx = np.array([i for i in range(len(val_acc))])
            sorted_idx = np.argsort(
                val_acc, axis=0).astype(dtype=np.int32)[::-1][:top_n_models]

            sorted_val_acc = val_acc[sorted_idx]
            val_idx = val_idx[sorted_idx]
            print(sorted_idx)
            print(sorted_val_acc)

            top_n_idx = val_idx[:top_n_models]
            per_model_per_batch_preds = [[] for i in range(top_n_models)]
            per_model_per_batch_targets = [[] for i in range(top_n_models)]

            test_losses = [dict() for i in range(top_n_models)]
        else:
            top_n_idx = [i for i in range(top_n_models)]
            per_model_per_batch_preds = [[] for i in range(top_n_models)]
            per_model_per_batch_targets = [[] for i in range(top_n_models)]
            test_losses = [dict() for i in range(top_n_models)]

        for idx, model_idx in enumerate(top_n_idx):
            if 'per_epoch_statistics' in self.state:
                self.state = \
                    self.model.load_model(model_save_dir=self.saved_models_filepath, model_name="train_model",
                                          model_idx=model_idx + 1)
            else:
                pass

            with tqdm.tqdm(total=int(self.num_evaluation_tasks /
                                     self.batch_size)) as pbar_test:
                for sample_idx, test_sample in enumerate(self.data['test']):
                    test_sample = self.convert_into_continual_tasks(
                        test_sample)
                    x_support_set, x_target_set, y_support_set, y_target_set, x, y = test_sample
                    per_model_per_batch_targets[idx].extend(
                        np.array(y_target_set))
                    per_model_per_batch_preds = self.test_evaluation_iteration(
                        val_sample=test_sample,
                        sample_idx=sample_idx,
                        model_idx=idx,
                        per_model_per_batch_preds=per_model_per_batch_preds,
                        pbar_test=pbar_test)
        per_batch_preds = np.mean(per_model_per_batch_preds, axis=0)

        per_batch_max = np.argmax(per_batch_preds, axis=2)
        per_batch_targets = np.array(per_model_per_batch_targets[0]).reshape(
            per_batch_max.shape)

        accuracy = np.mean(np.equal(per_batch_targets, per_batch_max))
        accuracy_std = np.std(np.equal(per_batch_targets, per_batch_max))

        test_losses = {
            "test_accuracy_mean": accuracy,
            "test_accuracy_std": accuracy_std
        }

        _ = save_statistics(self.logs_filepath,
                            list(test_losses.keys()),
                            create=True,
                            filename="test_summary.csv")

        summary_statistics_filepath = save_statistics(
            self.logs_filepath,
            list(test_losses.values()),
            create=False,
            filename="test_summary.csv")
        print(test_losses)
        print("saved test performance at", summary_statistics_filepath)