Пример #1
0
def discretize(F, Q, t, device):
    """
    Discretize matrices for continuous update step with matrix exponential and matrix fraction decomposition
    
    Keyword arguments:
    F, Q -- KF matrices for state update (torch.Tensor)
    t -- time delta to last observation (torch.Tensor)
    device -- device of model (torch.device)

    Returns:
    A, L -- discretized update matrices
    """
    if len(F.shape) == 3:
        m = F.shape[0]
        n = F.shape[1]
        M = torch.zeros(m, 2 * n, 2 * n)
        A = torch.matrix_exp(F * t.unsqueeze(-1).unsqueeze(-1).to(device))
        M[:, :n, :n] = F
        M[:, :n, n:] = Q
        M[:, n:, n:] = -F.transpose(1, 2)
        M = torch.matrix_exp(M * t.unsqueeze(-1).unsqueeze(-1)) @ torch.cat(
            [torch.zeros(n, n), torch.eye(n, n)])
    else:
        n = F.shape[0]
        M = torch.zeros(2 * n, 2 * n)
        A = torch.matrix_exp(F * t.view(-1, 1, 1).to(device))
        M[:n, :n] = F
        M[:n, n:] = Q
        M[n:, n:] = -F.T
        M = torch.matrix_exp(M * t.view(-1, 1, 1)) @ torch.cat(
            [torch.zeros(n, n), torch.eye(n, n)])
    C, D = M[:, :n], M[:, n:]
    L = C @ torch.inverse(D)
    return A.to(device), L.to(device)
Пример #2
0
    def params2orb(params, coeffs, with_penalty):
        # params: (*, nparams)
        # coeffs: (*, nao, norb)
        nao = coeffs.shape[-2]
        norb = coeffs.shape[-1]
        nparams = params.shape[-1]
        bshape = params.shape[:-1]

        # construct the rotation parameters
        triu_idxs = torch.triu_indices(nao, nao, offset=1)[..., :nparams]
        rotmat = torch.zeros((*bshape, nao, nao),
                             dtype=params.dtype,
                             device=params.device)
        rotmat[..., triu_idxs[0], triu_idxs[1]] = params
        rotmat = rotmat - rotmat.transpose(-2, -1).conj()

        # calculate the orthogonal orbital
        ortho_orb = torch.matrix_exp(rotmat) @ coeffs

        if with_penalty:
            penalty = torch.zeros((1, ),
                                  dtype=params.dtype,
                                  device=params.device)
            return ortho_orb, penalty
        else:
            return ortho_orb
Пример #3
0
def expm(x, eps, algo='torch'):
    if algo == 'torch':
        return torch.matrix_exp(x)
    elif algo == 'original':
        return exp_from_paper(x, eps)
    else:
        raise Exception('Invalid expm algo!')
    def train_exp(self, x, lie_alg_basis_ls, mat_dim, lie_var_ls,
                  lie_alg_init_type):
        lie_alg_basis_ls = [p * 1. for p in lie_alg_basis_ls
                            ]  # For torch.cat, convert param to tensor.
        lie_alg_basis = torch.cat(lie_alg_basis_ls,
                                  dim=0)[np.newaxis,
                                         ...]  # [1, lat_dim, mat_dim, mat_dim]
        if lie_alg_init_type == 'oth':
            lie_alg_basis = lie_alg_basis - lie_alg_basis.transpose(-2, -1)
        if self.normalize_alg:
            norm_v = torch.linalg.norm(lie_alg_basis, dim=[-2, -1])
            lie_alg_basis = lie_alg_basis / norm_v
        if self.use_alg_var:
            lie_var = torch.cat(lie_var_ls, dim=1)  # [1, lat_dim]
            lie_alg_basis = lie_alg_basis * lie_var[..., np.newaxis,
                                                    np.newaxis]

        lie_group = torch.eye(mat_dim, dtype=x.dtype).to(
            x.device)[np.newaxis, ...]  # [1, mat_dim, mat_dim]
        lie_alg = 0.
        latents_in_cut_ls = self.split_latents(x)  # [x0, x1, ...]
        for masked_latent in latents_in_cut_ls:
            lie_alg_sum_tmp = torch.sum(
                masked_latent[..., np.newaxis, np.newaxis] * lie_alg_basis,
                dim=1)
            lie_alg += lie_alg_sum_tmp  # [b, mat_dim, mat_dim]
            lie_group_tmp = torch.matrix_exp(lie_alg_sum_tmp)
            lie_group = torch.matmul(lie_group,
                                     lie_group_tmp)  # [b, mat_dim, mat_dim]
        return lie_group
