コード例 #1
0
ファイル: experiment_builder.py プロジェクト: tnaren3/maml
    def run_experiment(self):
        """
        Runs a full training experiment with evaluations of the model on the val set at every epoch. Furthermore,
        will return the test set evaluation results on the best performing validation model.
        """
        with tqdm.tqdm(initial=self.state['current_iter'],
                       total=int(self.args.total_iter_per_epoch *
                                 self.args.total_epochs)) as pbar_train:

            while (self.state['current_iter'] <
                   (self.args.total_epochs * self.args.total_iter_per_epoch)
                   ) and (self.args.evaluate_on_test_set_only == False):

                for train_sample_idx, train_sample in enumerate(
                        self.data.get_train_batches(
                            total_batches=int(self.args.total_iter_per_epoch *
                                              self.args.total_epochs) -
                            self.state['current_iter'],
                            augment_images=self.augment_flag)):
                    # print(self.state['current_iter'], (self.args.total_epochs * self.args.total_iter_per_epoch))
                    train_losses, total_losses, train_metrics, total_metrics, self.state[
                        'current_iter'] = self.train_iteration(
                            train_sample=train_sample,
                            total_losses=self.total_losses,
                            total_metrics=self.total_metrics,
                            epoch_idx=(self.state['current_iter'] /
                                       self.args.total_iter_per_epoch),
                            pbar_train=pbar_train,
                            current_iter=self.state['current_iter'],
                            sample_idx=self.state['current_iter'])

                    if self.state[
                            'current_iter'] % self.args.total_iter_per_epoch == 0:

                        total_losses = dict()
                        total_metrics = dict()
                        val_losses = dict()
                        val_metrics = dict()
                        with tqdm.tqdm(
                                total=int(self.args.num_evaluation_tasks /
                                          self.args.batch_size)) as pbar_val:
                            for _, val_sample in enumerate(
                                    self.data.get_val_batches(
                                        total_batches=int(
                                            self.args.num_evaluation_tasks /
                                            self.args.batch_size),
                                        augment_images=False)):
                                val_losses, total_losses, val_metrics, total_metrics = self.evaluation_iteration(
                                    val_sample=val_sample,
                                    total_losses=total_losses,
                                    total_metrics=total_metrics,
                                    pbar_val=pbar_val,
                                    phase='val')

                            if val_losses["val_accuracy_mean"] > self.state[
                                    'best_val_acc']:
                                print("Best validation accuracy",
                                      val_losses["val_accuracy_mean"])
                                self.state['best_val_acc'] = val_losses[
                                    "val_accuracy_mean"]
                                self.state['best_val_iter'] = self.state[
                                    'current_iter']
                                self.state['best_epoch'] = int(
                                    self.state['best_val_iter'] /
                                    self.args.total_iter_per_epoch)

                        self.epoch += 1
                        self.state = self.merge_two_dicts(
                            first_dict=self.merge_two_dicts(
                                first_dict=self.merge_two_dicts(
                                    first_dict=self.merge_two_dicts(
                                        first_dict=self.state,
                                        second_dict=train_losses),
                                    second_dict=val_losses),
                                second_dict=train_metrics),
                            second_dict=val_metrics)

                        self.save_models(model=self.model,
                                         epoch=self.epoch,
                                         state=self.state)

                        self.start_time, self.state = self.pack_and_save_metrics(
                            start_time=self.start_time,
                            create_summary_csv=self.create_summary_csv,
                            train_losses=train_losses,
                            train_metrics=train_metrics,
                            val_losses=val_losses,
                            val_metrics=val_metrics,
                            state=self.state)

                        self.total_losses = dict()
                        self.total_metrics = dict()

                        self.epochs_done_in_this_run += 1

                        save_to_json(
                            filename=os.path.join(self.logs_filepath,
                                                  "summary_statistics.json"),
                            dict_to_store=self.state['per_epoch_statistics'])

                        if self.epochs_done_in_this_run >= self.total_epochs_before_pause:
                            print("train_seed {}, val_seed: {}, at pause time".
                                  format(self.data.dataset.seed["train"],
                                         self.data.dataset.seed["val"]))
                            sys.exit()
            self.evaluated_test_set_using_the_best_models(top_n_models=5)
