def test_grad1():
    torch.manual_seed(1)
    model = Net()
    loss_fn = nn.CrossEntropyLoss()

    n = 4
    data = torch.rand(n, 1, 28, 28)
    targets = torch.LongTensor(n).random_(0, 10)

    autograd_hacks.add_hooks(model)
    output = model(data)
    loss_fn(output, targets).backward(retain_graph=True)
    autograd_hacks.compute_grad1(model)
    autograd_hacks.disable_hooks()

    # Compare values against autograd
    losses = torch.stack(
        [loss_fn(output[i:i + 1], targets[i:i + 1]) for i in range(len(data))])

    for layer in model.modules():
        if not autograd_hacks.is_supported(layer):
            continue
        for param in layer.parameters():
            assert torch.allclose(param.grad, param.grad1.mean(dim=0))
            assert torch.allclose(jacobian(losses, param), param.grad1)
def subtest_hess_type(hess_type):
    torch.manual_seed(1)
    model = TinyNet()

    def least_squares_loss(data_, targets_):
        assert len(data_) == len(targets_)
        err = data_ - targets_
        return torch.sum(err * err) / 2 / len(data_)

    n = 3
    data = torch.rand(n, 1, 28, 28)

    autograd_hacks.add_hooks(model)
    output = model(data)

    if hess_type == 'LeastSquares':
        targets = torch.rand(output.shape)
        loss_fn = least_squares_loss
    else:  # hess_type == 'CrossEntropy':
        targets = torch.LongTensor(n).random_(0, 10)
        loss_fn = nn.CrossEntropyLoss()

    autograd_hacks.backprop_hess(output, hess_type=hess_type)
    autograd_hacks.clear_backprops(model)
    autograd_hacks.backprop_hess(output, hess_type=hess_type)

    autograd_hacks.compute_hess(model)
    autograd_hacks.disable_hooks()

    for layer in model.modules():
        if not autograd_hacks.is_supported(layer):
            continue
        for param in layer.parameters():
            loss = loss_fn(output, targets)
            hess_autograd = hessian(loss, param)
            hess = param.hess
            assert torch.allclose(hess, hess_autograd.reshape(hess.shape))
ppsi.real_comp.zero_grad()

pars1 = list(ppsi.real_comp.parameters())

dw = 0.001  # sometimes less accurate when smaller than 1e-3
with torch.no_grad():
    pars1[0][0][0] = pars1[0][0][0] + dw

# Choose a specific s
s = torch.tensor(np.random.choice(evals, [1, L]), dtype=torch.float)
#s=torch.ones([1,L])

# First let's test the autodifferentiation:
if not hasattr(original_net.real_comp, 'autograd_hacks_hooks'):
    autograd_hacks.add_hooks(original_net.real_comp)
out_0 = original_net.real_comp(s)
out_0.mean().backward()
autograd_hacks.compute_grad1(original_net.real_comp)
autograd_hacks.clear_backprops(original_net.real_comp)
pars = list(original_net.real_comp.parameters())
grad0 = pars[0].grad1[0, :, :]

#out_0=original_net.real_comp(s)
#out_0.mean().backward()
#pars=list(original_net.real_comp.parameters())
#grad0=pars[0].grad #* (1/N_samples)

# Calculate the new and old wavefunctions for this s, numerical dln(Psi)
original_net.Autoregressive_pass(s, evals)
wvf0 = original_net.wvf
#        os.path.join(
#            save_models_path, 'eigen={}.pth'.format( args['eigen'] )
#        )))

