def forward(self, hist, nbrs, masks, op_mask, y, opt):
        opt.zero_grad()
        batch_size = y.size(0)
        target_networks_weights = self.hyper(hist, nbrs, masks)

        # Loss
        y, delta_log_py = self.point_cnf(
            y, target_networks_weights,
            torch.zeros(batch_size, y.size(1), 1).to(y))
        if self.logprob_type == "Laplace":
            log_py = standard_laplace_logprob(y).view(batch_size,
                                                      -1).sum(1, keepdim=True)
        if self.logprob_type == "Normal":
            log_py = standard_normal_logprob(y).view(batch_size,
                                                     -1).sum(1, keepdim=True)
        delta_log_py = delta_log_py.view(batch_size, y.size(1), 1).sum(1)
        log_px = log_py - delta_log_py
        log_px = log_px * op_mask
        loss = -log_px.mean()

        loss.backward()
        opt.step()
        recon = -log_px.sum()
        recon_nats = recon / float(y.size(0))
        return recon_nats
    def forward(self, x, y, opt, step, writer=None):
        opt.zero_grad()
        batch_size = x.size(0)
        z = self.flownet(x)

        # Loss
        y, delta_log_py = self.point_cnf(
            y, z,
            torch.zeros(batch_size, y.size(1), 1).to(y))
        if self.logprob_type == "Laplace":
            log_py = standard_laplace_logprob(y).view(batch_size,
                                                      -1).sum(1, keepdim=True)
        if self.logprob_type == "Normal":
            log_py = standard_normal_logprob(y).view(batch_size,
                                                     -1).sum(1, keepdim=True)
        delta_log_py = delta_log_py.view(batch_size, y.size(1), 1).sum(1)
        log_px = log_py - delta_log_py

        loss = -log_px.mean()

        loss.backward()
        opt.step()
        recon = -log_px.sum()
        recon_nats = recon / float(y.size(0))
        return recon_nats
    def get_logprob(self, hist, nbrs, masks, y_in):
        batch_size = y_in.size(0)
        target_networks_weights = self.hyper(hist, nbrs, masks)

        # Loss
        y, delta_log_py = self.point_cnf(
            y_in, target_networks_weights,
            torch.zeros(batch_size, y_in.size(1), 1).to(y_in))
        if self.logprob_type == "Laplace":
            log_py = standard_laplace_logprob(y).view(batch_size,
                                                      -1).sum(1, keepdim=True)
        if self.logprob_type == "Normal":
            log_py = standard_normal_logprob(y).view(batch_size,
                                                     -1).sum(1, keepdim=True)

        delta_log_py = delta_log_py.view(batch_size, y.size(1), 1).sum(1)
        log_px = log_py - delta_log_py

        return log_py, log_px
    def get_logprob(self, x, y_in):
        batch_size = x.size(0)
        w = self.flownet(x)

        # Loss
        y, delta_log_py = self.point_cnf(
            y_in, w,
            torch.zeros(batch_size, y_in.size(1), 1).to(y_in))
        if self.logprob_type == "Laplace":
            log_py = standard_laplace_logprob(y)
        if self.logprob_type == "Normal":
            log_py = standard_normal_logprob(y)

        batch_log_py = log_py.sum(dim=2)
        batch_log_px = batch_log_py - delta_log_py.sum(dim=2)
        log_py = log_py.view(batch_size, -1).sum(1, keepdim=True)
        delta_log_py = delta_log_py.view(batch_size, y.size(1), 1).sum(1)
        log_px = log_py - delta_log_py

        return log_py, log_px, (batch_log_py, batch_log_px)