Beispiel #1
0
    def test_module_clone_shared_params(self):
        # Tests proper use of memo parameter

        class TestModule(torch.nn.Module):
            def __init__(self):
                super(TestModule, self).__init__()
                cnn = [
                    torch.nn.Conv2d(3, 32, 3, 2, 1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(32, 32, 3, 2, 1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(32, 32, 3, 2, 1),
                    torch.nn.ReLU(),
                ]
                self.seq = torch.nn.Sequential(*cnn)
                self.head = torch.nn.Sequential(*[
                    torch.nn.Conv2d(32, 32, 3, 2, 1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(32, 100, 3, 2, 1)
                ])
                self.net = torch.nn.Sequential(self.seq, self.head)

            def forward(self, x):
                return self.net(x)

        original = TestModule()
        clone = l2l.clone_module(original)
        self.assertTrue(
            len(list(clone.parameters())) == len(list(original.parameters())),
            'clone and original do not have same number of parameters.',
        )

        orig_params = [p.data_ptr() for p in original.parameters()]
        duplicates = [p.data_ptr() in orig_params for p in clone.parameters()]
        self.assertTrue(not any(duplicates), 'clone() forgot some parameters.')
Beispiel #2
0
def do_eval(model, train_dl):
    adaptation_step = 5
    step_size = 0.1
    # model.linear.bias.register_hook(investigate)
    model.eval()
    accuracies = []

    mean = lambda x: sum(x) / len(x)

    for i, batch in enumerate(train_dl):
        for task_id, (support, query, support_targets,
                      query_targets) in enumerate(batch):
            support = support.to(device)  # B,C,W,H
            query = query.to(device)
            support_targets = support_targets.to(device)
            query_targets = query_targets.to(device)

            clone = clone_module(model)
            clone = adapt(clone,
                          support,
                          support_targets,
                          adaptation_steps=adaptation_step,
                          step_size=step_size)

            logits = clone(query)
            acc = accuracy(logits, query_targets)
            accuracies.append(acc)

    return mean(accuracies)
Beispiel #3
0
    def clone(
            self,
            first_order=None,
            allow_unused=None,
            allow_nograd=None,
            adapt_transform=None,
            ):
        """
        **Description**

        Similar to `MAML.clone()`.

        **Arguments**

        * **first_order** (bool, *optional*, default=None) - Whether the clone uses first-
            or second-order updates. Defaults to self.first_order.
        * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation
        of unused parameters. Defaults to self.allow_unused.
        * **allow_nograd** (bool, *optional*, default=False) - Whether to allow adaptation with
            parameters that have `requires_grad = False`. Defaults to self.allow_nograd.

        """
        if first_order is None:
            first_order = self.first_order
        if allow_unused is None:
            allow_unused = self.allow_unused
        if allow_nograd is None:
            allow_nograd = self.allow_nograd
        if adapt_transform is None:
            adapt_transform = self.adapt_transform
        module_clone = l2l.clone_module(self.module)
        update_clone = l2l.clone_module(self.compute_update)
        return GBML(
            module=module_clone,
            transform=self.transform,
            lr=self.lr,
            adapt_transform=adapt_transform,
            first_order=first_order,
            allow_unused=allow_unused,
            allow_nograd=allow_nograd,
            compute_update=update_clone,
        )
Beispiel #4
0
def meta_surrogate_loss(iteration_replays, iteration_policies, policy,
                        baseline, tau, gamma, adapt_lr):
    mean_loss = 0.0
    mean_kl = 0.0
    for task_replays, old_policy in tqdm(zip(iteration_replays,
                                             iteration_policies),
                                         total=len(iteration_replays),
                                         desc='Surrogate Loss',
                                         leave=False):
        policy.reset_context()
        train_replays = task_replays[:-1]
        valid_episodes = task_replays[-1]
        new_policy = l2l.clone_module(policy)

        # Fast Adapt
        for train_episodes in train_replays:
            new_policy = fast_adapt_a2c(new_policy,
                                        train_episodes,
                                        adapt_lr,
                                        baseline,
                                        gamma,
                                        tau,
                                        first_order=False)

        # Useful values
        states = valid_episodes.state()
        actions = valid_episodes.action()
        next_states = valid_episodes.next_state()
        rewards = valid_episodes.reward()
        dones = valid_episodes.done()

        # Compute KL
        old_densities = old_policy.density(states)
        new_densities = new_policy.density(states)
        kl = kl_divergence(new_densities, old_densities).mean()
        mean_kl += kl

        # Compute Surrogate Loss
        advantages = compute_advantages(baseline, tau, gamma, rewards, dones,
                                        states, next_states)
        advantages = ch.normalize(advantages).detach()
        old_log_probs = old_densities.log_prob(actions).mean(
            dim=1, keepdim=True).detach()
        new_log_probs = new_densities.log_prob(actions).mean(dim=1,
                                                             keepdim=True)
        mean_loss += trpo.policy_loss(new_log_probs, old_log_probs, advantages)
    mean_kl /= len(iteration_replays)
    mean_loss /= len(iteration_replays)
    return mean_loss, mean_kl
Beispiel #5
0
    def test_clone_module_nomodule(self):
        # Tests that we can clone non-module objects
        class TrickyModule(torch.nn.Module):
            def __init__(self):
                super(TrickyModule, self).__init__()
                self.tricky_modules = torch.nn.ModuleList([
                    torch.nn.Linear(2, 1),
                    None,
                    torch.nn.Linear(1, 1),
                ])

        model = TrickyModule()
        clone = l2l.clone_module(model)
        for i, submodule in enumerate(clone.tricky_modules):
            if i % 2 == 0:
                self.assertTrue(submodule is not None)
            else:
                self.assertTrue(submodule is None)
Beispiel #6
0
def meta_surrogate_loss(iter_replays, iter_policies, policy, baseline, params,
                        anil):
    mean_loss = 0.0
    mean_kl = 0.0
    for task_replays, old_policy in zip(iter_replays, iter_policies):
        train_replays = task_replays[:-1]
        valid_episodes = task_replays[-1]
        new_policy = clone_module(policy)

        # Fast Adapt to the training episodes
        for train_episodes in train_replays:
            new_policy = trpo_update(train_episodes,
                                     new_policy,
                                     baseline,
                                     params['inner_lr'],
                                     params['gamma'],
                                     params['tau'],
                                     anil=anil,
                                     first_order=False)

        # Calculate KL from the validation episodes
        states, actions, rewards, dones, next_states = get_episode_values(
            valid_episodes)

        # Compute KL
        old_densities = old_policy.density(states)
        new_densities = new_policy.density(states)
        kl = kl_divergence(new_densities, old_densities).mean()
        mean_kl += kl

        # Compute Surrogate Loss
        advantages = compute_advantages(baseline, params['tau'],
                                        params['gamma'], rewards, dones,
                                        states, next_states)
        advantages = ch.normalize(advantages).detach()
        old_log_probs = old_densities.log_prob(actions).mean(
            dim=1, keepdim=True).detach()
        new_log_probs = new_densities.log_prob(actions).mean(dim=1,
                                                             keepdim=True)
        mean_loss += trpo.policy_loss(new_log_probs, old_log_probs, advantages)

    mean_kl /= len(iter_replays)
    mean_loss /= len(iter_replays)
    return mean_loss, mean_kl
Beispiel #7
0
    def test_module_update_shared_params(self):
        # Tests proper use of memo parameter

        class TestModule(torch.nn.Module):
            def __init__(self):
                super(TestModule, self).__init__()
                cnn = [
                    torch.nn.Conv2d(3, 32, 3, 2, 1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(32, 32, 3, 2, 1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(32, 32, 3, 2, 1),
                    torch.nn.ReLU(),
                ]
                self.seq = torch.nn.Sequential(*cnn)
                self.head = torch.nn.Sequential(*[
                    torch.nn.Conv2d(32, 32, 3, 2, 1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(32, 100, 3, 2, 1)
                ])
                self.net = torch.nn.Sequential(self.seq, self.head)

            def forward(self, x):
                return self.net(x)

        original = TestModule()
        num_original = len(list(original.parameters()))
        clone = l2l.clone_module(original)
        updates = [torch.randn_like(p) for p in clone.parameters()]
        l2l.update_module(clone, updates)
        num_clone = len(list(clone.parameters()))
        self.assertTrue(
            num_original == num_clone,
            'clone and original do not have same number of parameters.',
        )
        for p, c, u in zip(original.parameters(), clone.parameters(), updates):
            self.assertTrue(
                torch.norm(p + u - c, p=2) <= EPSILON,
                'clone is not original + update.')

        orig_params = [p.data_ptr() for p in original.parameters()]
        duplicates = [p.data_ptr() in orig_params for p in clone.parameters()]
        self.assertTrue(not any(duplicates), 'clone() forgot some parameters.')
Beispiel #8
0
    def test_clone_module_models(self):
        ref_models = [
            l2l.vision.models.OmniglotCNN(10),
            l2l.vision.models.MiniImagenetCNN(10)
        ]
        l2l_models = [copy.deepcopy(m) for m in ref_models]
        inputs = [torch.randn(5, 1, 28, 28), torch.randn(5, 3, 84, 84)]

        # Compute reference gradients
        ref_grads = []
        for model, X in zip(ref_models, inputs):
            for iteration in range(10):
                model.zero_grad()
                clone = ref_clone_module(model)
                out = clone(X)
                out.norm(p=2).backward()
                self.optimizer_step(model,
                                    [p.grad for p in model.parameters()])
                ref_grads.append(
                    [p.grad.clone().detach() for p in model.parameters()])

        # Compute cloned gradients
        l2l_grads = []
        for model, X in zip(l2l_models, inputs):
            for iteration in range(10):
                model.zero_grad()
                clone = l2l.clone_module(model)
                out = clone(X)
                out.norm(p=2).backward()
                self.optimizer_step(model,
                                    [p.grad for p in model.parameters()])
                l2l_grads.append(
                    [p.grad.clone().detach() for p in model.parameters()])

        # Compare gradients and model parameters
        for ref_g, l2l_g in zip(ref_grads, l2l_grads):
            for r_g, l_g in zip(ref_g, l2l_g):
                self.assertTrue(torch.equal(r_g, l_g))
        for ref_model, l2l_model in zip(ref_models, l2l_models):
            for ref_p, l2l_p in zip(ref_model.parameters(),
                                    l2l_model.parameters()):
                self.assertTrue(torch.equal(ref_p, l2l_p))
Beispiel #9
0
def do_train(model, train_dl):
    adaptation_step = 5
    step_size = 0.1
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    ce = torch.nn.CrossEntropyLoss()
    model.train()

    mean = lambda x: sum(x) / len(x)
    for i, batch in enumerate(train_dl):
        optimizer.zero_grad()
        batch_accuracies = []
        batch_losses = []
        for task_id, (support, query, support_targets,
                      query_targets) in enumerate(batch):
            support = support.to(device)  # B,C,W,H
            query = query.to(device)
            support_targets = support_targets.to(device)
            query_targets = query_targets.to(device)

            clone = clone_module(model)
            clone = adapt(clone,
                          support,
                          support_targets,
                          adaptation_steps=adaptation_step,
                          step_size=step_size)

            logits = clone(query)
            outer_loss = ce(logits, query_targets)
            outer_loss.backward()

            acc = accuracy(logits, query_targets)
            batch_accuracies.append(acc)
            batch_losses.append(outer_loss.item())

        if i % 50 == 0:
            print('@ {} | {:.2f} {:.4f}'.format(i, mean(batch_accuracies),
                                                mean(batch_losses)))

        optimizer.step()
Beispiel #10
0
    def test_clone_module_basics(self):
        original_output = self.model(self.input)
        original_loss = self.loss_func(original_output, torch.tensor([[0., 0.]]))
        original_gradients = torch.autograd.grad(original_loss,
                                                 self.model.parameters(),
                                                 retain_graph=True,
                                                 create_graph=True)

        cloned_model = l2l.clone_module(self.model)
        self.optimizer_step(self.model, original_gradients)

        cloned_output = cloned_model(self.input)
        cloned_loss = self.loss_func(cloned_output, torch.tensor([[0., 0.]]))

        cloned_gradients = torch.autograd.grad(cloned_loss,
                                               cloned_model.parameters(),
                                               retain_graph=True,
                                               create_graph=True)

        self.optimizer_step(cloned_model, cloned_gradients)

        for a, b in zip(self.model.parameters(), cloned_model.parameters()):
            self.assertTrue(torch.equal(a, b))
Beispiel #11
0
def main(env_name='Particles2D-v1',
         adapt_lr=0.1,
         meta_lr=1.0,
         adapt_steps=1,
         num_iterations=1000,
         meta_bsz=20,
         adapt_bsz=20,
         tau=1.00,
         gamma=0.95,
         seed=42,
         num_workers=10,
         cuda=0,
         num_context_params=2):
    cuda = bool(cuda)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device_name = 'cpu'
    if cuda:
        torch.cuda.manual_seed(seed)
        device_name = 'cuda'
    device = torch.device(device_name)

    def make_env():
        env = gym.make(env_name)
        env = ch.envs.ActionSpaceScaler(env)
        return env

    env = l2l.gym.AsyncVectorEnv([make_env for _ in range(num_workers)])
    env.seed(seed)
    env.set_task(env.sample_tasks(1)[0])
    env = ch.envs.Torch(env)
    policy = CaviaDiagNormalPolicy(env.state_size,
                                   env.action_size,
                                   num_context_params=num_context_params,
                                   device=device)
    baseline = LinearValue(env.state_size, env.action_size)

    for iteration in range(num_iterations):
        iteration_reward = 0.0
        iteration_replays = []
        iteration_policies = []

        for task_config in tqdm(env.sample_tasks(meta_bsz),
                                leave=False,
                                desc='Data'):  # Samples a new config
            # deepcopy is not working here
            clone = l2l.clone_module(policy)

            env.set_task(task_config)
            env.reset()
            policy.reset_context()
            task = ch.envs.Runner(env)
            task_replay = []

            # Fast Adapt
            for step in range(adapt_steps):
                train_episodes = task.run(clone, episodes=adapt_bsz)
                if cuda:
                    train_episodes = train_episodes.to(device,
                                                       non_blocking=True)
                clone = fast_adapt_a2c(clone,
                                       train_episodes,
                                       adapt_lr,
                                       baseline,
                                       gamma,
                                       tau,
                                       first_order=True)
                task_replay.append(train_episodes)

            # Compute Validation Loss
            valid_episodes = task.run(clone, episodes=adapt_bsz)
            task_replay.append(valid_episodes)
            iteration_reward += valid_episodes.reward().sum().item(
            ) / adapt_bsz
            iteration_replays.append(task_replay)
            iteration_policies.append(clone)

        # Print statistics
        print('\nIteration', iteration)
        adaptation_reward = iteration_reward / meta_bsz
        print('adaptation_reward', adaptation_reward)

        # TRPO meta-optimization
        backtrack_factor = 0.8
        ls_max_steps = 15
        max_kl = 0.01
        if cuda:
            baseline = baseline.to(device, non_blocking=True)
            iteration_replays = [[
                r.to(device, non_blocking=True) for r in task_replays
            ] for task_replays in iteration_replays]

        # Compute CG step direction
        old_loss, old_kl = meta_surrogate_loss(iteration_replays,
                                               iteration_policies, policy,
                                               baseline, tau, gamma, adapt_lr)
        grad = autograd.grad(old_loss, policy.parameters(), retain_graph=True)
        grad = parameters_to_vector([g.detach() for g in grad])
        Fvp = trpo.hessian_vector_product(old_kl, policy.parameters())
        step = trpo.conjugate_gradient(Fvp, grad)
        shs = 0.5 * torch.dot(step, Fvp(step))
        lagrange_multiplier = torch.sqrt(shs / max_kl)
        step = step / lagrange_multiplier
        step_ = [torch.zeros_like(p.data) for p in policy.parameters()]
        vector_to_parameters(step, step_)
        step = step_
        del old_kl, Fvp, grad
        old_loss.detach_()

        # Line-search
        for ls_step in range(ls_max_steps):
            stepsize = backtrack_factor**ls_step * meta_lr
            clone = l2l.clone_module(policy)

            for p, u in zip(clone.parameters(), step):
                p.data.add_(-stepsize, u.data)
            new_loss, kl = meta_surrogate_loss(iteration_replays,
                                               iteration_policies, clone,
                                               baseline, tau, gamma, adapt_lr)
            if new_loss < old_loss and kl < max_kl:
                for p, u in zip(policy.parameters(), step):
                    p.data.add_(-stepsize, u.data)
                break
Beispiel #12
0
def main(
    fast_lr: float = 0.1,
    meta_lr: float = 0.003,
    num_iterations: int = 40000,
    meta_batch_size: int = 16,
    adaptation_steps: int = 5,
    dataset: str = 'cifarfs',
    layers: int = 4,
    shots: int = 5,
    ways: int = 5,
    cuda: int = 1,  # 0 or 1 only
    seed: int = 1234,
):
    args = dict(locals())
    wandb.init(
        project='kfo',
        group='MAML',
        name='maml-' + dataset,
        config=args,
    )
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device('cpu')
    if cuda and torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        device = torch.device('cuda')

    # Create Tasksets using the benchmark interface
    tasksets = l2l.vision.benchmarks.get_tasksets(
        name=dataset,
        train_samples=2 * shots,
        train_ways=ways,
        test_samples=2 * shots,
        test_ways=ways,
        root='~/data',
    )

    # Create model and learnable update
    if dataset == 'cifarfs':
        model = CIFARCNN(output_size=ways, hidden_size=32, layers=layers)
    elif dataset == 'mini-imagenet':
        model = l2l.vision.models.CNN4(
            output_size=ways,
            hidden_size=32,
            layers=layers,
        )
    model.to(device)
    kfo_transform = l2l.optim.transforms.KroneckerTransform(
        kronecker_cls=l2l.nn.KroneckerLinear,
        psd=True,
    )
    fast_update = l2l.optim.ParameterUpdate(
        parameters=model.parameters(),
        transform=kfo_transform,
    )
    fast_update.to(device)
    diff_sgd = l2l.optim.DifferentiableSGD(lr=fast_lr)

    all_parameters = list(model.parameters()) + list(fast_update.parameters())
    opt = torch.optim.Adam(all_parameters, meta_lr)
    loss = torch.nn.CrossEntropyLoss(reduction='mean')

    for iteration in tqdm.trange(num_iterations):
        opt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0
        for task in range(meta_batch_size):
            # Compute meta-training loss
            task_model = l2l.clone_module(model)
            task_update = l2l.clone_module(fast_update)
            batch = tasksets.train.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(
                batch,
                task_model,
                task_update,
                diff_sgd,
                loss,
                adaptation_steps,
                shots,
                ways,
                device,
            )
            evaluation_error.backward()
            meta_train_error += evaluation_error.item()
            meta_train_accuracy += evaluation_accuracy.item()

            # Compute meta-validation loss
            task_model = l2l.clone_module(model)
            task_update = l2l.clone_module(fast_update)
            batch = tasksets.validation.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(
                batch,
                task_model,
                task_update,
                diff_sgd,
                loss,
                adaptation_steps,
                shots,
                ways,
                device,
            )
            meta_valid_error += evaluation_error.item()
            meta_valid_accuracy += evaluation_accuracy.item()

        # log some metrics
        wandb.log(
            {
                'Train Error': meta_train_error / meta_batch_size,
                'Train Accuracy': meta_train_accuracy / meta_batch_size,
                'Validation Error': meta_valid_error / meta_batch_size,
                'Validation Accuracy': meta_valid_accuracy / meta_batch_size,
            },
            step=iteration)

        # Average the accumulated gradients and optimize
        for p in model.parameters():
            p.grad.data.mul_(1.0 / meta_batch_size)
        for p in fast_update.parameters():
            p.grad.data.mul_(1.0 / meta_batch_size)
        opt.step()

        meta_test_error = 0.0
        meta_test_accuracy = 0.0
        for task in range(meta_batch_size):
            # Compute meta-testing loss
            task_model = l2l.clone_module(model)
            task_update = l2l.clone_module(fast_update)
            batch = tasksets.test.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(
                batch,
                task_model,
                task_update,
                diff_sgd,
                loss,
                adaptation_steps,
                shots,
                ways,
                device,
            )
            meta_test_error += evaluation_error.item()
            meta_test_accuracy += evaluation_accuracy.item()
        wandb.log(
            {
                'Test Error': meta_test_error / meta_batch_size,
                'Test Accuracy': meta_test_accuracy / meta_batch_size,
            },
            step=iteration)
Beispiel #13
0
def main(
    fast_lr=0.1,
    meta_lr=0.003,
    num_iterations=10000,
    meta_batch_size=16,
    adaptation_steps=5,
    shots=5,
    ways=5,
    cuda=1,
    seed=1234
    ):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device('cpu')
    if cuda and torch.cuda.device_count():
        torch.cuda.manual_seed(seed)
        device = torch.device('cuda')

    # Create Tasksets using the benchmark interface
    tasksets = l2l.vision.benchmarks.get_tasksets(
        name='cifarfs',
        train_samples=2*shots,
        train_ways=ways,
        test_samples=2*shots,
        test_ways=ways,
        root='~/data',
    )

    # Create model and learnable update
    model = CifarCNN(output_size=ways)
    model.to(device)
    features = model.features
    classifier = model.linear
    kfo_transform = l2l.optim.transforms.KroneckerTransform(l2l.nn.KroneckerLinear)
    fast_update = l2l.optim.ParameterUpdate(
        parameters=classifier.parameters(),
        transform=kfo_transform,
    )
    fast_update.to(device)
    diff_sgd = l2l.optim.DifferentiableSGD(lr=fast_lr)

    all_parameters = list(model.parameters()) + list(fast_update.parameters())
    opt = torch.optim.Adam(all_parameters, meta_lr)
    loss = torch.nn.CrossEntropyLoss(reduction='mean')

    for iteration in range(num_iterations):
        opt.zero_grad()
        meta_train_error = 0.0
        meta_train_accuracy = 0.0
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0
        for task in range(meta_batch_size):
            # Compute meta-training loss
            task_features = l2l.clone_module(features)
            task_classifier = l2l.clone_module(classifier)
            task_update = l2l.clone_module(fast_update)
            batch = tasksets.train.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                               task_features,
                                                               task_classifier,
                                                               task_update,
                                                               diff_sgd,
                                                               loss,
                                                               adaptation_steps,
                                                               shots,
                                                               ways,
                                                               device)
            evaluation_error.backward()
            meta_train_error += evaluation_error.item()
            meta_train_accuracy += evaluation_accuracy.item()

            # Compute meta-validation loss
            task_features = l2l.clone_module(features)
            task_classifier = l2l.clone_module(classifier)
            task_update = l2l.clone_module(fast_update)
            batch = tasksets.validation.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                               task_features,
                                                               task_classifier,
                                                               task_update,
                                                               diff_sgd,
                                                               loss,
                                                               adaptation_steps,
                                                               shots,
                                                               ways,
                                                               device)
            meta_valid_error += evaluation_error.item()
            meta_valid_accuracy += evaluation_accuracy.item()

        # Print some metrics
        print('\n')
        print('Iteration', iteration)
        print('Meta Train Error', meta_train_error / meta_batch_size)
        print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size)
        print('Meta Valid Error', meta_valid_error / meta_batch_size)
        print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size)

        # Average the accumulated gradients and optimize
        for p in model.parameters():
            p.grad.data.mul_(1.0 / meta_batch_size)
        for p in fast_update.parameters():
            p.grad.data.mul_(1.0 / meta_batch_size)
        opt.step()

    meta_test_error = 0.0
    meta_test_accuracy = 0.0
    for task in range(meta_batch_size):
        # Compute meta-testing loss
        task_features = l2l.clone_module(features)
        task_classifier = l2l.clone_module(classifier)
        task_update = l2l.clone_module(fast_update)
        batch = tasksets.test.sample()
        evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                           task_features,
                                                           task_classifier,
                                                           task_update,
                                                           diff_sgd,
                                                           loss,
                                                           adaptation_steps,
                                                           shots,
                                                           ways,
                                                           device)
        meta_test_error += evaluation_error.item()
        meta_test_accuracy += evaluation_accuracy.item()
    print('Meta Test Error', meta_test_error / meta_batch_size)
    print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size)