コード例 #2
0
    def run_experiment(self):
        """
        Runs a full training experiment with evaluations of the model on the val set at every epoch. Furthermore,
        will return the test set evaluation results on the best performing validation model.
        """
        torch.cuda.empty_cache()

        count = 0
        with tqdm.tqdm(initial=self.state['current_iter'],
                       total=int(self.args.total_iter_per_epoch *
                                 self.args.total_epochs)) as pbar_train:

            while (self.state['current_iter'] <
                   (self.args.total_epochs * self.args.total_iter_per_epoch)
                   ) and (self.args.evaluate_on_test_set_only == False):

                # train sample is data for one training iteration
                for train_sample_idx, train_sample in enumerate(
                        self.data.get_train_batches(
                            total_batches=int(self.args.total_iter_per_epoch *
                                              self.args.total_epochs) -
                            self.state['current_iter'],
                            augment_images=self.augment_flag)):
                    # print(self.state['current_iter'], (self.args.total_epochs * self.args.total_iter_per_epoch))
                    train_losses, self.state[
                        'current_iter'] = self.train_iteration(
                            train_sample=train_sample,
                            total_losses=self.total_losses,
                            epoch_idx=(self.state['current_iter'] /
                                       self.args.total_iter_per_epoch),
                            pbar_train=pbar_train,
                            current_iter=self.state['current_iter'],
                            sample_idx=self.state['current_iter'])

                    if train_sample[5][0] == 6:
                        count += 1

                    if self.state[
                            'current_iter'] % self.args.total_iter_per_epoch == 0:

                        total_losses = dict()
                        val_losses = dict()
                        total_accs = dict()
                        with tqdm.tqdm(
                                total=int(self.args.num_evaluation_tasks /
                                          self.args.batch_size)) as pbar_val:
                            for _, val_sample in enumerate(
                                    self.data.get_test_batches(
                                        total_batches=int(
                                            self.args.num_evaluation_tasks /
                                            self.args.batch_size),
                                        augment_images=False)):
                                val_losses, total_losses, total_accs = self.evaluation_iteration(
                                    val_sample=val_sample,
                                    total_losses=total_losses,
                                    total_accs=total_accs,
                                    pbar_val=pbar_val,
                                    phase='val')

                            if val_losses["val_accuracy_mean"] > self.state[
                                    'best_val_acc']:
                                print("Best validation accuracy",
                                      val_losses["val_accuracy_mean"])
                                self.state['best_val_acc'] = val_losses[
                                    "val_accuracy_mean"]
                                self.state['best_val_iter'] = self.state[
                                    'current_iter']
                                self.state['best_epoch'] = int(
                                    self.state['best_val_iter'] /
                                    self.args.total_iter_per_epoch)

                            nn = self.num_test_tasks
                            accs = -1 * np.ones(nn)
                            for ii in range(nn):
                                if ii in total_accs:
                                    accs[ii] = np.mean(
                                        np.asarray(total_accs[ii]))

                            print("ACCURACIES")
                            print(accs)
                            print(np.mean(accs))
                            print(np.min(accs))
                            print(np.std(accs))
                            print(np.max(accs))
                            sorted_accs = np.argsort(accs)
                            print(sorted_accs[:3])

                            total_losses = dict()
                            val_losses = dict()
                            total_accs = dict()

                            for _, val_sample in enumerate(
                                    self.data.get_test_train_batches(
                                        total_batches=int(
                                            self.args.num_evaluation_tasks /
                                            self.args.batch_size),
                                        augment_images=False)):
                                val_losses, total_losses, total_accs = self.evaluation_iteration(
                                    val_sample=val_sample,
                                    total_losses=total_losses,
                                    total_accs=total_accs,
                                    pbar_val=pbar_val,
                                    phase='val')

                            nn = self.num_train_tasks
                            accs = -1 * np.ones(nn)
                            for ii in range(nn):
                                if ii in total_accs:
                                    accs[ii] = np.mean(
                                        np.asarray(total_accs[ii]))

                            print("ACCURACIES TRAIN")
                            print(accs)
                            print(np.mean(accs))
                            print(np.min(accs))
                            print(np.std(accs))
                            print(np.max(accs))
                            sorted_accs = np.argsort(accs)
                            print(sorted_accs[:3])

                        self.epoch += 1
                        self.state = self.merge_two_dicts(
                            first_dict=self.merge_two_dicts(
                                first_dict=self.state,
                                second_dict=train_losses),
                            second_dict=val_losses)

                        self.save_models(model=self.model,
                                         epoch=self.epoch,
                                         state=self.state)

                        self.start_time, self.state = self.pack_and_save_metrics(
                            start_time=self.start_time,
                            create_summary_csv=self.create_summary_csv,
                            train_losses=train_losses,
                            val_losses=val_losses,
                            state=self.state)

                        self.total_losses = dict()

                        self.epochs_done_in_this_run += 1

                        save_to_json(
                            filename=os.path.join(self.logs_filepath,
                                                  "summary_statistics.json"),
                            dict_to_store=self.state['per_epoch_statistics'])

                        if self.epochs_done_in_this_run >= self.total_epochs_before_pause:
                            print("train_seed {}, val_seed: {}, at pause time".
                                  format(self.data.dataset.seed["train"],
                                         self.data.dataset.seed["val"]))
                            sys.exit()

            accs_train = self.evaluated_test_set_using_the_best_models(
                top_n_models=5, dataset_name='train')
            accs_test1 = self.evaluated_test_set_using_the_best_models(
                top_n_models=5, dataset_name='test')