#model reduction by model manifold boundary
for epoch in trange(args['epochs']):
    time.sleep(1)
    net.train()
    criterion = nn.CrossEntropyLoss()
    for batch_idx, (data, target) in enumerate(kloader):

        #clear gradients + add hooks + forward
        autograd_hacks.clear_backprops(net)
        optimizer.zero_grad()
        data, target = data.to(args['dev']), target.to(args['dev'])
        autograd_hacks.add_hooks(net)
        criterion(net(data), target).backward(retain_graph=True)

        #compute per sample gradient
        autograd_hacks.compute_grad1(net)
        autograd_hacks.disable_hooks()

        #Jacobian + kernel matrix + eigen decomposition of kernel matrix
        J = torch.cat([
            params.grad1.data.view(args['k_size'], -1)
            for params in net.parameters()
        ], 1)
        kernel = torch.matmul(J, J.T)
        w, v = torch.symeig(kernel, eigenvectors=True)
        #flipping to decreasing order
        w, v = torch.flip(w, dims=[0]), torch.flip(v, dims=[1])
    def forward(self, N_samples=None, x=None, requires_grad=True):

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        if N_samples is None and x is None:
            raise ValueError('Must enter samples or the number of samples to' \
                              ' be generated')

        # if not sampling, just calculating wavefunction
        if N_samples is None and x is not None:
            N_samples, need_samples = x.shape[0], False

        # if sampling and calculating wavefunction
        if N_samples is not None and x is None:
            need_samples = True
            x = torch.zeros([N_samples, self.D], dtype=torch.float).to(device)

        # the full wavefunction is a product of the conditionals
        WAV = torch.ones([N_samples]).to(device)
        order = np.arange(0, self.D)  # sequential autoregressive ordering

        # for gradient tracking
        params = list(self.parameters())
        grads_per_param = []

        for d in range(self.D):

            # mask enforces the autoregressive property
            mask = torch.zeros_like(x)
            mask[:, order[0:(d)]] = 1

            # add autograd hooks for per-sample gradient calculation
            if not hasattr(self.model, 'autograd_hacks_hooks'):
                autograd_hacks.add_hooks(self.model)

            # L2 normalization of masked output
            out = F.normalize(self.model(mask * x), 2)

            # 'psi_pos' is positive bits, 'psi_neg' is negative bits
            psi_pos = out[:, 0].squeeze()
            psi_neg = out[:, 1].squeeze()

            if need_samples == True:

                # sampling routine according to psi**2:
                # convert bit values from 0 to -1
                m = torch.distributions.Bernoulli(psi_pos**2).sample()
                m = torch.where(m == 0,
                                torch.tensor(-1).to(device),
                                torch.tensor(1).to(device))

                # update sample tensor
                x[:, d] = m

                # Accumulate PSI based on which state (s) was sampled
                selected_wavs = torch.where(x[:, d] > 0, psi_pos, psi_neg)
                WAV = WAV * selected_wavs

            else:

                # if not sampling, m is a list of bits in column d
                m = x[:, d]

                # Accumulate PPSI based on which state (s) was sampled
                selected_wavs = torch.where(m > 0, psi_pos, psi_neg)
                WAV = WAV * selected_wavs

            if requires_grad == True:

                # eval_grads stores backpropagation values for out1 and out2.
                # eval_grads[0] are the out1 grads for all samples (per param),
                # eval_grads[1] are the out2 grads for all samples (per param).
                eval_grads = [[[]] * len(params)
                              for outputs in range(len(self.evals))]

                # Store the per-output grads in eval_grads
                for output in range(len(self.evals)):

                    # backpropagate the current output (out1 or out2)
                    out[:, output].mean(0).backward(retain_graph=True)

                    # compute gradients for all samples
                    autograd_hacks.compute_grad1(self.model)
                    autograd_hacks.clear_backprops(self.model)

                    # store the calculated gradients for all samples
                    for param in range(len(params)):
                        eval_grads[output][param] = params[param].grad1

                # allocate space for gradient accumulation
                if d == 0:
                    for param in range(len(params)):
                        grads_per_param.append(
                            torch.zeros_like(eval_grads[0][param]))

                # accumulate gradients per parameter based on sampled bits
                for param in range(len(params)):

                    # reshape m and wavs so they can be accumulated/divided properly
                    reshaped_m = m.reshape(m.shape + (1, ) *
                                           (grads_per_param[param].ndim - 1))
                    reshaped_wavs = selected_wavs.reshape(
                        selected_wavs.shape + (1, ) *
                        (grads_per_param[param].ndim - 1))

                    # select the proper gradient to accumulate based on m
                    grads_per_param[param][:] += torch.where(
                        reshaped_m[:] > 0,
                        eval_grads[0][param][:] / reshaped_wavs[:],
                        eval_grads[1][param][:] / reshaped_wavs[:])

        return WAV.detach(), x.detach(), grads_per_param
Exemple #6
0
    def dp_sgd(self, model, global_round, norm_bound, noise_scale):
        #################
        ## ALGORITHM 1 ##
        #################

        # Set mode to train model
        model.train()
        epoch_loss = []

        model_dummy = copy.deepcopy(model)

        # Set optimizer for the local updates
        if self.args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=self.args.lr,
                                        momentum=0.0)
        elif self.args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=self.args.lr,
                                         weight_decay=1e-4)

        # for each epoch (1...E)
        for iter in range(self.args.local_ep):

            batch_loss = []

            # for each batch
            for batch_idx, (images, labels) in enumerate(self.trainloader):

                # add hooks for per sample gradients
                model.zero_grad()
                autograd_hacks.add_hooks(model)

                # Forward pass, compute loss, backwards pass
                log_probs = model(torch.FloatTensor(images))
                loss = self.criterion(log_probs, labels)
                loss.backward(retain_graph=True)

                # Per-sample gradients g_i
                autograd_hacks.compute_grad1(model)
                autograd_hacks.disable_hooks()

                # Compute L2^2 norm for each g_i
                g_norms = torch.zeros(labels.shape[0])

                for name, param in model.named_parameters():
                    g_norms += param.grad1.flatten(1).norm(2, dim=1)**2

                # Clipping factor =  min(1, C / norm(gi)) ....OR.... max(1, norm(gi) / C)
                clip_factor = torch.clamp(g_norms**0.5 / norm_bound, min=1)
                #print(np.percentile(g_norms ** 0.5, [25, 50, 75]))

                # Clip each gradient
                for param in model.parameters():
                    for i in range(len(labels)):
                        param.grad1.data[i] /= clip_factor[i]

                # Noisy batch update
                for param in model.parameters():
                    # batch average of clipped gradients
                    param.grad = param.grad1.mean(dim=0)

                    # add noise
                    param.grad += torch.randn(
                        param.size()) * norm_bound * noise_scale / len(labels)
                    print(
                        param.grad,
                        torch.randn(param.size()) * norm_bound * noise_scale /
                        len(labels))
                    # update weights
                    param.data -= self.args.lr * param.grad.data

                # revert model back to proper format (per-sample gradients messed it up a bit)
                model_dummy.load_state_dict(model.state_dict())
                model = copy.deepcopy(model_dummy)

                # Record loss
                batch_loss.append(loss.item())

            # Append loss, go to next epoch...
            epoch_loss.append(sum(batch_loss) / len(batch_loss))

        return model.state_dict(), sum(epoch_loss) / len(epoch_loss)