def glogprob(self, input, context, std=None): # init assert input.dim() == 3 # bsz x ssz x x_dim assert context.dim() == 3 # bsz x 1 x ctx_dim std = self.std if std is None else std batch_size = input.size(0) sample_size = input.size(1) # reschape input = input.view(batch_size * sample_size, self.input_dim) # bsz*ssz x xdim _, context = expand_tensor(context, sample_size=sample_size, do_unsqueeze=False) # bsz*ssz x xdim #context = context.view(batch_size*sample_size, -1) # bsz*ssz x xdim # grad true input.requires_grad = True # encode ctx = self.ctx_encode(context) inp = self.inp_encode(input) # concat h = torch.cat([inp, ctx], dim=1) # (unnorm) logprob logprob = -self.neglogprob(h) logprob = torch.sum(logprob) # de-noise with context glogprob = grad(logprob, input) return glogprob.view(batch_size, sample_size, self.input_dim)
def glogprob(self, input, context, std=None, scale=None): # init assert input.dim() == 3 # bsz x ssz x x_dim assert context.dim() == 3 # bsz x 1 x ctx_dim #std = self.std if std is None else std batch_size = input.size(0) sample_size = input.size(1) if std is None: std = input.new_zeros(batch_size * sample_size, 1) else: assert torch.is_tensor(std) if scale is None: scale = 1. # reschape input = input.view(batch_size * sample_size, self.input_dim) # bsz*ssz x xdim _, context = expand_tensor(context, sample_size=sample_size, do_unsqueeze=False) # bsz*ssz x xdim #context = context.view(batch_size*sample_size, -1) # bsz*ssz x xdim std = std.view(batch_size * sample_size, 1) # encode ctx = self.ctx_encode(context) inp = self.inp_encode(input) # concat h = torch.cat([inp, ctx, std], dim=1) # de-noise with context glogprob = self.dae(h) return glogprob.view(batch_size, sample_size, self.input_dim)
def forward(self, input, context, std=None): # init assert input.dim() == 3 # bsz x ssz x x_dim assert context.dim() == 3 # bsz x 1 x ctx_dim std = self.std if std is None else std batch_size = input.size(0) sample_size = input.size(1) # reschape input = input.view(batch_size * sample_size, self.input_dim) # bsz*ssz x xdim _, context = expand_tensor(context, sample_size=sample_size, do_unsqueeze=False) # bsz*ssz x xdim #context = context.view(batch_size*sample_size, -1) # bsz*ssz x xdim # add noise x_bar, eps = self.add_noise(input, std) # encode ctx = self.ctx_encode(context) inp = self.inp_encode(x_bar) # concat h = torch.cat([inp, ctx], dim=1) # de-noise with context glogprob = self.dae(h) ''' get loss ''' #loss = (std**2)*self.loss(std*glogprob, -eps) loss = self.loss(std * glogprob, -eps) # return return None, loss
def forward(self, input, context, std=None, scale=None): # init assert input.dim() == 3 # bsz x ssz x x_dim assert context.dim() == 3 # bsz x csz x ctx_dim batch_size = input.size(0) sample_size = input.size(1) if std is None: std = input.new_zeros(batch_size, sample_size, 1) else: assert torch.is_tensor(std) if scale is None: scale = 1. # encode (context) csz = context.size(1) # context sample size ctx = self.ctx_encode(context.view(batch_size * csz, self.context_dim)).view( batch_size, csz, self.ctx_dim).mean(dim=1, keepdim=True) # reschape input = input.view(batch_size * sample_size, self.input_dim) # bsz*ssz x xdim std = std.view(batch_size * sample_size, 1) # add noise x_bar, eps = self.add_noise(input, std) # grad true x_bar.requires_grad = True # encode _, ctx = expand_tensor(ctx, sample_size=sample_size, do_unsqueeze=False) # bsz*ssz x xdim inp = self.inp_encode(x_bar) # concat h = torch.cat([inp, ctx, std], dim=1) # (unnorm) logprob logprob = -self.neglogprob(h) logprob = torch.sum(logprob) # de-noise with context glogprob = grad(logprob, x_bar) ''' get loss ''' #loss = (std**2)*self.loss(std*glogprob, -eps) loss = self.loss(std * glogprob, -eps) # return return glogprob.view(batch_size, sample_size, self.input_dim).detach(), loss
def logprob(self, input, context, std=None, scale=None): # init assert input.dim() == 3 # bsz x ssz x x_dim assert context.dim() == 3 # bsz x csz x ctx_dim batch_size = input.size(0) sample_size = input.size(1) if std is None: std = input.new_zeros(batch_size * sample_size, 1) else: assert torch.is_tensor(std) if scale is None: scale = 1. # encode (context) csz = context.size(1) # context sample size ctx = self.ctx_encode(context.view(batch_size * csz, self.context_dim)).view( batch_size, csz, self.ctx_dim).mean(dim=1, keepdim=True) # reschape input = input.view(batch_size * sample_size, self.input_dim) # bsz*ssz x xdim std = std.view(batch_size * sample_size, 1) # encode _, ctx = expand_tensor(ctx, sample_size=sample_size, do_unsqueeze=False) # bsz*ssz x xdim inp = self.inp_encode(input) # concat h = torch.cat([inp, ctx, std], dim=1) # (unnorm) logprob logprob = -self.neglogprob(h) return logprob.view(batch_size, sample_size, 1)
def convert_2d_3d_tensor(input, sample_size): assert input.dim() == 2 input_expanded, _ = expand_tensor(input, sample_size, do_unsqueeze=True) return input_expanded