def compute_hessian(y, params): grads = torch.autograd.grad([y], params, create_graph=True) flat_grads = trpo._flatten_and_concat_variables(grads) hessian_rows = [] for i in range(len(flat_grads)): ggrads = torch.autograd.grad([flat_grads[i]], params, retain_graph=True) assert all(ggrad is not None for ggrad in ggrads) flat_ggrads_data = trpo._flatten_and_concat_variables(ggrads).detach() hessian_rows.append(flat_ggrads_data) return torch.stack(hessian_rows)
def compute_hessian_vector_product(y, params, vec): grads = torch.autograd.grad([y], params, create_graph=True) flat_grads = trpo._flatten_and_concat_variables(grads) return trpo._hessian_vector_product(flat_grads, params, vec)