Exemplo n.º 1
0
def dp_sgd_backward(dis, real_imgs, fake_imgs, device, clip_norm, noise_factor,
                    dis_opt):
    """
  since only one part of the loss depends on the data, we compute its gradients separately and noise them up
  in order to maintain similar gradient sizes, gradients for the other loss are clipped to the same per sample norm
  """
    # real data loss first:
    params = list(dis.parameters())
    loss_real = -pt.mean(dis(real_imgs))
    with backpack(BatchGrad(), BatchL2Grad()):
        loss_real.backward(retain_graph=True)

    squared_param_norms_real = [
        p.batch_l2 for p in params
    ]  # first we get all the squared parameter norms...
    global_norms_real = pt.sqrt(
        pt.sum(pt.stack(squared_param_norms_real),
               dim=0))  # ...then compute the global norms...
    global_clips_real = pt.clamp_max(
        clip_norm / global_norms_real,
        1.)  # ...and finally get a vector of clipping factors

    perturbed_grads = []
    for idx, param in enumerate(params):
        clipped_sample_grads = param.grad_batch * expand_vector(
            global_clips_real, param.grad_batch)
        clipped_grad = pt.sum(clipped_sample_grads,
                              dim=0)  # after clipping we sum over the batch

        noise_sdev = noise_factor * 2 * clip_norm  # gaussian noise standard dev is computed (sensitivity is 2*clip)...
        perturbed_grad = clipped_grad + pt.randn_like(
            clipped_grad, device=device) * noise_sdev  # ...and applied
        perturbed_grads.append(perturbed_grad)  # store perturbed grads

    dis_opt.zero_grad()
    # now add fake data loss gradients:
    loss_fake = pt.mean(dis(fake_imgs))
    with backpack(BatchGrad(), BatchL2Grad()):
        loss_fake.backward()

    squared_param_norms_fake = [
        p.batch_l2 for p in params
    ]  # first we get all the squared parameter norms...
    global_norms_fake = pt.sqrt(
        pt.sum(pt.stack(squared_param_norms_fake),
               dim=0))  # ...then compute the global norms...
    global_clips_fake = pt.clamp_max(
        clip_norm / global_norms_fake,
        1.)  # ...and finally get a vector of clipping factors

    for idx, param in enumerate(params):
        clipped_sample_grads = param.grad_batch * expand_vector(
            global_clips_fake, param.grad_batch)
        clipped_grad = pt.sum(clipped_sample_grads,
                              dim=0)  # after clipping we sum over the batch

        param.grad = clipped_grad + perturbed_grads[idx]

    ld = loss_real.item() + loss_fake.item()
    return global_norms_real, global_clips_real, global_norms_fake, global_clips_fake, ld
Exemplo n.º 2
0
def test_extension_hook_param_before_savefield_exists(problem):
    """Extension hooks iterating over parameters may get called before BackPACK.

    This leads to the case, that the BackPACK quantities might not be calculated yet.
    Thus, derived quantities cannot be calculated.

    Sequential containers just work fine.
    Custom containers crash.

    Args:
        problem: problem consisting of model, loss, and problem_string

    Raises:
        NotImplementedError: if problem_string is unknown
    """
    _, loss, problem_string = problem

    params_without_grad_batch = []

    def check_grad_batch(module):
        """Check whether the module has a grad_batch attribute.

        Args:
            module: the module to check

        Raises:
            AssertionError: if a parameter does not have grad_batch attribute.
        """
        for p in module.parameters():
            if not hasattr(p, "grad_batch"):
                params_without_grad_batch.append(id(p))
                raise AssertionError(
                    f"Param {id(p)} has no 'grad_batch' attribute")

    if problem_string == NESTED_SEQUENTIAL:
        with backpack(BatchGrad(), extension_hook=check_grad_batch,
                      debug=True):
            loss.backward()

        assert len(params_without_grad_batch) == 0
    elif problem_string == CUSTOM_CONTAINER:
        with raises(AssertionError):
            with backpack(BatchGrad(),
                          extension_hook=check_grad_batch,
                          debug=True):
                loss.backward()
        assert len(params_without_grad_batch) > 0
    else:
        raise NotImplementedError(f"unknown problem_string={problem_string}")
Exemplo n.º 3
0
def test_for_loop_replace() -> None:
    """Application of retain_graph: replace an outer for-loop.

    This test is based on issue #220 opened by Romain3Ch216.
    It computes per-component individual gradients of a tensor-valued output
    with a for loop over components, rather than over samples and components.
    """
    manual_seed(0)
    B = 5
    M = 3
    h = 2

    x = randn(B, h)
    fc = extend(Linear(h, M))
    A = fc(x)

    grad_autograd = zeros(B, M, *fc.weight.shape)
    for b in range(B):
        for m in range(M):
            with backpack(retain_graph=True):
                grads = autograd.grad(A[b, m], fc.weight, retain_graph=True)
            grad_autograd[b, m] = grads[0]

    grad_backpack = zeros(B, M, *fc.weight.shape)
    for i in range(M):
        with backpack(BatchGrad(), retain_graph=True):
            A[:, i].backward(ones_like(A[:, i]), retain_graph=True)
        grad_backpack[:, i] = fc.weight.grad_batch

    check_sizes_and_values(grad_backpack, grad_autograd)
Exemplo n.º 4
0
 def select(self, context):
     tensor = torch.from_numpy(context).float().cuda()
     mu = self.func(tensor)
     # calculate gradient
     sum_mu = torch.sum(mu)
     with backpack(BatchGrad()):
         sum_mu.backward()
     g_list = torch.cat(
         [
             p.grad_batch.flatten(start_dim=1).detach()
             for p in self.func.parameters()
         ],
         dim=1,
     )
     # calculate CB
     sigma = torch.sqrt(
         torch.diag(
             torch.matmul(torch.matmul(g_list, self.Uinv),
                          torch.transpose(g_list, 0, 1))))
     # calculate UCB and select the arm
     score = torch.normal(mu.view(-1), self.nu * sigma.view(-1))
     arm = torch.argmax(score)
     # update self.Uinv
     g = g_list[arm]
     self.Uinv = self.Uinv - (torch.matmul(
         torch.matmul(torch.matmul(self.Uinv, g.view(-1, 1)), g.view(
             1, -1)), self.Uinv)) / (1 + torch.matmul(
                 torch.matmul(g.view(1, -1), self.Uinv), g.view(-1, 1)))
     return arm, [], [], []
Exemplo n.º 5
0
    def update(self, rollouts, explore_policy, exploration):
        obs_shape = rollouts.obs.size()[2:]
        action_shape = rollouts.actions.size()[-1]
        num_steps, num_processes, _ = rollouts.rewards.size()

        qs = self.model(rollouts.obs.view(-1, *obs_shape)).view(
                num_steps + 1, num_processes, -1)
        values = qs[:-1].gather(-1, rollouts.actions).view(num_steps, num_processes, 1)
        
        probs, _ = explore_policy(qs[-1].detach(), exploration)
        next_values = (probs * qs[-1]).sum(-1).unsqueeze(-1).unsqueeze(0) 
        
        advantages = rollouts.returns[:-1] - values
        
        with backpack(BatchGrad()):
            self.optimizer.zero_grad()
            torch.cat([values, next_values], dim=0).sum().backward()
            # store the td errors and masks to use inside tdprop
            self.optimizer.temp_store_td_errors(advantages.detach())
            self.optimizer.temp_store_masks(rollouts.cumul_masks)
            # extract grads: grad = -2 * td * grad_v 
            self.optimizer.extract_grads_from_batch()
            total_norm = nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
            self.optimizer.clip_coef = self.max_grad_norm / (total_norm + 1e-6)
            self.optimizer.step()

        return advantages.pow(2).mean().item() 
