def test_bivariate(self): # compute derivative with forward pass y, f1, s1 = self.net_bivariate.forward_and_derivatives(self.tensor) f2, s2 = jacobian_second_order(self.tensor, y) assert torch.allclose(f1, f2) assert torch.allclose(s1, s2)
def test_equality_Hessian_jacobian_second_bn_no_affine(self): y = self.net_batchnorm_no_affine(self.data) f_x, s_x = jacobian_second_order(self.data, y) f_x_1, H_x = jacobian_hessian(self.data, y) s_x_1 = torch.einsum('biik->bik', H_x) assert torch.allclose(f_x, f_x_1) assert torch.allclose(s_x, s_x_1)
def test_running_time(self): start = time() for i in range(100): y = self.net(self.data) f2, s2 = jacobian_second_order(self.data, y) time_no_bn = time() - start start = time() for i in range(100): y1 = self.net_batchnorm(self.data) f2, s2 = jacobian_second_order(self.data, y1) time_bn = time() - start start = time() for i in range(100): y2 = self.net_batchnorm(self.data) f2, s2 = jacobian_second_order(self.data, y2) time_bn_no_affine = time() - start print("No bn: {:.4f}, bn: {:.4f}, bn no affine: {:.4f}".format( time_no_bn, time_bn, time_bn_no_affine))
def batch_Fisher_div_with_c_x(net, samples, etas, lam=0): # do the forward pass at once here: transformed_samples = net(samples) f, s = jacobian_second_order(samples, transformed_samples) f = f.reshape(-1, f.shape[1], f.shape[2]) s = s.reshape(-1, s.shape[1], s.shape[2]) return Fisher_divergence_loss_with_c_x(f, s, etas, lam=lam) / (samples.shape[0])
def batch_Fisher_div(net, samples, etas): # do the forward pass at once here: transformed_samples = net(samples) f, s = jacobian_second_order(samples, transformed_samples) # we reshape the f and s f = f.reshape(-1, f.shape[1], f.shape[2]) s = s.reshape(-1, s.shape[1], s.shape[2]) return Fisher_divergence_loss(f, s, etas) / (samples.shape[0])
# if you have saved the files before and only want to create plots: # times_forward = np.load("results/times_forward.npy") # times_jac = np.load("results/times_jac.npy") data = torch.randn(5000, 100) data.requires_grad = True # study time complexity wrt number outputs in the NN: for out_size in n_outputs: print(out_size) net = createDefaultNNWithDerivatives(100, out_size)() start = time() y1 = net(data) f1, s1 = jacobian_second_order(data, y1) times_jac.append(time() - start) start = time() y2, f2, s2 = net.forward_and_derivatives(data) times_forward.append(time() - start) print(times_jac) print(times_forward) if len(times_jac) > 0: np.save("results/times_jac", np.array(times_jac)) if len(times_forward) > 0: np.save("results/times_forward", np.array(times_forward)) fig, ax = plt.subplots(1, 1) ax.plot(n_outputs, times_jac)