コード例 #1
0
ファイル: mlp.py プロジェクト: lim0606/pytorch-ardae-vae
    def glogprob(self, input, context, std=None):
        # init
        assert input.dim() == 3  # bsz x ssz x x_dim
        assert context.dim() == 3  # bsz x 1 x ctx_dim
        std = self.std if std is None else std
        batch_size = input.size(0)
        sample_size = input.size(1)

        # reschape
        input = input.view(batch_size * sample_size,
                           self.input_dim)  # bsz*ssz x xdim
        _, context = expand_tensor(context,
                                   sample_size=sample_size,
                                   do_unsqueeze=False)  # bsz*ssz x xdim
        #context = context.view(batch_size*sample_size, -1) # bsz*ssz x xdim

        # grad true
        input.requires_grad = True

        # encode
        ctx = self.ctx_encode(context)
        inp = self.inp_encode(input)

        # concat
        h = torch.cat([inp, ctx], dim=1)

        # (unnorm) logprob
        logprob = -self.neglogprob(h)
        logprob = torch.sum(logprob)

        # de-noise with context
        glogprob = grad(logprob, input)

        return glogprob.view(batch_size, sample_size, self.input_dim)
コード例 #2
0
ファイル: mlp.py プロジェクト: lim0606/pytorch-ardae-vae
    def glogprob(self, input, context, std=None, scale=None):
        # init
        assert input.dim() == 3  # bsz x ssz x x_dim
        assert context.dim() == 3  # bsz x 1 x ctx_dim
        #std = self.std if std is None else std
        batch_size = input.size(0)
        sample_size = input.size(1)
        if std is None:
            std = input.new_zeros(batch_size * sample_size, 1)
        else:
            assert torch.is_tensor(std)
        if scale is None:
            scale = 1.

        # reschape
        input = input.view(batch_size * sample_size,
                           self.input_dim)  # bsz*ssz x xdim
        _, context = expand_tensor(context,
                                   sample_size=sample_size,
                                   do_unsqueeze=False)  # bsz*ssz x xdim
        #context = context.view(batch_size*sample_size, -1) # bsz*ssz x xdim
        std = std.view(batch_size * sample_size, 1)

        # encode
        ctx = self.ctx_encode(context)
        inp = self.inp_encode(input)

        # concat
        h = torch.cat([inp, ctx, std], dim=1)

        # de-noise with context
        glogprob = self.dae(h)

        return glogprob.view(batch_size, sample_size, self.input_dim)
コード例 #3
0
ファイル: mlp.py プロジェクト: lim0606/pytorch-ardae-vae
    def forward(self, input, context, std=None):
        # init
        assert input.dim() == 3  # bsz x ssz x x_dim
        assert context.dim() == 3  # bsz x 1 x ctx_dim
        std = self.std if std is None else std
        batch_size = input.size(0)
        sample_size = input.size(1)

        # reschape
        input = input.view(batch_size * sample_size,
                           self.input_dim)  # bsz*ssz x xdim
        _, context = expand_tensor(context,
                                   sample_size=sample_size,
                                   do_unsqueeze=False)  # bsz*ssz x xdim
        #context = context.view(batch_size*sample_size, -1) # bsz*ssz x xdim

        # add noise
        x_bar, eps = self.add_noise(input, std)

        # encode
        ctx = self.ctx_encode(context)
        inp = self.inp_encode(x_bar)

        # concat
        h = torch.cat([inp, ctx], dim=1)

        # de-noise with context
        glogprob = self.dae(h)
        ''' get loss '''
        #loss = (std**2)*self.loss(std*glogprob, -eps)
        loss = self.loss(std * glogprob, -eps)

        # return
        return None, loss
コード例 #4
0
    def forward(self, input, context, std=None, scale=None):
        # init
        assert input.dim() == 3  # bsz x ssz x x_dim
        assert context.dim() == 3  # bsz x csz x ctx_dim
        batch_size = input.size(0)
        sample_size = input.size(1)
        if std is None:
            std = input.new_zeros(batch_size, sample_size, 1)
        else:
            assert torch.is_tensor(std)
        if scale is None:
            scale = 1.

        # encode (context)
        csz = context.size(1)  # context sample size
        ctx = self.ctx_encode(context.view(batch_size * csz,
                                           self.context_dim)).view(
                                               batch_size, csz,
                                               self.ctx_dim).mean(dim=1,
                                                                  keepdim=True)

        # reschape
        input = input.view(batch_size * sample_size,
                           self.input_dim)  # bsz*ssz x xdim
        std = std.view(batch_size * sample_size, 1)

        # add noise
        x_bar, eps = self.add_noise(input, std)

        # grad true
        x_bar.requires_grad = True

        # encode
        _, ctx = expand_tensor(ctx,
                               sample_size=sample_size,
                               do_unsqueeze=False)  # bsz*ssz x xdim
        inp = self.inp_encode(x_bar)

        # concat
        h = torch.cat([inp, ctx, std], dim=1)

        # (unnorm) logprob
        logprob = -self.neglogprob(h)
        logprob = torch.sum(logprob)

        # de-noise with context
        glogprob = grad(logprob, x_bar)
        ''' get loss '''
        #loss = (std**2)*self.loss(std*glogprob, -eps)
        loss = self.loss(std * glogprob, -eps)

        # return
        return glogprob.view(batch_size, sample_size,
                             self.input_dim).detach(), loss
コード例 #5
0
    def logprob(self, input, context, std=None, scale=None):
        # init
        assert input.dim() == 3  # bsz x ssz x x_dim
        assert context.dim() == 3  # bsz x csz x ctx_dim
        batch_size = input.size(0)
        sample_size = input.size(1)
        if std is None:
            std = input.new_zeros(batch_size * sample_size, 1)
        else:
            assert torch.is_tensor(std)
        if scale is None:
            scale = 1.

        # encode (context)
        csz = context.size(1)  # context sample size
        ctx = self.ctx_encode(context.view(batch_size * csz,
                                           self.context_dim)).view(
                                               batch_size, csz,
                                               self.ctx_dim).mean(dim=1,
                                                                  keepdim=True)

        # reschape
        input = input.view(batch_size * sample_size,
                           self.input_dim)  # bsz*ssz x xdim
        std = std.view(batch_size * sample_size, 1)

        # encode
        _, ctx = expand_tensor(ctx,
                               sample_size=sample_size,
                               do_unsqueeze=False)  # bsz*ssz x xdim
        inp = self.inp_encode(input)

        # concat
        h = torch.cat([inp, ctx, std], dim=1)

        # (unnorm) logprob
        logprob = -self.neglogprob(h)

        return logprob.view(batch_size, sample_size, 1)
コード例 #6
0
def convert_2d_3d_tensor(input, sample_size):
    assert input.dim() == 2
    input_expanded, _ = expand_tensor(input, sample_size, do_unsqueeze=True)
    return input_expanded