コード例 #3
0
    def run_experiment(self):
        """
        Runs a full training experiment with evaluations of the model on the val set at every epoch. Furthermore,
        will return the test set evaluation results on the best performing validation model.
        """
        global epoch_val_preds, epoch_val_labels

        with tqdm.tqdm(initial=self.state['current_iter'],
                       total=int(self.args.total_iter_per_epoch *
                                 self.args.total_epochs)) as pbar_train:

            while (self.state['current_iter'] <
                   (self.args.total_epochs * self.args.total_iter_per_epoch)):

                for train_sample_idx, train_sample in enumerate(
                        self.data.get_train_batches(
                            total_batches=int(self.args.total_iter_per_epoch *
                                              self.args.total_epochs) -
                            self.state['current_iter'])):
                    # print(self.state['current_iter'], (self.args.total_epochs * self.args.total_iter_per_epoch))
                    train_losses, total_losses, self.state[
                        'current_iter'] = self.train_iteration(
                            train_sample=train_sample,
                            total_losses=self.total_losses,
                            epoch_idx=(self.state['current_iter'] /
                                       self.args.total_iter_per_epoch),
                            pbar_train=pbar_train,
                            current_iter=self.state['current_iter'],
                            sample_idx=self.state['current_iter'])

                    if self.state[
                            'current_iter'] % self.args.total_iter_per_epoch == 0:

                        epoch_val_preds = []
                        epoch_val_labels = []
                        total_losses = dict()
                        val_losses = dict()
                        with tqdm.tqdm(total=int(
                                len(self.data.dataset_val.main_df) /
                                self.args.batch_size)) as pbar_val:
                            for _, val_sample in enumerate(
                                    self.data.get_val_batches()):
                                val_losses, total_losses = self.evaluation_iteration(
                                    val_sample=val_sample,
                                    total_losses=total_losses,
                                    pbar_val=pbar_val,
                                    phase='val')

                            val_losses["val_c_index"] = ci_index(
                                epoch_val_preds, epoch_val_labels)

                            if val_losses["val_loss_mean"] < self.state[
                                    'best_val_loss']:
                                print("Best validation loss",
                                      val_losses["val_loss_mean"])
                                self.state['best_val_loss'] = val_losses[
                                    "val_loss_mean"]
                                self.state['best_c_index'] = val_losses[
                                    "val_c_index"]
                                self.state['best_val_iter'] = self.state[
                                    'current_iter']
                                self.state['best_epoch'] = int(
                                    self.state['best_val_iter'] /
                                    self.args.total_iter_per_epoch)

                        self.epoch += 1
                        self.state = self.merge_two_dicts(
                            first_dict=self.merge_two_dicts(
                                first_dict=self.state,
                                second_dict=train_losses),
                            second_dict=val_losses)

                        self.save_models(model=self.model,
                                         epoch=self.epoch,
                                         state=self.state)

                        self.start_time, self.state = self.pack_and_save_metrics(
                            start_time=self.start_time,
                            create_summary_csv=self.create_summary_csv,
                            train_losses=train_losses,
                            val_losses=val_losses,
                            state=self.state)

                        self.total_losses = dict()

                        self.epochs_done_in_this_run += 1

                        save_to_json(
                            filename=os.path.join(self.logs_filepath,
                                                  "summary_statistics.json"),
                            dict_to_store=self.state['per_epoch_statistics'])

                        if self.epochs_done_in_this_run >= self.total_epochs_before_pause:
                            sys.exit()
