class NoComponentsEWC(NoComponentsLearner):
    def __init__(self, net, ewc_lambda=1e-3, results_dir='./tmp/results/'):
        super().__init__(net, results_dir)
        self.preconditioner = KFAC_EWC(self.net.components,
                                       ewc_lambda=ewc_lambda)

    def train(self,
              trainloader,
              task_id,
              component_update_freq=100,
              num_epochs=100,
              save_freq=1,
              testloaders=None):
        if task_id not in self.observed_tasks:
            self.observed_tasks.add(task_id)
            self.T += 1
        eval_bool = testloaders is not None
        if eval_bool:
            self.evaluate(testloaders)
        self.save_data(0, task_id, save_eval=eval_bool)
        if self.T <= self.net.num_init_tasks:
            self.init_train(trainloader, task_id, num_epochs, save_freq,
                            testloaders)
        else:
            iter_cnt = 0
            for i in range(num_epochs):
                for X, Y in trainloader:
                    X = X.to(self.net.device, non_blocking=True)
                    Y = Y.to(self.net.device, non_blocking=True)
                    Y_hat = self.net(X, task_id=task_id)
                    l = self.loss(Y_hat, Y)
                    self.optimizer.zero_grad()
                    self.preconditioner.zero_grad()
                    l.backward()
                    self.preconditioner.step(task_id,
                                             update_stats=False,
                                             update_params=True)
                    self.optimizer.step()
                    iter_cnt += 1
                if i % save_freq == 0 or i == num_epochs - 1:
                    if eval_bool:
                        self.evaluate(testloaders)
                    self.save_data(i + 1, task_id, save_eval=eval_bool)
            self.update_multitask_cost(trainloader, task_id)

    def update_multitask_cost(self, loader, task_id):
        for X, Y in loader:
            X = X.to(self.net.device, non_blocking=True)
            Y = Y.to(self.net.device, non_blocking=True)
            Y_hat = self.net(X, task_id=task_id)
            l = self.loss(Y_hat, Y)
            self.preconditioner.zero_grad()
            l.backward()
            self.preconditioner.step(task_id,
                                     update_stats=True,
                                     update_params=False)
            break
class CompositionalEWC(CompositionalLearner):
    def __init__(self, net, ewc_lambda=1e-3, results_dir='./tmp/results/'):
        super().__init__(net, results_dir)
        self.preconditioner = KFAC_EWC(self.net.components, ewc_lambda=ewc_lambda)

    def update_modules(self, trainloader, task_id):
        self.net.freeze_modules(freeze=False)
        self.net.freeze_structure(freeze=True)

        for X, Y in trainloader:
            X = X.to(self.net.device, non_blocking=True)
            Y = Y.to(self.net.device, non_blocking=True)
            Y_hat = self.net(X, task_id=task_id)
            l = self.loss(Y_hat, Y)
            self.optimizer.zero_grad()
            self.preconditioner.zero_grad()
            l.backward()
            self.preconditioner.step(task_id, update_stats=False, update_params=True)
            self.optimizer.step()

        self.net.freeze_modules(freeze=True)
        self.net.freeze_structure(freeze=False, task_id=task_id)

    def update_multitask_cost(self, loader, task_id):
        self.net.freeze_modules(freeze=False)
        self.net.freeze_structure(freeze=True)
        for X, Y in loader:
            X = X.to(self.net.device, non_blocking=True)
            Y = Y.to(self.net.device, non_blocking=True)
            Y_hat = self.net(X, task_id=task_id)
            l = self.loss(Y_hat, Y)
            self.preconditioner.zero_grad()
            l.backward()
            self.preconditioner.step(task_id, update_stats=True, update_params=False)
            break

        self.net.freeze_modules(freeze=True)
        self.net.freeze_structure(freeze=False, task_id=task_id)
 def __init__(self, net, ewc_lambda=1e-3, results_dir='./tmp/results/'):
     super().__init__(net, results_dir)
     self.preconditioner = KFAC_EWC(self.net.components,
                                    ewc_lambda=ewc_lambda)