def benchmark_mexp(b, d, r, repeats=30, do_backward=True):
    assert d >= r
    if not torch.cuda.is_available():
        raise RuntimeError('Benchmaring requires CUDA')
    eye = torch.eye(d, device='cuda:0').unsqueeze(0).repeat(b, 1, 1)

    start_sec = None
    discrepancies = []

    for i in range(repeats + 1):
        if i == 1:
            torch.cuda.synchronize()
            start_sec = time.monotonic()

        param = torch.randn(b, d, d, device='cuda:0').tril(diagonal=-1)
        param = param - param.permute(0, 2, 1)  # skew-symmetric
        param = torch.nn.Parameter(param)

        out = torch.matrix_exp(param)[:, :, :r]

        if do_backward:
            loss = out.abs().sum()
            loss.backward()

        if i >= 1:
            with torch.no_grad():
                discrepancies.append(tensor_diff(out.permute(0, 2, 1).bmm(out), eye[:, :r, :r]))

    torch.cuda.synchronize()
    end_sec = time.monotonic()

    avg_ms = (end_sec - start_sec) * 1000 / repeats
    avg_err = torch.tensor(discrepancies).mean().item()
    return avg_ms, avg_err
 def get_act_repr_simple(self, act_param):
     b, n_lats = list(act_param.size())
     alg_tmp = torch.triu(
         torch.ones(b, n_lats, 2, 2, device=act_param.device))
     alg_tmp = alg_tmp - alg_tmp.transpose(-2, -1)
     act_alg = act_param.view(b, n_lats, 1, 1) * alg_tmp
     act_repr = torch.matrix_exp(act_alg.view(-1, 2, 2))  # [b*n_lats, 2, 2]
     act_repr = act_repr.view(b, n_lats, 2, 2)
     return act_repr
Пример #7
0
    def _compute_h(self):
        """
        Compute DAG penalty.

        Return
        ------
        torch.Tensor: DAG penalty term (scalar-valued).
        """
        return torch.trace(torch.matrix_exp(self.B * self.B)) - self.d
Пример #8
0
def compute_acyclicity(w):
    """

    Parameters
    ----------
    w: torch.Tensor
    """

    return torch.trace(torch.matrix_exp(w * w)) - w.shape[0]
Пример #9
0
    def forward(self, source_noise, context_enc, x_prev, dx):
        """ Generate the output given a latent code (Inverse Normalizing Flow).
        Args:
            source_noise (FloatTensor): Source Noise (e.g., Gaussian) to do generation.
            context_enc (FloatTensor): Fused past_trajectory + local_scene feature.
            x_prev (FloatTensor): Initial positions of agents.
            dx (FloatTensor): Initial velocities of agents.
            
            
            global_scene (FloatTensor): The global scene feature map.
            scene_idx (IntTensor): The global_scene index corresponding to each agent.
            _feedback (Optional, FloatTensor): Agent states over the past T_{feedback} steps.
            _h (Optional, FloatTensor): GRU hidden states.

        Input Shapes:
            source_noise (A, 2)
            sigma (A, 2, 2)




            context_enc (A, D_{lc})
            x_prev: (A, 2)
            dx: (A, 2)



            global_scene: (B, C, H*W)
            scene_idx: (A, )
            _feedback (Optional): (A, 2*T_{feedback})
            _h (Optional): (N_{layers}, A, D_{gru})
        
        Output Shapes:
            x: (A, 2)
            mu: (A, 2)
            sigma: (A, 2, 2)
        """
        total_agents = source_noise.size(0)  # The number of agents A.

        output = self.mlp(context_enc)
        prediction = self.projection(output)  # (A, 6)

        mu_hat = prediction[..., :2]
        sigma_hat = prediction[..., 2:].reshape((total_agents, 2, 2))

        # Verlet integration
        mu = x_prev + self.velocity_const * dx + mu_hat  # (A, 2)
        # sigma = self.symmetrize_and_exp(sigma_hat) # (A, 2, 2)
        sigma_sym = sigma_hat + sigma_hat.transpose(-1, -2)
        sigma = torch.matrix_exp(
            sigma_sym)  # torch.matrix_exp is added in PyTorch 1.7.0

        x = sigma.matmul(source_noise.unsqueeze(-1)).squeeze(-1) + mu  # (A, 2)

        return x, mu, sigma
