예제 #1
0
 def run_demo(self):
     self.load()
     self.show_intro()
     meta_batch = MetaBatch(self.samples[Splits.test], self.device)
     task = self.analyze_task(meta_batch)
     self.analyze_sample(task)
     print("Demo Done")
예제 #2
0
    def loop_outer(self):
        steps = self.hp.steps_outer
        interval = steps // 10
        loader = self.loaders[Splits.train]
        tracker = MetricsTracker(prefix=Splits.train)
        stop_saver = EarlyStopSaver(self.net)

        for i, batch in enumerate(
                generate(loader, limit=steps, show_progress=True)):
            self.opt_outer.zero_grad()
            for task in MetaBatch(batch, self.device).get_tasks():
                metrics = self.loop_inner(task)
                tracker.store(metrics)
                self.opt_outer.store_grad()
            self.opt_outer.step(i)

            if i % interval == 0:
                metrics = tracker.get_average()
                tracker.reset()
                metrics.update(self.loop_eval(Splits.val))
                print({k: round(v, 3) for k, v in metrics.items()})
                stop_saver.check_stop(metrics["val_loss"])
            self.save()
        stop_saver.load_best()
        self.save()
예제 #3
0
    def loop_val(self) -> dict:
        tracker = MetricsTracker(prefix=Splits.val)
        for task in MetaBatch(self.batch_val, torch.device("cpu")).get_tasks():
            x_train, y_train, x_test, y_test = task
            model = self.model
            model.fit(x_train.numpy(), y_train.numpy())

            logits_numpy = model.predict_proba(x_test.numpy())

            logits = torch.from_numpy(logits_numpy)
            acc = ReptileSystem.get_accuracy(logits, y_test)
            tracker.store(dict(acc=acc))
        return tracker.get_average()
예제 #4
0
    def loop_eval(self) -> dict:
        tracker = MetricsTracker(prefix=Splits.val)
        for task in MetaBatch(self.batch_val, torch.device("cpu")).get_tasks():
            x_train, y_train = task.train
            x_test, y_test = task.test
            # model = svm.SVC(kernel="rbf")
            model = linear_model.RidgeClassifier()
            model.fit(x_train.numpy(), y_train.numpy())

            logits_numpy = model.predict(x_test.numpy())
            logits = torch.from_numpy(logits_numpy)
            acc = accuracy_score(logits, y_test)
            tracker.store(dict(acc=acc))
        return tracker.get_average()
예제 #5
0
    def loop_val(self) -> dict:
        tracker = MetricsTracker(prefix=Splits.val)
        net_before = deepcopy(self.net.state_dict())
        opt_before = deepcopy(self.opt_inner.state_dict())

        def reset_state():
            self.net.load_state_dict(deepcopy(net_before))
            self.opt_inner.load_state_dict(deepcopy(opt_before))

        for task in MetaBatch(self.batch_val, self.device).get_tasks():
            metrics = self.loop_inner(task)
            tracker.store(metrics)
            reset_state()

        return tracker.get_average()
예제 #6
0
    def loop_eval(self, data_split: str) -> dict:
        tracker = MetricsTracker(prefix=data_split)
        net_before = deepcopy(self.net.state_dict())
        opt_before = deepcopy(self.opt_inner.state_dict())

        def reset_state():
            self.net.load_state_dict(deepcopy(net_before))
            self.opt_inner.load_state_dict(deepcopy(opt_before))

        for task in MetaBatch(self.samples[data_split],
                              self.device).get_tasks():
            metrics = self.loop_inner(task)
            tracker.store(metrics)
            reset_state()

        return tracker.get_average()
예제 #7
0
    def loop_outer(self):
        steps = self.hparams.steps_outer
        interval = steps // 10
        loader = self.loaders[Splits.train]
        tracker = MetricsTracker(prefix=Splits.train)

        for i, batch in enumerate(
                generate(loader, limit=steps, show_progress=True)):
            self.opt_outer.zero_grad()
            for task in MetaBatch(batch, self.device).get_tasks():
                metrics = self.loop_inner(task, self.hparams.bs_inner,
                                          self.hparams.steps_inner)
                tracker.store(metrics)
                self.opt_outer.store_grad()
            self.opt_outer.step(i)

            if i % interval == 0:
                metrics = tracker.get_average()
                tracker.reset()
                metrics.update(self.loop_val())
                print({k: round(v, 3) for k, v in metrics.items()})
예제 #8
0
 def analyze_task(self, meta_batch: MetaBatch):
     st.header("Task Analysis")
     st.write(
         """
     Each task or episode is created by randomly sampling a subset of classes
     from the original dataset, with a fixed number of examples per class.
     This is to simulate the few-shot learning setting across diverse tasks.
     The meta-learner is trained to generalize quickly to unseen tasks.
     """
     )
     st.write(f"Number of samples per class: {self.hp.num_shots}")
     st.write(f"Number of classes: {self.hp.num_ways}")
     st.write(f"Train/val/test samples: **{self.hp.num_ways * self.hp.num_shots}**")
     tasks_all = meta_batch.get_tasks()
     i = st.selectbox(label="Task Index", options=range(len(tasks_all))[::-1])
     task = tasks_all[i]
     self.show_task(task)
     st.write("Overall test Set performance for this task:")
     st.write((self.loop_inner(task)))
     st.subheader("Model Output Visualization")
     self.analyze_model_outputs(task)
     return task