コード例 #4
0
    def run_experiment(self):
        """
        Runs a full training experiment with evaluations of the model on the val set at every epoch. Furthermore,
        will return the test set evaluation results on the best performing validation model.
        """

        # pr = cProfile.Profile()
        # pr.enable()
        with tqdm.tqdm(
                initial=self.state["current_iter"],
                total=int(self.args.total_iter_per_epoch *
                          self.args.total_epochs),
        ) as pbar_train:

            while (self.state["current_iter"] <
                   (self.args.total_epochs * self.args.total_iter_per_epoch)
                   ) and (self.args.evaluate_on_test_set_only == False):

                for train_sample_idx, train_sample in enumerate(
                        self.data.get_train_batches(
                            total_batches=int(self.args.total_iter_per_epoch *
                                              self.args.total_epochs) -
                            self.state["current_iter"])):
                    (
                        train_losses,
                        total_losses,
                        self.state["current_iter"],
                    ) = self.train_iteration(
                        train_sample=train_sample,
                        total_losses=self.total_losses,
                        epoch_idx=(self.state["current_iter"] /
                                   self.args.total_iter_per_epoch),
                        pbar_train=pbar_train,
                        current_iter=self.state["current_iter"],
                        sample_idx=self.state["current_iter"],
                    )

                    if self.state[
                            "current_iter"] % self.args.total_iter_per_epoch == 0:
                        # pr.disable()
                        # pr.print_stats()
                        epoch = (self.state["current_iter"] //
                                 self.args.total_iter_per_epoch)
                        total_losses = dict()
                        val_losses = dict()
                        new_best = False

                        if (self.args.eval_using_full_task_set
                            ):  # evaluate on the whole available task set
                            val_losses = self.full_task_set_evaluation(
                                epoch=epoch)
                        else:  # evaluate in few-shot fashion/ on query set only
                            with tqdm.tqdm(total=int(
                                    self.args.num_evaluation_tasks /
                                    self.args.batch_size)) as pbar_val:
                                for _, val_sample in enumerate(
                                        self.data.
                                        get_val_batches(total_batches=int(
                                            self.args.num_evaluation_tasks /
                                            self.args.batch_size))):
                                    (
                                        val_losses,
                                        total_losses,
                                    ) = self.evaluation_iteration(
                                        val_sample=val_sample,
                                        total_losses=total_losses,
                                        pbar_val=pbar_val,
                                        phase="val",
                                    )
                        # Write metrics to tensorboard

                        # log metrics
                        self.writer.add_scalars(
                            "loss",
                            {
                                "train": train_losses["train_loss_mean"],
                                "val": val_losses["val_loss_mean"],
                            },
                            epoch,
                        )

                        self.writer.add_scalars(
                            "Accuracy",
                            {
                                "train": train_losses["train_accuracy_mean"],
                                "val": val_losses["val_accuracy_mean"],
                            },
                            epoch,
                        )

                        # log weight distributions and gradients of slow weights
                        for param_name, param in self.model.named_parameters():
                            self.writer.add_histogram(param_name, param, epoch)

                        self.writer.flush()

                        if (val_losses["val_accuracy_mean"] >
                                self.state["best_val_accuracy"]):
                            self.num_epoch_no_improvements = 0
                            new_best = True
                            print(
                                "Best validation accuracy",
                                val_losses["val_accuracy_mean"],
                                "with loss",
                                val_losses["val_loss_mean"],
                            )

                            self.state["best_val_accuracy"] = (
                                val_losses["val_accuracy_mean"], )

                            self.state["best_val_iter"] = self.state[
                                "current_iter"]
                            self.state["best_epoch"] = int(
                                self.state["best_val_iter"] /
                                self.args.total_iter_per_epoch)

                        else:
                            self.num_epoch_no_improvements += 1
                        self.epoch += 1
                        self.state = self.merge_two_dicts(
                            first_dict=self.merge_two_dicts(
                                first_dict=self.state,
                                second_dict=train_losses),
                            second_dict=val_losses,
                        )

                        self.save_models(
                            model=self.model,
                            epoch=self.epoch,
                            state=self.state,
                            new_best=new_best,
                        )

                        self.start_time, self.state = self.pack_and_save_metrics(
                            start_time=self.start_time,
                            create_summary_csv=self.create_summary_csv,
                            train_losses=train_losses,
                            val_losses=val_losses,
                            state=self.state,
                        )

                        self.total_losses = dict()

                        self.epochs_done_in_this_run += 1

                        save_to_json(
                            filename=os.path.join(self.logs_filepath,
                                                  "summary_statistics.json"),
                            dict_to_store=self.state["per_epoch_statistics"],
                        )

                        if (self.epochs_done_in_this_run >=
                                self.total_epochs_before_pause):
                            print("Pause time, evaluating on test set...")
                            print(
                                self.full_task_set_evaluation(
                                    set_name="test", epoch=self.epoch))
                            print("train_seed {}, val_seed: {}, at pause time".
                                  format(
                                      self.data.dataset.seed["train"],
                                      self.data.dataset.seed["val"],
                                  ))

                            sys.exit()
                        if self.num_epoch_no_improvements > self.patience:
                            print(
                                "{} epochs no improvement, early stopping applied."
                                .format(self.num_epoch_no_improvements))
                            print(
                                self.full_task_set_evaluation(
                                    set_name="test", epoch=self.epoch))
                            print("train_seed {}, val_seed: {}, at pause time".
                                  format(
                                      self.data.dataset.seed["train"],
                                      self.data.dataset.seed["val"],
                                  ))

                            sys.exit()

            print(
                self.full_task_set_evaluation(epoch=self.epoch,
                                              set_name="test"))