Пример #10
0
 def blas_lapack_ops(self):
     m = torch.randn(3, 3)
     a = torch.randn(10, 3, 4)
     b = torch.randn(10, 4, 3)
     v = torch.randn(3)
     return (
         torch.addbmm(m, a, b),
         torch.addmm(torch.randn(2, 3), torch.randn(2, 3),
                     torch.randn(3, 3)),
         torch.addmv(torch.randn(2), torch.randn(2, 3), torch.randn(3)),
         torch.addr(torch.zeros(3, 3), v, v),
         torch.baddbmm(m, a, b),
         torch.bmm(a, b),
         torch.chain_matmul(torch.randn(3, 3), torch.randn(3, 3),
                            torch.randn(3, 3)),
         # torch.cholesky(a), # deprecated
         torch.cholesky_inverse(torch.randn(3, 3)),
         torch.cholesky_solve(torch.randn(3, 3), torch.randn(3, 3)),
         torch.dot(v, v),
         torch.eig(m),
         torch.geqrf(a),
         torch.ger(v, v),
         torch.inner(m, m),
         torch.inverse(m),
         torch.det(m),
         torch.logdet(m),
         torch.slogdet(m),
         torch.lstsq(m, m),
         torch.lu(m),
         torch.lu_solve(m, *torch.lu(m)),
         torch.lu_unpack(*torch.lu(m)),
         torch.matmul(m, m),
         torch.matrix_power(m, 2),
         # torch.matrix_rank(m),
         torch.matrix_exp(m),
         torch.mm(m, m),
         torch.mv(m, v),
         # torch.orgqr(a, m),
         # torch.ormqr(a, m, v),
         torch.outer(v, v),
         torch.pinverse(m),
         # torch.qr(a),
         torch.solve(m, m),
         torch.svd(a),
         # torch.svd_lowrank(a),
         # torch.pca_lowrank(a),
         # torch.symeig(a), # deprecated
         # torch.lobpcg(a, b), # not supported
         torch.trapz(m, m),
         torch.trapezoid(m, m),
         torch.cumulative_trapezoid(m, m),
         # torch.triangular_solve(m, m),
         torch.vdot(v, v),
     )
Пример #11
0
def _expm_torch(X, basis=None):
    backend = utils.max_backend(X, basis)
    X = utils.as_tensor(X, **backend)

    if basis is not None:
        # X contains parameters in the Lie algebra -> reconstruct the matrix
        # X.shape = [.., F], basis.shape = [..., F, D, D]
        basis = utils.as_tensor(basis, **backend)
        X = torch.sum(basis * X[..., None, None], dim=-3)

    return torch.matrix_exp(X)
Пример #12
0
    def forward(ctx, input):
        # detach so we can cast to NumPy
        # E = slin.expm(input.detach().numpy())
        # f = np.trace(E)
        # E = torch.from_numpy(E)

        E = torch.matrix_exp(input.detach())
        f = torch.trace(E)

        ctx.save_for_backward(E)
        return torch.as_tensor(f, dtype=input.dtype)
Пример #13
0
    def forward(self, x, rho, alpha, temperature) -> tuple:

        w_prime = self._preprocess_graph(self.w, tau=temperature,
                                         seed=self.seed)
        w_prime = self.pns_mask * w_prime
        mse_loss = self._get_mse_loss(x, w_prime)
        h = (torch.trace(torch.matrix_exp(w_prime * w_prime)) - self.n_nodes)
        loss = (0.5 / self.n_samples * mse_loss
                + self.l1_graph_penalty * torch.linalg.norm(w_prime, ord=1)
                + alpha * h
                + 0.5 * rho * h * h)

        return loss, h, self.w