def self_interferences(loss, params, t=False):
  with backpack(BatchGrad()):
    sumloss(loss).backward()
  grads = torch.cat([i.grad_batch.reshape((loss.shape[0], -1)) for i in params], 1)
  if t:
    return (grads ** 2).sum(1)
  return (grads ** 2).sum(1).cpu().data.numpy()
def cross_interferences_diag(lossA, lossB, params, t=False):
  """lossA and B must be callbacks, because for backpack to work it
  seems there must only be one computation graph associated with the
  parameters at a time!
  Or like it checks the most recent one only? I'm confused
  """
  with backpack(BatchGrad()):
    sumloss(lossA()).backward()
  gradsA = torch.cat([i.grad_batch.reshape((i.grad_batch.shape[0], -1)) for i in params], 1)
  with backpack(BatchGrad()):
    sumloss(lossB()).backward()
  gradsB = torch.cat([i.grad_batch.reshape((i.grad_batch.shape[0], -1)) for i in params], 1)
  x = (gradsA * gradsB).sum(1)
  gc.collect()
  if t: return x
  return x.cpu().data.numpy()
def interferences(loss, params):
  with backpack(BatchGrad()):
    sumloss(loss).backward()
  grads = torch.cat([i.grad_batch.reshape((loss.shape[0], -1)) for i in params], 1)
  all_dots = []
  for i in range(len(grads)-1):
    all_dots.append((grads[i][None, :] @ grads[i+1:].t())[0])
  return torch.cat(all_dots, 0).cpu().data.numpy()
Exemplo n.º 9
0
 def get_anchor_gradients(self, net, loss_func):
     """Get the n x p matrix of gradients based on public data."""
     public_inputs, public_targets = self.public_inputs, self.public_targets
     outputs = net(public_inputs)
     loss = loss_func(outputs, public_targets)
     with backpack(BatchGrad()):
         loss.backward()
     cur_batch_grad_list = []
     for p in net.parameters():
         cur_batch_grad_list.append(p.grad_batch)
         del p.grad_batch
     return flatten_tensor(cur_batch_grad_list)  # n x p
Exemplo n.º 10
0
 def _select(self, context):
     tensor = torch.from_numpy(context).float().to(self.device)
     mu = self.model(tensor)
     sum_mu = torch.sum(mu)
     with backpack(BatchGrad()):
         sum_mu.backward()
     g_list = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.model.parameters()], dim=1)
     sigma = torch.sqrt(torch.sum(g_list * g_list / self.A, dim=1))
     sample_r = torch.normal(mu.view(-1), self.alpha * sigma.view(-1))
     arm = torch.argmax(sample_r)
     self.A += g_list[arm] * g_list[arm]
     return arm, g_list[arm].norm().item(), 0, 0
 def get_anchor_gradients(self, net, loss_func):
     public_inputs, public_targets = self.public_inputs, self.public_targets
     outputs = net(public_inputs)
     loss = loss_func(outputs, public_targets)
     with backpack(BatchGrad()):
         loss.backward()
     cur_batch_grad_list = []
     for p in net.parameters():
         cur_batch_grad_list.append(
             p.grad_batch.reshape(p.grad_batch.shape[0], -1))
         del p.grad_batch
     return flatten_tensor(cur_batch_grad_list)
Exemplo n.º 12
0
 def select(self, context):
     tensor = torch.from_numpy(context).float().cuda()
     mu = self.func(tensor)
     sum_mu = torch.sum(mu)
     with backpack(BatchGrad()):
         sum_mu.backward()
     g_list = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
     sigma = torch.sqrt(torch.sum(self.lamdba * self.nu * g_list * g_list / self.U, dim=1))
     if self.style == 'ts':
         sample_r = torch.normal(mu.view(-1), sigma.view(-1))
     elif self.style == 'ucb':
         sample_r = mu.view(-1) + sigma.view(-1)
     arm = torch.argmax(sample_r)
     self.U += g_list[arm] * g_list[arm]
     return arm, g_list[arm].norm().item(), 0, 0
Exemplo n.º 13
0
    def extensions(self, global_step):
        """Return list of BackPACK extensions required for the computation.

        Args:
            global_step (int): The current iteration number.

        Returns:
            list: (Potentially empty) list with required BackPACK quantities.
        """
        ext = []

        if self.should_compute(global_step):
            ext.append(BatchGrad())

        return ext
Exemplo n.º 14
0
 def select(self, context):
     tensor = torch.from_numpy(context).float().cuda()
     mu = self.func(tensor)
     sum_mu = torch.sum(mu)
     with backpack(BatchGrad()):
         sum_mu.backward()
     g_list = torch.cat([
         p.grad_batch.flatten(start_dim=1).detach()
         for p in self.func.parameters()
     ],
                        dim=1)
     sigma = torch.sqrt(torch.sum(g_list * g_list / self.U, dim=1))
     sample_r = mu.view(-1) + self.nu * sigma.view(-1)
     arm = torch.argmax(sample_r)
     self.U += g_list[arm] * g_list[arm]
     # print("selected arm: {} {}".format(mu.view(-1)[arm], sigma[arm]))
     return arm, g_list[arm].norm().item(), 0, 0
Exemplo n.º 15
0
    def _train_batch_(self):
        (data, target) = self.train_loader_it.next()
        data, target = data.to(self.device), target.to(self.device)
        self.current_batch_size = data.size()[0]
        output = self.model(data)
        loss = self.loss_function(output, target)

        with backpack(BatchGrad()):
            loss.mean().backward()

        loss_value = loss.mean()

        self.loss_batch = loss
        self.current_training_loss = torch.unsqueeze(loss_value.detach(),
                                                     dim=0)
        self.train_batch_index += 1
        self._current_validation_loss.calculated = False
Exemplo n.º 16
0
 def decide(self, pool_articles, userID, k=1):
     self.a = len(pool_articles)
     
     user_vec = torch.cat(self.a*[self.user_feature[userID - 1].view(1, -1)])
     article_vec = torch.cat([
         torch.from_numpy(x.contextFeatureVector[:self.dimension]).view(1, -1).to(torch.float32)
         for x in pool_articles])
     score = self.learner(user_vec, article_vec.cuda()).view(-1)
     sum_score = torch.sum(score)
     with backpack(BatchGrad()):
         sum_score.backward()
     
     grad = torch.cat([p.grad_batch.view(self.a, -1) for p in self.learner.parameters()], dim=1)
     sigma = torch.sqrt(torch.sum(grad * grad / self.U, dim=1))
     self.reg = self.nu * torch.mean(sigma).item()
     arm = torch.argmax(score + self.nu * sigma).item()
     self.g = grad[arm]
     return [pool_articles[arm]]