コード例 #5
0
    def run_experiment(self):
        """
        Runs a full training experiment with evaluations of the model on the val set at every epoch. Furthermore,
        will return the test set evaluation results on the best performing validation model.
        """
        with tqdm.tqdm(initial=self.state['current_iter'],
                       total=int(self.total_iter_per_epoch *
                                 self.total_epochs)) as pbar_train:

            self.data['train'].dataset.set_current_iter_idx(
                self.state['current_iter'])

            while (self.state['current_iter'] <
                   (self.total_epochs * self.total_iter_per_epoch)) and (
                       self.evaluate_on_test_set_only == False):

                for idx, train_sample in enumerate(self.data['train']):
                    train_sample = self.convert_into_continual_tasks(
                        train_sample)

                    train_losses, total_losses, self.state[
                        'current_iter'] = self.train_iteration(
                            train_sample=train_sample,
                            total_losses=self.total_losses,
                            epoch_idx=(self.state['current_iter'] /
                                       self.total_iter_per_epoch),
                            pbar_train=pbar_train,
                            current_iter=self.state['current_iter'],
                            sample_idx=self.state['current_iter'])

                    better_val_model = False
                    if self.state[
                            'current_iter'] % self.total_iter_per_epoch == 0:

                        total_losses = dict()
                        val_losses = dict()
                        with tqdm.tqdm(
                                total=len(self.data['val'])) as pbar_val:
                            for val_sample_idx, val_sample in enumerate(
                                    self.data['val']):

                                val_sample = self.convert_into_continual_tasks(
                                    val_sample)
                                val_losses, total_losses = self.evaluation_iteration(
                                    val_sample=val_sample,
                                    total_losses=total_losses,
                                    pbar_val=pbar_val,
                                    phase='val')

                            if val_losses["val_accuracy_mean"] > self.state[
                                    'best_val_acc']:
                                print("Best validation accuracy",
                                      val_losses["val_accuracy_mean"])
                                self.state['best_val_acc'] = val_losses[
                                    "val_accuracy_mean"]
                                self.state['best_val_iter'] = self.state[
                                    'current_iter']
                                self.state['best_epoch'] = int(
                                    self.state['best_val_iter'] /
                                    self.total_iter_per_epoch)

                        self.epoch += 1
                        self.state = self.merge_two_dicts(
                            first_dict=self.merge_two_dicts(
                                first_dict=self.state,
                                second_dict=train_losses),
                            second_dict=val_losses)

                        self.start_time, self.state = self.pack_and_save_metrics(
                            start_time=self.start_time,
                            create_summary_csv=self.create_summary_csv,
                            train_losses=train_losses,
                            val_losses=val_losses,
                            state=self.state)
                        self.save_models(model=self.model,
                                         epoch=self.epoch,
                                         state=self.state)

                        self.total_losses = dict()

                        self.epochs_done_in_this_run += 1
                        # print(self.state['per_epoch_statistics']['val_accuracy_mean'])
                        save_to_json(
                            filename=os.path.join(self.logs_filepath,
                                                  "summary_statistics.json"),
                            dict_to_store=self.state['per_epoch_statistics'])

            self.evaluate_test_set_using_the_best_models(top_n_models=5)