Пример #14
0
    def infer(self, pred_traj, context_enc, x_prev, dx):
        """Infer the latent code given a trajectory (Normalizing Flow).

        Args:
            pred_traj (FloatTensor): Pred trajectory to do inference.
            encoding (FloatTensor): Context encoding.
            x_prev (FloatTensor): decoding position at the previous time (see the forward method).
            dx (FloatTensor): velocities at the previous time (see the forward method).
        
        Input Shapes:
            x: (A, T, 2)
            lc_encoding: (A, T, D_{local})
            x_prev: (A, T, 2)
            dx: (A, T, 2)

        Output Shapes:
            z: (A, T, 2)
            mu: (A, T, 2)
            sigma: (A, T, 2, 2)
        """
        total_agents = pred_traj.size(0)  # The number of agents A.
        pred_steps = pred_traj.size(1)  # The prediction steps T.

        output = self.mlp(context_enc)
        prediction = self.projection(output)  # (A, T, 6)

        mu_hat = prediction[..., :2]
        sigma_hat = prediction[..., 2:].reshape(
            (total_agents, pred_steps, 2, 2))

        # Verlet integration
        mu = x_prev + self.velocity_const * dx + mu_hat
        # sigma = self.symmetrize_and_exp(sigma_hat)
        sigma_sym = sigma_hat + sigma_hat.transpose(-1, -2)
        sigma = torch.matrix_exp(
            sigma_sym)  # torch.matrix_exp is added in PyTorch 1.7.0
        sigma = sigma.reshape(total_agents, pred_steps, 2, 2)

        # solve  Z = inv(sigma) * (X-mu)
        x_mu = (pred_traj - mu).unsqueeze(-1)  # (A, T, 2, 1)
        z, _ = x_mu.solve(sigma)  # (A, T, 2, 1)
        z = z.squeeze(-1)  # (A, T, 2)

        return z, mu, sigma
Пример #15
0
 def get_act_repr(self, act_param, lie_alg_basis_ls):
     b, n_acts = list(act_param.size())
     assert n_acts == len(
         lie_alg_basis_ls)  # Assume each group has 1 latent dim.
     act_repr = torch.zeros(b, self.full_mat_dim,
                            self.full_mat_dim).to(act_param)
     b_idx = 0
     for i, subgroup_size_i in enumerate(self.subgroup_sizes_ls):
         mat_dim = int(math.sqrt(subgroup_size_i))
         assert mat_dim * mat_dim == subgroup_size_i
         e_idx = b_idx + mat_dim
         assert list(lie_alg_basis_ls[i].size()) == [1, mat_dim, mat_dim]
         lie_alg = act_param[:, i][..., np.newaxis, np.newaxis] * (
             lie_alg_basis_ls[i] - lie_alg_basis_ls[i].transpose(-2, -1)
         )  # Assume each latent subspace is 1, and oth basis.
         lie_group = torch.matrix_exp(lie_alg)  # [b, mat_dim, mat_dim]
         act_repr[:, b_idx:e_idx, b_idx:e_idx] = lie_group
         b_idx = e_idx
     return act_repr
Пример #16
0
def sym_expm(x: torch.Tensor, using_native=False) -> torch.Tensor:
    r"""Symmetric matrix exponent.

    Parameters
    ----------
    x : torch.Tensor
        symmetric matrix
    using_native : bool, optional
        if using native matrix exponent `torch.matrix_exp`, by default False

    Returns
    -------
    torch.Tensor
        :math:`\exp(x)`
    """
    if using_native:
        return torch.matrix_exp(x)
    else:
        return sym_funcm(x, torch.exp)