Exemplo n.º 17
0
def test_retain_graph():
    """Tests whether retain_graph works as expected.

    Does several forward and backward passes.
    In between, it is tested whether BackPACK quantities are present or not.
    """
    manual_seed(0)
    model = extend(Sequential(Linear(4, 6), Linear(6, 5)))
    loss_fn = extend(CrossEntropyLoss())

    # after a forward pass graph is not clear
    inputs = rand(8, 4)
    labels = randint(5, (8, ))
    loss = loss_fn(model(inputs), labels)
    with raises(AssertionError):
        _check_no_io(model)

    # after a normal backward pass graph should be clear
    loss.backward()
    _check_no_io(model)

    # after a backward pass with retain_graph=True graph is not clear
    loss = loss_fn(model(inputs), labels)
    with backpack(retain_graph=True):
        loss.backward(retain_graph=True)
    with raises(AssertionError):
        _check_no_io(model)

    # doing several backward passes with retain_graph=True
    for _ in range(3):
        with backpack(retain_graph=True):
            loss.backward(retain_graph=True)
    with raises(AssertionError):
        _check_no_io(model)

    # finally doing a normal backward pass that verifies graph is clear again
    with backpack(BatchGrad()):
        loss.backward()
    _check_no_io(model)
Exemplo n.º 18
0
def main(args):
    print(args)
    assert args.dpsgd
    torch.backends.cudnn.benchmark = True

    train_data, train_labels = get_data(args)
    model = model_dict[args.experiment](vocab_size=args.max_features).cuda()
    model = extend(model)
    optimizer = DP_SGD(model.parameters(),
                       lr=args.learning_rate,
                       **dpsgd_kwargs[args.experiment])
    loss_function = nn.CrossEntropyLoss(
    ) if args.experiment != 'logreg' else nn.BCELoss()

    timings = []
    for epoch in range(1, args.epochs + 1):
        start = time.perf_counter()
        dataloader = data.dataloader(train_data, train_labels, args.batch_size)
        for batch_idx, (x, y) in enumerate(dataloader):
            x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
            model.zero_grad()
            outputs = model(x)
            loss = loss_function(outputs, y)
            with backpack(BatchGrad(), BatchL2Grad()):
                loss.backward()
            optimizer.step()
        torch.cuda.synchronize()
        duration = time.perf_counter() - start
        print("Time Taken for Epoch: ", duration)
        timings.append(duration)

    if not args.no_save:
        utils.save_runtimes(__file__.split('.')[0], args, timings)
    else:
        print('Not saving!')
    print('Done!')
Exemplo n.º 19
0
    def backward_and_step(self, vt, vtp, delta, gamma_k):
        """Performs backward pass and step for n-step TD(0), does not work for TD(lambda)
        Requires:
         - v: V(s_t)
         - vp: V(s_t+n)
         - delta: the TD delta, i.e. (v - (r + gamma*vp)) for TD(0), (v - (G^5 + gamma^5*vp)) for 5s-TD
             careful to profide (v - y) rather than (y - v)
         - gamma_k: the multiplicative factor on vp applied in delta, e.g. gamma^5 for 5-step TD,
             can be 0 e.g. if s_t+n is terminal
        """
        self.zero_grad()

        with backpack(BatchGrad()):
            torch.cat([vt, vtp], dim=0).sum().backward()
        mbs = vt.shape[0]
        batch_g = self.batch_p2v([i.grad_batch.data for i in self.parameters],
                                 mbs * 2)
        batch_dvt = batch_g[:mbs]
        batch_dvtp = batch_g[mbs:] * gamma_k[:, None]
        # derivative of delta^2
        batch_dL = 2 * delta[:, None] * batch_dvt
        dL = batch_dL.mean(0)

        if self.weight_decay:
            # We ignore weight decay when computing the correction
            # This should not be problematic but further testing required
            dL.add_(self.p2v(self.parameters).detach(),
                    alpha=self.weight_decay)

        self._step += 1
        bias_correction1 = 1 - self.beta1**self._step if self.mom_correct_bias else 1
        bias_correction2 = 1 - self.beta2**self._step

        # \/V(s) - \/V(s')
        gdiff = batch_dvt - batch_dvtp

        if self.do_correction:
            # last corrected momentum
            mu_tm1 = self.mu - self.eta

            # Update eta
            if self.diagonal:
                z = (gdiff * batch_dvt).mean(0)
                #z = (gdiff.mean(0) * batch_dvt.mean(0))
                self.eta.mul_(self.beta).add_(self.alpha * self.beta *
                                              (self.zeta * mu_tm1))  # eta
            elif self.block_size > 0:
                pad = torch.zeros((self.nparams_padding), device=self.device)
                batch_pad = torch.zeros((mbs, self.nparams_padding),
                                        device=self.device)
                batch_shape = mbs, self.nblocks, self.block_size
                # batch_block_shape = mbs, self.nblocks, self.block_size, self.block_size
                gdiff_padded = torch.cat([gdiff, batch_pad],
                                         1).reshape(batch_shape)
                batch_dvt_padded = torch.cat([batch_dvt, batch_pad],
                                             1).reshape(batch_shape)
                z = torch.einsum('ija,ijb->jab', gdiff_padded,
                                 batch_dvt_padded) / mbs
                mu_padded = torch.cat([mu_tm1, pad], 0).reshape(
                    (self.nblocks, self.block_size))
                zeta_T_times_mu = (torch.einsum('ikj,ik->ij', self.zeta,
                                                mu_padded).reshape(
                                                    (-1, ))[:self.nparams]
                                   )  # unpad
                self.eta.mul_(self.beta).add_(self.alpha * self.beta *
                                              zeta_T_times_mu)  # eta
            else:
                z = torch.einsum('ij,ik->jk', gdiff,
                                 batch_dvt) / mbs  # sum of outer product
                #z = torch.einsum('j,k->jk', gdiff.mean(0), batch_dvt.mean(0))
                self.eta.mul_(self.beta).add_(self.alpha * self.beta *
                                              (self.zeta.T @ mu_tm1))  # eta
            # Update zeta
            self.zeta.mul_(self.beta).add_(z, alpha=1 - self.dampening)

        # Update momentum mu
        self.mu.mul_(self.beta).add_(dL, alpha=1 - self.dampening)
        # Compute (corrected) momentum \mu_t
        mu_t = self.mu - self.eta if self.do_correction else self.mu

        # TDProp update
        if self.beta2 > 0:
            # diag[(\/V(s) - \/V(s')) ^T \/V(s)]
            diag_H = (2 * gdiff * batch_dvt).pow(2).mean(0)
            # Update TDprop denominator
            self.z_denom.mul_(self.beta2).add_(1 - self.beta2, diag_H)
            # Compute bias corrected TDprop denom
            denom = (self.z_denom.sqrt() / np.sqrt(bias_correction2)).add_(
                self.epsilon)
            # Update parameters
            for p, g, d in zip(self.parameters, self.v2p(mu_t),
                               self.v2p(denom)):
                p.data.addcdiv_(g, d, value=-(self.alpha / bias_correction1))
        # Normal or corrected momentum update
        else:
            for p, g in zip(self.parameters, self.v2p(mu_t)):
                p.data.add_(g, alpha=-(self.alpha / bias_correction1))
