예제 #1
0
파일: flows.py 프로젝트: rikrd/torchkit
    def forward(self, x, logdet, dsparams, mollify=0.0, delta=nn_.delta):

        ndim = self.num_ds_dim
        a_ = self.act_a(dsparams[:, :, 0 * ndim:1 * ndim])
        b_ = self.act_b(dsparams[:, :, 1 * ndim:2 * ndim])
        w = self.act_w(dsparams[:, :, 2 * ndim:3 * ndim])

        a = a_ * (1 - mollify) + 1.0 * mollify
        b = b_ * (1 - mollify) + 0.0 * mollify

        pre_sigm = a * x[:, :, None] + b
        sigm = torch.sigmoid(pre_sigm)
        x_pre = torch.sum(w * sigm, dim=2)
        x_pre_clipped = x_pre * (1 - delta) + delta * 0.5
        x_ = log(x_pre_clipped) - log(1 - x_pre_clipped)
        xnew = x_

        logj = F.log_softmax(dsparams[:,:,2*ndim:3*ndim], dim=2) + \
            nn_.logsigmoid(pre_sigm) + \
            nn_.logsigmoid(-pre_sigm) + log(a)

        logj = utils.log_sum_exp(logj, 2).sum(2)
        logdet_ = logj + np.log(1-delta) - \
        (log(x_pre_clipped) + log(-x_pre_clipped+1))
        logdet = logdet_.sum(1) + logdet

        return xnew, logdet
예제 #2
0
파일: flows.py 프로젝트: rikrd/torchkit
    def forward(self, inputs):
        if len(inputs) == 2:
            input, logdet = inputs
        elif len(inputs) == 3:
            input, logdet, context = inputs
        else:
            raise (Exception('inputs length not correct'))

        output = log(input) - log(1 - input)
        logdet -= sum_from_one(log(input) + log(-input + 1))

        if len(inputs) == 2:
            return output, logdet
        elif len(inputs) == 3:
            return output, logdet, context
        else:
            raise (Exception('inputs length not correct'))
예제 #3
0
파일: flows.py 프로젝트: rikrd/torchkit
    def forward(self, x, logdet, dsparams):
        inv = np.log(np.exp(1 - nn_.delta) - 1)
        ndim = self.hidden_dim
        pre_u = self.u_[None, None, :, :] + dsparams[:, :,
                                                     -self.in_dim:][:, :,
                                                                    None, :]
        pre_w = self.w_[None, None, :, :] + dsparams[:, :, 2 * ndim:3 *
                                                     ndim][:, :, None, :]
        a = self.act_a(dsparams[:, :, 0 * ndim:1 * ndim] + inv)
        b = self.act_b(dsparams[:, :, 1 * ndim:2 * ndim])
        w = self.act_w(pre_w)
        u = self.act_u(pre_u)

        pre_sigm = torch.sum(u * a[:, :, :, None] * x[:, :, None, :], 3) + b
        sigm = torch.sigmoid(pre_sigm)
        x_pre = torch.sum(w * sigm[:, :, None, :], dim=3)
        x_pre_clipped = x_pre * (1 - nn_.delta) + nn_.delta * 0.5
        x_ = log(x_pre_clipped) - log(1 - x_pre_clipped)
        xnew = x_

        logj = F.log_softmax(pre_w, dim=3) + \
            nn_.logsigmoid(pre_sigm[:,:,None,:]) + \
            nn_.logsigmoid(-pre_sigm[:,:,None,:]) + log(a[:,:,None,:])
        # n, d, d2, dh

        logj = logj[:, :, :, :, None] + F.log_softmax(pre_u, dim=3)[:, :,
                                                                    None, :, :]
        # n, d, d2, dh, d1

        logj = utils.log_sum_exp(logj, 3).sum(3)
        # n, d, d2, d1

        logdet_ = logj + np.log(1-nn_.delta) - \
            (log(x_pre_clipped) + log(-x_pre_clipped+1))[:,:,:,None]

        logdet = utils.log_sum_exp(
            logdet_[:, :, :, :, None] + logdet[:, :, None, :, :], 3).sum(3)
        # n, d, d2, d1, d0 -> n, d, d2, d0

        return xnew, logdet