Пример #17
0
    def forward(self, X: torch.Tensor) -> torch.Tensor:
        n, k = X.size(-2), X.size(-1)
        transposed = n < k
        if transposed:
            X = X.mT
            n, k = k, n
        # Here n > k and X is a tall matrix
        if self.orthogonal_map == _OrthMaps.matrix_exp or self.orthogonal_map == _OrthMaps.cayley:
            # We just need n x k - k(k-1)/2 parameters
            X = X.tril()
            if n != k:
                # Embed into a square matrix
                X = torch.cat(
                    [X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)],
                    dim=-1)
            A = X - X.mH
            # A is skew-symmetric (or skew-hermitian)
            if self.orthogonal_map == _OrthMaps.matrix_exp:
                Q = torch.matrix_exp(A)
            elif self.orthogonal_map == _OrthMaps.cayley:
                # Computes the Cayley retraction (I+A/2)(I-A/2)^{-1}
                Id = torch.eye(n, dtype=A.dtype, device=A.device)
                Q = torch.linalg.solve(torch.add(Id, A, alpha=-0.5),
                                       torch.add(Id, A, alpha=0.5))
            # Q is now orthogonal (or unitary) of size (..., n, n)
            if n != k:
                Q = Q[..., :k]
            # Q is now the size of the X (albeit perhaps transposed)
        else:
            # X is real here, as we do not support householder with complex numbers
            A = X.tril(diagonal=-1)
            tau = 2. / (1. + (A * A).sum(dim=-2))
            Q = torch.linalg.householder_product(A, tau)
            # The diagonal of X is 1's and -1's
            # We do not want to differentiate through this or update the diagonal of X hence the casting
            Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)

        if hasattr(self, "base"):
            Q = self.base @ Q
        if transposed:
            Q = Q.mT
        return Q
    def val_exp(self, x, lie_alg_basis_ls, lie_var_ls, lie_alg_init_type):
        lie_alg_basis_ls = [p * 1. for p in lie_alg_basis_ls
                            ]  # For torch.cat, convert param to tensor.
        lie_alg_basis = torch.cat(lie_alg_basis_ls,
                                  dim=0)[np.newaxis,
                                         ...]  # [1, lat_dim, mat_dim, mat_dim]
        if lie_alg_init_type == 'oth':
            lie_alg_basis = lie_alg_basis - lie_alg_basis.transpose(-2, -1)
        if self.normalize_alg:
            norm_v = torch.linalg.norm(lie_alg_basis, dim=[-2, -1])
            lie_alg_basis = lie_alg_basis / norm_v
        if self.use_alg_var:
            lie_var = torch.cat(lie_var_ls, dim=1)  # [1, lat_dim]
            lie_alg_basis = lie_alg_basis * lie_var[..., np.newaxis,
                                                    np.newaxis]

        lie_alg_mul = x[
            ..., np.newaxis,
            np.newaxis] * lie_alg_basis  # [b, lat_dim, mat_dim, mat_dim]
        lie_alg = torch.sum(lie_alg_mul, dim=1)  # [b, mat_dim, mat_dim]
        lie_group = torch.matrix_exp(lie_alg)  # [b, mat_dim, mat_dim]
        return lie_group
Пример #19
0
 def loss_torch(p):
     return torch.sum((torch.matrix_exp(p) - torch.eye(3))**2)
Пример #20
0
 def forward(self, X):
     return torch.matrix_exp(X)
Пример #21
0
    def forward(self,
                warp_inputs: List[List[torch.Tensor]],
                state: List[List[torch.Tensor]] = None):
        embeddings = []

        # Get all inputs to be same dimension so we can pass through the
        # transformer cascade
        for i, warp_input in enumerate(warp_inputs):
            embeddings_1 = self.input_modules[i][0](warp_input[0])
            embeddings_2 = self.input_modules[i][1](warp_input[1])
            embeddings += [embeddings_1, embeddings_2]

        embeddings = torch.cat(embeddings, dim=-1)
        embeddings = self.input_linear(embeddings)
        embeddings = self.relu(embeddings)
        embeddings = self.dropout(embeddings)
        embeddings = embeddings.transpose(0, 1)

        for i in range(self.transformer_layers):
            embeddings = self.transformer_encoder_layers[i](embeddings)
            embeddings = self.layer_norms[i](embeddings)

        embeddings = embeddings.transpose(0, 1)
        embeddings = embeddings.sum(dim=-2)

        output_raw = self.output_head(embeddings)
        output_raw_size = output_raw.size()
        output_raw = output_raw.reshape([-1, output_raw.size(-1)])
        range_start = 0
        kronecker_matrices = []
        for i, param in enumerate(self.warp_parameters):
            output_matrices = []
            for dim in self.output_shapes[i]:
                length = 2 * dim * self.num_sym_tensor_products
                matrix_gen = output_raw[:, range_start:range_start + length]
                matrix_gen = matrix_gen.reshape(-1, 2 * dim)
                matrix_gen_a = matrix_gen[:, :dim]
                matrix_gen_b = matrix_gen[:, dim:]
                ab = torch.bmm(matrix_gen_a.unsqueeze(-1),
                               matrix_gen_b.unsqueeze(-2))
                ba = torch.bmm(matrix_gen_b.unsqueeze(-1),
                               matrix_gen_a.unsqueeze(-2))
                matrix = 0.5 * (ab + ba)
                matrix = matrix.reshape(output_raw.size(0), -1, dim,
                                        dim).sum(dim=-3)
                matrix = matrix.reshape(list(output_raw_size)[:-1] + 2 * [dim])
                exp_matrix = torch.matrix_exp(
                    matrix.reshape((-1, matrix.size(-2), matrix.size(-1))))
                exp_matrix = exp_matrix.reshape(matrix.size())
                output_matrices.append(exp_matrix)
                range_start += length

            input_matrices = []
            for dim in self.input_shapes[i]:
                length = 2 * dim * self.num_sym_tensor_products
                matrix_gen = output_raw[:, range_start:range_start + length]
                matrix_gen = matrix_gen.reshape(-1, 2 * dim)
                matrix_gen_a = matrix_gen[:, :dim]
                matrix_gen_b = matrix_gen[:, dim:]
                ab = torch.bmm(matrix_gen_a.unsqueeze(-1),
                               matrix_gen_b.unsqueeze(-2))
                ba = torch.bmm(matrix_gen_b.unsqueeze(-1),
                               matrix_gen_a.unsqueeze(-2))
                matrix = 0.5 * (ab + ba)
                matrix = matrix.reshape(output_raw.size(0), -1, dim,
                                        dim).sum(dim=-3)
                matrix = matrix.reshape(list(output_raw_size)[:-1] + 2 * [dim])
                exp_matrix = torch.matrix_exp(
                    matrix.reshape((-1, matrix.size(-2), matrix.size(-1))))
                exp_matrix = exp_matrix.reshape(matrix.size())
                input_matrices.append(exp_matrix)
                range_start += length

            kronecker_matrices.append([input_matrices, output_matrices])

        return kronecker_matrices