Exemplo n.º 20
0
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    t0 = time.time()
    steps = n_training//args.batchsize

    if(train_samples == None): # using pytorch data loader for CIFAR10
        loader = iter(trainloader)
    else: # manually sample minibatchs for SVHN
        sample_idxes = np.arange(n_training)
        np.random.shuffle(sample_idxes)
    is_first = 0
    for batch_idx in range(steps):
        if(args.dataset=='svhn'):
            current_batch_idxes = sample_idxes[batch_idx*args.batchsize : (batch_idx+1)*args.batchsize]
            inputs, targets = train_samples[current_batch_idxes], train_labels[current_batch_idxes]
        else:
            inputs, targets = next(loader)
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()

        if(args.private):
            logging = batch_idx % 20 == 0
            ## compute anchor subspace
            optimizer.zero_grad()
            if is_first ==0:
                net.gep.get_anchor_space(net, loss_func=loss_func, logging=logging)
                is_first+=1
            ## collect batch gradients
            batch_grad_list = []
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = loss_func(outputs, targets)
            with backpack(BatchGrad()):
                loss.backward()
            for p in net.parameters():
                batch_grad_list.append(p.grad_batch.reshape(p.grad_batch.shape[0], -1))
                del p.grad_batch
            ## compute gradient embeddings and residual gradients
            clipped_theta, residual_grad, target_grad = net.gep(flatten_tensor(batch_grad_list), logging = logging)
            ## add noise to guarantee differential privacy
            theta_noise = torch.normal(0, noise_multiplier0*args.clip0/args.batchsize, size=clipped_theta.shape, device=clipped_theta.device)
            grad_noise = torch.normal(0, noise_multiplier1*args.clip1/args.batchsize, size=residual_grad.shape, device=residual_grad.device)
            clipped_theta += theta_noise
            residual_grad += grad_noise
            ## update with Biased-GEP or GEP
            if(args.rgp):
                noisy_grad = gep.get_approx_grad(clipped_theta) + residual_grad
            else:
                noisy_grad = gep.get_approx_grad(clipped_theta)
            if(logging):
                print('target grad norm: %.2f, noisy approximation norm: %.2f'%(target_grad.norm().item(), noisy_grad.norm().item()))
            ## make use of noisy gradients
            offset = 0
            for p in net.parameters():
                shape = p.grad.shape
                numel = p.grad.numel()
                p.grad.data = noisy_grad[offset:offset+numel].view(shape) #+ 0.1*torch.mean(pub_grad, dim=0).view(shape)
                offset+=numel
        else:
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = loss_func(outputs, targets)
            loss.backward()
            try:
                for p in net.parameters():
                    del p.grad_batch
            except:
                pass
        optimizer.step()
        step_loss = loss.item()
        if(args.private):
            step_loss /= inputs.shape[0]
        train_loss += step_loss
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).float().cpu().sum()
        acc = 100.*float(correct)/float(total)
    t1 = time.time()
    print('Train loss:%.5f'%(train_loss/(batch_idx+1)), 'time: %d s'%(t1-t0), 'train acc:', acc, '\n')
    return (train_loss/batch_idx, acc)

model = extend(MyFirstResNet()).to(DEVICE)

# %%
# Using :py:class:`BatchGrad <backpack.extensions.BatchGrad>` in a
# :py:class:`with backpack(...) <backpack.backpack>` block,
# we can access the individual gradients for each sample.
#
# The loss does not need to be extended in this case either, as it does not
# have model parameters and BackPACK does not need to know about it for
# first order extensions. This also means you can use any custom loss function.

model.zero_grad()
loss = F.cross_entropy(model(x), y, reduction="sum")
with backpack(BatchGrad()):
    loss.backward()

print("{:<20}  {:<30} {:<30}".format("Param", "grad", "grad (batch)"))
print("-" * 80)
for name, p in model.named_parameters():
    print(
        "{:<20}: {:<30} {:<30}".format(name, str(p.grad.shape), str(p.grad_batch.shape))
    )

# %%
# To check that everything works, let's compute one individual gradient with
# PyTorch (using a single sample in a forward and backward pass)
# and compare it with the one computed by BackPACK.

sample_to_check = 1
Exemplo n.º 22
0
lossfunc = extend(lossfunc)

# %%
# Individual gradients for a mini-batch subset
# --------------------------------------------
#
# Let's say we only want to compute individual gradients for samples 0, 1,
# 13, and 42. Naively, we could perform the computation for all samples, then
# slice out the samples we care about.

# selected samples
subsampling = [0, 1, 13, 42]

loss = lossfunc(model(X), y)

with backpack(BatchGrad()):
    loss.backward()

# naive approach: compute for all, slice out relevant
naive = [p.grad_batch[subsampling] for p in model.parameters()]

# %%
# This is not efficient, as individual gradients are computed for all samples,
# most of them being discarded after. We can do better by specifying the active
# samples directly with the ``subsampling`` argument of
# :py:class:`BatchGrad <backpack.extensions.BatchGrad>`.

loss = lossfunc(model(X), y)

# efficient approach: specify active samples in backward pass
with backpack(BatchGrad(subsampling=subsampling)):
Exemplo n.º 23
0
            print("\tmodule.input0.shape:", module.input0.shape)
            # gradient w.r.t output
            print("\tg_out[0].shape:     ", g_out[0].shape)

        # actual computation
        return (g_out[0] *
                module.input0).flatten(start_dim=1).sum(axis=1).unsqueeze(-1)


# %%
# Lastly, we need to register the mapping between layer (``ScaleModule``) and layer
# extension (``ScaleModuleBatchGrad``) in an instance of
# :py:class:`BatchGrad <backpack.extensions.BatchGrad>`.

# register module-computation mapping
extension = BatchGrad()
extension.set_module_extension(ScaleModule, ScaleModuleBatchGrad())

# %%
# That's it. We can now pass ``extension`` to a
# :py:class:`with backpack(...) <backpack.backpack>` context and compute individual
# gradients with respect to ``ScaleModule``'s ``weight`` parameter.

# %%
# Test custom module
# ------------------
# Here, we verify the custom module extension on a small net with random inputs.
# Let's create these.

batch_size = 10
batch_axis = 0
Exemplo n.º 24
0
                x = self.linear1(x)
                x = self.linear2(x)
                return x

        model = _MyCustomModule()
    else:
        raise NotImplementedError(
            f"problem={problem_string} but no test setting for this.")

    model = extend(model.to(device))
    lossfunc = extend(CrossEntropyLoss(reduction="mean").to(device))
    loss = lossfunc(model(X), y)
    yield model, loss, problem_string


@mark.parametrize("extension", [BatchGrad(), DiagGGNExact()],
                  ids=["BatchGrad", "DiagGGNExact"])