Пример #22
0
    transform=transform
)

test_loader = DataLoader(test_dataset, batch_size=1, num_workers=6, shuffle=False)
real_test_loader = DataLoader(real_test_dataset, batch_size=1, num_workers=6, shuffle=False)

# iterating over all test and real test images, appending submission
for i, data in enumerate(test_loader):
    filename = test_loader.dataset.sample_ids[i]
    img, target, K = data
    img = img.cuda()

    t, att = val_net.forward(img)

    # Conversion: prv --> q
    R_pred = torch.matrix_exp(skew_symmetric(att))
    q = dcm_to_ep(R_pred)
    q = q.detach().cpu().numpy()
    if np.any(np.isnan(q)):
        print("Test: Nans found in q!")
        q = np.array([[1.0, 0.0, 0.0, 0.0]])
        print(q[0])

    # Conversion: [delta u, delta v, tz] -> r = [tx, ty, tz]
    r = origin_reg_conversion(target.detach().cpu().numpy()[0], K.detach().cpu().numpy()[0], t.detach().cpu().numpy()[0])
    if np.any(np.isnan(r)):
        print("Test: Nans found in r!")
    submission.append_test(filename, q[0], r.tolist())

for i, data in enumerate(real_test_loader):
    filename = real_test_loader.dataset.sample_ids[i]
Пример #23
0
def compute_h(w_adj):

    d = w_adj.shape[0]
    h = torch.trace(torch.matrix_exp(w_adj * w_adj)) - d

    return h
Пример #24
0
    while (epoch <= maxEpochs):

        if epoch % update_print_ct == 0:
            if np.abs(prev_loss - curr_loss) < eps:
                counteps += 1
                if counteps == 3:
                    print('The network has converged, eps = ' + str(eps))
                    break
            prev_loss = curr_loss

        for i in range(0, trainXp.shape[0]):

            dt = dt_list[i]

            K = torch.matrix_exp(net.Lgen * dt)
            # eL = torch.diag_embed(torch.exp(net.L*dt)) # exponential of the eigs, then embedded into a diagonal matrix
            # K = torch.matmul(torch.matmul(net.V,eL),torch.inverse(net.V)) # matrix representation of Koopman operator

            Kpsixp = torch.matmul(net(trainXp[i:i + 1]), K)
            psixf = net(trainXf[i:i + 1])
            loss = loss_func(psixf, Kpsixp)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        curr_loss = loss.item()

        if epoch % update_print_ct == 0:
            print('[' + str(epoch) + ']' + ' loss = ' + str(curr_loss))