def test_extension_hook_multiple_parameter_visits(
        problem, extension: BackpropExtension):
    """Tests whether each parameter is visited exactly once.

    For those cases where parameters are visited more than once (e.g. Custom containers),
    it tests that an error is raised.

    Furthermore, it is tested whether first order extensions run fine in either case,
    and second order extensions raise an error in the case of custom containers.

    Args:
        problem: test problem, consisting of model, loss, and problem_string
        extension: first or second order extension to test
Exemplo n.º 25
0
def train(epoch):
    torch.set_printoptions(precision=16)
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    step_st_time = time.time()
    epoch_time = 0
    print('\nKFAC/KBFGS damping: %f' % damping)
    print('\nNGD damping: %f' % (damping))

    # 
    desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (tag, lr_scheduler.get_last_lr()[0], 0, 0, correct, total))

    writer.add_scalar('train/lr', lr_scheduler.get_last_lr()[0], epoch)

    prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True)
    for batch_idx, (inputs, targets) in prog_bar:

        if optim_name in ['kfac', 'skfac', 'ekfac', 'sgd', 'adam']:
            inputs, targets = inputs.to(args.device), targets.to(args.device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            if optim_name in ['kfac', 'skfac', 'ekfac'] and optimizer.steps % optimizer.TCov == 0:
                # compute true fisher
                optimizer.acc_stats = True
                with torch.no_grad():
                    sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1),1).squeeze().to(args.device)
                loss_sample = criterion(outputs, sampled_y)
                loss_sample.backward(retain_graph=True)
                optimizer.acc_stats = False
                optimizer.zero_grad()  # clear the gradient for computing true-fisher.
            loss.backward()
            optimizer.step()
        elif optim_name in ['kbfgs', 'kbfgsl', 'kbfgsl_2loop', 'kbfgsl_mem_eff']:
            inputs, targets = inputs.to(args.device), targets.to(args.device)
            optimizer.zero_grad()
            outputs = net.forward(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            # do another forward-backward pass over batch inside step()
            def closure():
                return inputs, targets, criterion, False # is_autoencoder = False
            optimizer.step(closure)
        elif optim_name == 'exact_ngd':
            inputs, targets = inputs.to(args.device), targets.to(args.device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            # update Fisher inverse
            if batch_idx % args.freq == 0:
              # compute true fisher
              with torch.no_grad():
                sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1),1).squeeze().to(args.device)
              # use backpack extension to compute individual gradient in a batch
              batch_grad = []
              with backpack(BatchGrad()):
                loss_sample = criterion(outputs, sampled_y)
                loss_sample.backward(retain_graph=True)

              for name, param in net.named_parameters():
                if hasattr(param, "grad_batch"):
                  batch_grad.append(args.batch_size * param.grad_batch.reshape(args.batch_size, -1))
                else:
                  raise NotImplementedError

              J = torch.cat(batch_grad, 1)
              fisher = torch.matmul(J.t(), J) / args.batch_size
              inv = torch.linalg.inv(fisher + damping * torch.eye(fisher.size(0)).to(fisher.device))
              # clean the gradient to compute the true fisher
              optimizer.zero_grad()

            loss.backward()
            # compute the step direction p = F^-1 @ g
            grad_list = []
            for name, param in net.named_parameters():
              grad_list.append(param.grad.data.reshape(-1, 1))
            g = torch.cat(grad_list, 0)
            p = torch.matmul(inv, g)

            start = 0
            for name, param in net.named_parameters():
              end = start + param.data.reshape(-1, 1).size(0)
              param.grad.copy_(p[start:end].reshape(param.grad.data.shape))
              start = end

            optimizer.step()

        ### new optimizer test
        elif optim_name in ['kngd'] :
            inputs, targets = inputs.to(args.device), targets.to(args.device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            if  optimizer.steps % optimizer.freq == 0:
                # compute true fisher
                optimizer.acc_stats = True
                with torch.no_grad():
                    sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device)
                loss_sample = criterion(outputs, sampled_y)
                loss_sample.backward(retain_graph=True)
                optimizer.acc_stats = False
                optimizer.zero_grad()  # clear the gradient for computing true-fisher.
                if args.partial_backprop == 'true':
                  idx = (sampled_y == targets) == False
                  loss = criterion(outputs[idx,:], targets[idx])
                  # print('extra:', idx.sum().item())
            loss.backward()
            optimizer.step()

        elif optim_name == 'ngd':
            if batch_idx % args.freq == 0:
                store_io_(True)
                inputs, targets = inputs.to(args.device), targets.to(args.device)
                optimizer.zero_grad()
                # net.set_require_grad(True)

                outputs = net(inputs)
                damp = damping
                loss = criterion(outputs, targets)
                loss.backward(retain_graph=True)

                # storing original gradient for later use
                grad_org = []
                # grad_dict = {}
                for name, param in net.named_parameters():
                    grad_org.append(param.grad.reshape(1, -1))
                #     grad_dict[name] = param.grad.clone()
                grad_org = torch.cat(grad_org, 1)

                ###### now we have to compute the true fisher
                with torch.no_grad():
                # gg = torch.nn.functional.softmax(outputs, dim=1)
                    sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device)
                
                if args.trial == 'true':
                    update_list, loss = optimal_JJT_v2(outputs, sampled_y, args.batch_size, damping=damp, alpha=0.95, low_rank=args.low_rank, gamma=args.gamma, memory_efficient=args.memory_efficient, super_opt=args.super_opt)
                else:
                    update_list, loss = optimal_JJT(outputs, sampled_y, args.batch_size, damping=damp, alpha=0.95, low_rank=args.low_rank, gamma=args.gamma, memory_efficient=args.memory_efficient)

                # optimizer.zero_grad()
                # update_list, loss = optimal_JJT_fused(outputs, sampled_y, args.batch_size, damping=damp)

                optimizer.zero_grad()
   
                # last part of SMW formula
                grad_new = []
                for name, param in net.named_parameters():
                    param.grad.copy_(update_list[name])
                    grad_new.append(param.grad.reshape(1, -1))
                grad_new = torch.cat(grad_new, 1)   
                # grad_new = grad_org
                store_io_(False)
            else:
                inputs, targets = inputs.to(args.device), targets.to(args.device)
                optimizer.zero_grad()
                # net.set_require_grad(True)

                outputs = net(inputs)
                damp = damping
                loss = criterion(outputs, targets)
                loss.backward()

                # storing original gradient for later use
                grad_org = []
                # grad_dict = {}
                for name, param in net.named_parameters():
                    grad_org.append(param.grad.reshape(1, -1))
                #     grad_dict[name] = param.grad.clone()
                grad_org = torch.cat(grad_org, 1)

                ###### now we have to compute the true fisher
                # with torch.no_grad():
                # gg = torch.nn.functional.softmax(outputs, dim=1)
                    # sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device)
                all_modules = net.modules()

                for m in net.modules():
                    if hasattr(m, "NGD_inv"):                    
                        grad = m.weight.grad
                        if isinstance(m, nn.Linear):
                            I = m.I
                            G = m.G
                            n = I.shape[0]
                            NGD_inv = m.NGD_inv
                            grad_prod = einsum("ni,oi->no", (I, grad))
                            grad_prod = einsum("no,no->n", (grad_prod, G))
                            v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                            gv = einsum("n,no->no", (v, G))
                            gv = einsum("no,ni->oi", (gv, I))
                            gv = gv / n
                            update = (grad - gv)/damp
                            m.weight.grad.copy_(update)
                        elif isinstance(m, nn.Conv2d):
                            if hasattr(m, "AX"):

                                if args.low_rank.lower() == 'true':
                                    ###### using low rank structure
                                    U = m.U
                                    S = m.S
                                    V = m.V
                                    NGD_inv = m.NGD_inv
                                    n = NGD_inv.shape[0]

                                    grad_reshape = grad.reshape(grad.shape[0], -1)
                                    grad_prod = V @ grad_reshape.t().reshape(-1, 1)
                                    grad_prod = torch.diag(S) @ grad_prod
                                    grad_prod = U @ grad_prod
                                    
                                    grad_prod = grad_prod.squeeze()
                                    v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                                    gv = U.t() @ v.unsqueeze(1)
                                    gv = torch.diag(S) @ gv
                                    gv = V.t() @ gv

                                    gv = gv.reshape(grad_reshape.shape[1], grad_reshape.shape[0]).t()
                                    gv = gv.view_as(grad)
                                    gv = gv / n
                                    update = (grad - gv)/damp
                                    m.weight.grad.copy_(update)
                                else:
                                    AX = m.AX
                                    NGD_inv = m.NGD_inv
                                    n = AX.shape[0]

                                    grad_reshape = grad.reshape(grad.shape[0], -1)
                                    grad_prod = einsum("nkm,mk->n", (AX, grad_reshape))
                                    v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                                    gv = einsum("nkm,n->mk", (AX, v))
                                    gv = gv.view_as(grad)
                                    gv = gv / n
                                    update = (grad - gv)/damp
                                    m.weight.grad.copy_(update)
                            elif hasattr(m, "I"):
                                I = m.I
                                if args.memory_efficient == 'true':
                                    I = unfold_func(m)(I)
                                G = m.G
                                n = I.shape[0]
                                NGD_inv = m.NGD_inv
                                grad_reshape = grad.reshape(grad.shape[0], -1)
                                x1 = einsum("nkl,mk->nml", (I, grad_reshape))
                                grad_prod = einsum("nml,nml->n", (x1, G))
                                v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                                gv = einsum("n,nml->nml", (v, G))
                                gv = einsum("nml,nkl->mk", (gv, I))
                                gv = gv.view_as(grad)
                                gv = gv / n
                                update = (grad - gv)/damp
                                m.weight.grad.copy_(update)
                        elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                            if args.batchnorm == 'true':
                                dw = m.dw
                                n = dw.shape[0]
                                NGD_inv = m.NGD_inv
                                grad_prod = einsum("ni,i->n", (dw, grad))

                                v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze()
                                gv = einsum("n,ni->i", (v, dw))
                                
                                gv = gv / n
                                update = (grad - gv)/damp
                                m.weight.grad.copy_(update)
                        
                        

                # last part of SMW formula
                grad_new = []
                for name, param in net.named_parameters():
                    grad_new.append(param.grad.reshape(1, -1))
                grad_new = torch.cat(grad_new, 1)   
                # grad_new = grad_org


            ##### do kl clip
            lr = lr_scheduler.get_last_lr()[0]
            # vg_sum = 0
            # vg_sum += (grad_new * grad_org ).sum()
            # vg_sum = vg_sum * (lr ** 2)
            # nu = min(1.0, math.sqrt(args.kl_clip / vg_sum))
            # for name, param in net.named_parameters():
            #     param.grad.mul_(nu)

            # optimizer.step()
            # manual optimizing:
            with torch.no_grad():
                for name, param in net.named_parameters():
                    d_p = param.grad.data
                    # print('=== step ===')

                    # apply momentum
                    # if args.momentum != 0:
                    #     buf[name].mul_(args.momentum).add_(d_p)
                    #     d_p.copy_(buf[name])

                    # apply weight decay
                    if args.weight_decay != 0:
                        d_p.add_(args.weight_decay, param.data)

                    lr = lr_scheduler.get_last_lr()[0]
                    param.data.add_(-lr, d_p)
                    # print('d_p:', d_p.shape)
                    # print(d_p)



        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                (tag, lr_scheduler.get_last_lr()[0], train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
        prog_bar.set_description(desc, refresh=True)
        if args.step_info == 'true' and (batch_idx % 50 == 0 or batch_idx == len(prog_bar) - 1):
            step_saved_time = time.time() - step_st_time
            epoch_time += step_saved_time
            test_acc, test_loss = test(epoch)
            TRAIN_INFO['train_acc'].append(float("{:.4f}".format(100. * correct / total)))
            TRAIN_INFO['test_acc'].append(float("{:.4f}".format(test_acc)))
            TRAIN_INFO['train_loss'].append(float("{:.4f}".format(train_loss/(batch_idx + 1))))
            TRAIN_INFO['test_loss'].append(float("{:.4f}".format(test_loss)))
            TRAIN_INFO['total_time'].append(float("{:.4f}".format(step_saved_time)))
            if args.debug_mem == 'true':
                TRAIN_INFO['memory'].append(torch.cuda.memory_reserved())
            step_st_time = time.time()
            net.train()

    writer.add_scalar('train/loss', train_loss/(batch_idx + 1), epoch)
    writer.add_scalar('train/acc', 100. * correct / total, epoch)
    acc = 100. * correct / total
    train_loss = train_loss/(batch_idx + 1)
    if args.step_info == 'true':
        TRAIN_INFO['epoch_time'].append(float("{:.4f}".format(epoch_time)))
    # save diagonal blocks of exact Fisher inverse or its approximations
    if args.save_inv == 'true':
      all_modules = net.modules()

      count = 0
      start, end = 0, 0
      if optim_name == 'ngd':
        for m in all_modules:
          if m.__class__.__name__ == 'Linear':
            with torch.no_grad():
              I = m.I
              G = m.G
              J = torch.einsum('ni,no->nio', I, G)
              J = J.reshape(J.size(0), -1)
              JTDJ = torch.matmul(J.t(), torch.matmul(m.NGD_inv, J)) / args.batch_size

              with open('ngd/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f:
                np.save(f, ((torch.eye(JTDJ.size(0)).to(JTDJ.device) - JTDJ) / damping).cpu().numpy())
                count += 1

          elif m.__class__.__name__ == 'Conv2d':
            with torch.no_grad():
              AX = m.AX
              AX = AX.reshape(AX.size(0), -1)
              JTDJ = torch.matmul(AX.t(), torch.matmul(m.NGD_inv, AX)) / args.batch_size
              with open('ngd/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f:
                np.save(f, ((torch.eye(JTDJ.size(0)).to(JTDJ.device) - JTDJ) / damping).cpu().numpy())
                count += 1

      elif optim_name == 'exact_ngd':
        for m in all_modules:
          if m.__class__.__name__ in ['Conv2d', 'Linear']:
            with open('exact/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f:
              end = start + m.weight.data.reshape(1, -1).size(1)
              np.save(f, inv[start:end,start:end].cpu().numpy())
              start = end + m.bias.data.size(0)
              count += 1

      elif optim_name == 'kfac':
        for m in all_modules:
          if m.__class__.__name__ in ['Conv2d', 'Linear']:
            with open('kfac/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f:
              G = optimizer.m_gg[m]
              A = optimizer.m_aa[m]

              H_g = torch.linalg.inv(G + math.sqrt(damping) * torch.eye(G.size(0)).to(G.device))
              H_a = torch.linalg.inv(A + math.sqrt(damping) * torch.eye(A.size(0)).to(A.device))

              end = m.weight.data.reshape(1, -1).size(1)
              kfac_inv = torch.kron(H_a, H_g)[:end,:end]
              np.save(f, kfac_inv.cpu().numpy())
              count += 1

    return acc, train_loss
Exemplo n.º 26
0
model = Sequential(Flatten(), Linear(784, 10))
lossfunc = CrossEntropyLoss()

model = extend(model)
lossfunc = extend(lossfunc)

# %%
# First order extensions
# ----------------------

# %%
# Batch gradients

loss = lossfunc(model(X), y)
with backpack(BatchGrad()):
    loss.backward()

for name, param in model.named_parameters():
    print(name)
    print(".grad.shape:             ", param.grad.shape)
    print(".grad_batch.shape:       ", param.grad_batch.shape)

# %%
# Variance

loss = lossfunc(model(X), y)
with backpack(Variance()):
    loss.backward()

for name, param in model.named_parameters():
Exemplo n.º 27
0
tproblem = extend_with_access_unreduced_loss(tproblem)

# forward pass
batch_loss, _ = tproblem.get_batch_loss_and_accuracy()

# individual loss
savefield = "_unreduced_loss"
individual_loss = getattr(batch_loss, savefield)

print("Individual loss shape:   ", individual_loss.shape)
print("Mini-batch loss:         ", batch_loss)
print("Averaged individual loss:", individual_loss.mean())

# It is still possible to use BackPACK in the backward pass
with backpack(
        BatchGrad(),
        Variance(),
        SumGradSquared(),
        BatchL2Grad(),
        DiagGGNExact(),
        DiagGGNMC(),
        KFAC(),
        KFLR(),
        KFRA(),
        DiagHessian(),
):
    batch_loss.backward()

# print info
for name, param in tproblem.net.named_parameters():
    print(name)
# Computing clipped individual gradients
# -----------------------------------------------------------------
#
# Before writing the optimizer class, let's see how we can use ``BackPACK``
# on a single batch to compute the clipped gradients, without the overhead
# of the optimizer class.
#
# We take a single batch from the data loader, compute the loss,
# and use the ``with(backpack(...))`` syntax to activate two extensions;
# ``BatchGrad`` and ``BatchL2Grad``.

x, y = next(iter(mnist_dataloader))
x, y = x.to(DEVICE), y.to(DEVICE)

loss = loss_function(model(x), y)
with backpack(BatchL2Grad(), BatchGrad()):
    loss.backward()

# %%
# ``BatchGrad`` computes individual gradients and ``BatchL2Grad`` their norm (squared),
# which get stored in the ``grad_batch`` and ``batch_l2`` attributes of the parameters

for p in model.parameters():
    print("{:28} {:32} {}".format(str(p.grad.shape), str(p.grad_batch.shape),
                                  str(p.batch_l2.shape)))

# %%
# To compute the clipped gradients, we need to know the norms of the complete
# individual gradients, but ad the moment they are split across parameters,
# so let's reduce over the parameters
def run_exp(meta_seed, nhid, nlayers, n_train_seeds):
    torch.manual_seed(meta_seed)
    np.random.seed(meta_seed)
    gamma = 0.9

    def init_weights(m):
        if isinstance(m, torch.nn.Linear):
            k = np.sqrt(6 / (np.sum(m.weight.shape)))
            m.weight.data.uniform_(-k, k)
            m.bias.data.fill_(0)
        if isinstance(m, torch.nn.Conv2d):
            u, v, w, h = m.weight.shape
            k = np.sqrt(6 / (w * h * u + w * h * v))
            m.weight.data.uniform_(-k, k)
            m.bias.data.fill_(0)

    env = CifarWindowEnv(_train_x, _train_y)
    test_env = CifarWindowEnvBatch()
    env.step_reward = 0.05
    test_env.step_reward = 0.05

    ##nhid = 32
    act = torch.nn.LeakyReLU()
    #act = torch.nn.Tanh()
    model = torch.nn.Sequential(*([
        torch.nn.Conv2d(4, nhid, 5, stride=2), act,
        torch.nn.Conv2d(nhid, nhid * 2, 3), act,
        torch.nn.Conv2d(nhid * 2, nhid * 4, 3), act
    ] + sum([[torch.nn.Conv2d(nhid * 4, nhid * 4, 3, padding=1), act]
             for i in range(nlayers)], []) + [
                 torch.nn.Flatten(),
                 torch.nn.Linear(nhid * 4 * 10 * 10, nhid * 4), act,
                 torch.nn.Linear(nhid * 4, 14)
             ]))
    model.to(device)
    model.apply(init_weights)
    if 1:
        model = extend(model)
    opt = torch.optim.Adam(model.parameters(), 1e-5)  #, weight_decay=1e-5)

    #opt = torch.optim.SGD(model.parameters(), 1e-4, momentum=0.9)

    def run_test(X, Y, dataacc=None):
        obs = test_env.reset(X, Y)
        accr = np.zeros(len(X))
        for i in range(test_env.max_steps):
            o = model(obs)
            cls_act = Categorical(logits=o[:, :10]).sample()
            mov_act = Categorical(logits=o[:, 10:]).sample()
            actions = torch.stack([cls_act, mov_act]).data.cpu().numpy()
            obs, r, done, _ = test_env.step(actions)
            accr += r
            if dataacc is not None:
                dataacc.append(obs[np.random.randint(0, len(obs))])
                dataacc.append(obs[np.random.randint(0, len(obs))])
            if done.all():
                break
        return test_env.correct_answers / len(X), test_env.acc_reward

    train_perf = []
    test_perf = []
    all_dots = []
    all_dots_test = []
    tds = []
    qs = []
    xent = torch.nn.CrossEntropyLoss()
    tau = 0.1

    n_rp = 1000
    rp_s = torch.zeros((n_rp, 4, 32, 32), device=device)
    rp_a = torch.zeros((n_rp, 2), device=device, dtype=torch.long)
    rp_g = torch.zeros((n_rp, ), device=device)
    rp_idx = 0
    rp_fill = 0

    obs = env.reset(np.random.randint(0, n_train_seeds))
    ntest = 128
    epsilon = 0.9
    ep_reward = 0
    ep_rewards = []
    ep_start = 0
    for i in tqdm(range(200001)):
        epsilon = 0.9 * (1 - min(i, 100000) / 100000) + 0.05
        if not i % 1000:
            t0 = time.time()
            dataacc = []
            with torch.no_grad():
                train_perf.append(
                    run_test(_train_x[:min(ntest, n_train_seeds)],
                             _train_y[:min(ntest, n_train_seeds)]))
                test_perf.append(
                    run_test(test_x[:ntest], test_y[:ntest], dataacc=dataacc))
            print(train_perf[-2:], test_perf[-2:], np.mean(ep_rewards[-50:]),
                  len(ep_rewards))
            if 1:
                t1 = time.time()
                s = rp_s[:128]
                if 1:
                    loss = sumloss(model(s).mean())
                    with backpack(BatchGrad()):
                        loss.backward()
                    all_grads = torch.cat([
                        i.grad_batch.reshape((s.shape[0], -1))
                        for i in model.parameters()
                    ], 1)
                else:
                    all_grads = []
                    for k in range(len(s)):
                        Qsa = model(s[k][None, :])
                        grads = torch.autograd.grad(Qsa.max(),
                                                    model.parameters())
                        fg = torch.cat([i.reshape((-1, )) for i in grads])
                        all_grads.append(fg)
                dots = []
                for k in range(len(s)):
                    for j in range(k + 1, len(s)):
                        dots.append(all_grads[k].dot(all_grads[j]).item())
                all_dots.append(np.float32(dots).mean())
                opt.zero_grad()
                s = torch.stack(dataacc[:128])
                loss = sumloss(model(s).mean())
                with backpack(BatchGrad()):
                    loss.backward()
                all_grads = torch.cat([
                    i.grad_batch.reshape((s.shape[0], -1))
                    for i in model.parameters()
                ], 1)
                dots = []
                for k in range(len(s)):
                    for j in range(k + 1, len(s)):
                        dots.append(all_grads[k].dot(all_grads[j]).item())
                all_dots_test.append(np.float32(dots).mean())
                opt.zero_grad()
                if i and 0:
                    print(i, (cls_pi * torch.log(cls_pi)).mean().item(),
                          (mov_pi * torch.log(mov_pi)).mean().item())
                    print(cls_pi[0].data.cpu().numpy())

        o = model(obs[None, :])
        cls_act = Categorical(logits=o[0, :10]).sample().item()
        #cls_act = env.current_y
        mov_act = Categorical(logits=o[0, 10:]).sample().item()
        action = np.int32([cls_act, mov_act])

        #if np.random.uniform(0,1) < 0.4:
        #    action = env.current_y

        obsp, r, done, _ = env.step(action)
        rp_s[rp_idx] = obs
        rp_a[rp_idx] = torch.tensor(action)
        rp_idx = (rp_idx + 1) % rp_s.shape[0]
        rp_fill += 1
        ep_reward += r
        obs = obsp
        if done:
            rp_g[ep_start:i] = ep_reward
            ep_rewards.append(ep_reward)
            ep_reward = 0
            ep_start = i
            obs = env.reset(np.random.randint(0, n_train_seeds))
            if rp_idx > 250:
                rp_at = rp_idx
                rp_idx = 0
                rp_fill = 0
                s = rp_s[:rp_at]
                a = rp_a[:rp_at]
                g = rp_g[:rp_at]
                o = model(s)
                cls_pi = F.softmax(o[:, :10], 1).clamp(min=1e-5)
                mov_pi = F.softmax(o[:, 10:], 1).clamp(min=1e-5)
                cls_prob = torch.log(cls_pi[torch.arange(len(a)), a[:, 0]])
                mov_prob = torch.log(mov_pi[torch.arange(len(a)), a[:, 1]])
                #import pdb; pdb.set_trace()
                loss = -(g * (cls_prob + mov_prob)).mean()
                loss += 1e-4 * ((cls_pi * torch.log(cls_pi)).sum(1).mean() +
                                (mov_pi * torch.log(mov_pi)).sum(1).mean())
                loss.backward()
                #for p in model.parameters():
                #    p.grad.data.clamp_(-1, 1)
                opt.step()
                opt.zero_grad()

    s = rp_s[:200]
    all_grads = []
    for i in range(len(s)):
        Qsa = model(s[i][None, :])
        grads = torch.autograd.grad(Qsa.max(), model.parameters())
        fg = torch.cat([i.reshape((-1, )) for i in grads], 0)
        all_grads.append(fg)
    dots = []
    cosd = []
    for i in range(len(s)):
        for j in range(i + 1, len(s)):
            dots.append(all_grads[i].dot(all_grads[j]).item())
            cosd.append(all_grads[i].dot(all_grads[j]).item() / (np.sqrt(
                (all_grads[i]**2).sum().item()) * np.sqrt(
                    (all_grads[j]**2).sum().item())))
    dots = np.float32(dots)
    print(np.mean(train_perf[-5:]), np.mean(test_perf[-5:]))  #, all_dots[-5:])

    return {
        'dots': dots,
        'cosd': cosd,
        'all_dots': all_dots,
        'all_dots_test': all_dots_test,
        'train_perf': train_perf,
        'test_perf': test_perf,
    }
def run_exp(meta_seed, nhid, nlayers, n_train_seeds):
    torch.manual_seed(meta_seed)
    np.random.seed(meta_seed)
    gamma = 0.9
    train_x = _train_x[:n_train_seeds]
    train_y = _train_y[:n_train_seeds]

    ##nhid = 32
    act = torch.nn.LeakyReLU()
    #act = torch.nn.Tanh()
    model = torch.nn.Sequential(*([
        torch.nn.Conv2d(3, nhid, 5, stride=2), act,
        torch.nn.Conv2d(nhid, nhid * 2, 3), act,
        torch.nn.Conv2d(nhid * 2, nhid * 4, 3), act
    ] + sum([[torch.nn.Conv2d(nhid * 4, nhid * 4, 3, padding=1), act]
             for i in range(nlayers)], []) + [
                 torch.nn.Flatten(),
                 torch.nn.Linear(nhid * 4 * 10 * 10, nhid), act,
                 torch.nn.Linear(nhid, 10)
             ]))

    def init_weights(m):
        if isinstance(m, torch.nn.Linear):
            k = np.sqrt(6 / (np.sum(m.weight.shape)))
            m.weight.data.uniform_(-k, k)
            m.bias.data.fill_(0)

    model.to(device)
    model.apply(init_weights)
    if 0:
        model = extend(model)
    opt = torch.optim.Adam(model.parameters(), 1e-3)  #, weight_decay=1e-5)

    def compute_acc_mb(X, Y, mbs=1024):
        N = len(X)
        i = 0
        tot = 0
        while i < N:
            x = X[i:i + mbs]
            y = Y[i:i + mbs]
            tot += np.sum(
                model(x).argmax(1).cpu().data.numpy() == y.cpu().data.numpy())
            i += mbs
        return tot / N

    train_perf = []
    test_perf = []
    all_dots = []
    xent = torch.nn.CrossEntropyLoss()

    for i in tqdm(range(1000)):
        if not i % 20:
            train_perf.append(compute_acc_mb(train_x, train_y))
            test_perf.append(compute_acc_mb(test_x, test_y))
            s = train_x[:96]
            if 0:
                loss = sumloss(model(s).max(1).values)
                with backpack(BatchGrad()):
                    loss.backward()
                all_grads = torch.cat([
                    i.grad_batch.reshape((mbs, -1))
                    for i in model.parameters()
                ], 1)
            all_grads = []
            for i in range(len(s)):
                Qsa = model(s[i][None, :])
                grads = torch.autograd.grad(Qsa.max(), model.parameters())
                fg = torch.cat([i.reshape((-1, )) for i in grads])
                all_grads.append(fg)
            dots = []
            cosd = []
            cosd = []
            for i in range(len(s)):
                for j in range(i + 1, len(s)):
                    dots.append(all_grads[i].dot(all_grads[j]).item())
            all_dots.append(np.float32(dots).mean())
            opt.zero_grad()
        mbidx = np.random.randint(0, len(train_x), 32)
        x = train_x[mbidx]
        y = train_y[mbidx]
        pred = model(x)
        loss = xent(pred, y)
        loss.backward()
        opt.step()
        opt.zero_grad()

    s = train_x[:200]
    all_grads = []
    for i in range(len(s)):
        Qsa = model(s[i][None, :])
        grads = torch.autograd.grad(Qsa.max(), model.parameters())
        fg = torch.cat([i.reshape((-1, )) for i in grads], 0)
        all_grads.append(fg)
    dots = []
    cosd = []
    for i in range(len(s)):
        for j in range(i + 1, len(s)):
            dots.append(all_grads[i].dot(all_grads[j]).item())
            cosd.append(all_grads[i].dot(all_grads[j]).item() / (np.sqrt(
                (all_grads[i]**2).sum().item()) * np.sqrt(
                    (all_grads[j]**2).sum().item())))
    dots = np.float32(dots)
    print(train_perf[-5:], test_perf[-5:], all_dots[-5:])

    return {
        'dots': dots,
        'cosd': cosd,
        'all_dots': all_dots,
        'train_perf': train_perf,
        'test_perf': test_perf,
    }