예제 #1
0
 def forward(self, c=None):
     if c is None:
         c, dist = self.sample(self.model)
         return Result(c, -self.loss_op(c, dist))  #/c.size(0) )
     else:
         return Result(c, -self.loss_op(c, self.model(
             c, sample=False)))  #/c.size(0)) #batch_size
예제 #2
0
	def forward(self, inputs, c, z=None):
		emb = [self.embc(x) for x in inputs]
		emb = sum(emb)/len(emb)
		mu = self.conv_mu(emb)
		sigma = self.conv_sigma(emb)
		dist = Normal(mu, sigma)
		if z is None: z = dist.rsample()
		return Result(z, dist.log_prob(z).sum(dim=1).sum(dim=1).sum(dim=1))
예제 #3
0
 def forward(self, inputs, c=None):
     inputs_permuted = inputs.transpose(0, 1)  # |D| * batch * ...
     embeddings = [self.enc(x) for x in inputs_permuted]
     mean_embedding = sum(embeddings) / len(embeddings)
     mu_c = self.mu_c(mean_embedding)
     sigma_c = self.sigma_c(mean_embedding)
     dist = Normal(mu_c, sigma_c)
     if c is None: c = dist.rsample()
     return Result(c, dist.log_prob(c).sum(dim=1))
예제 #4
0
 def forward(self, inputs, c, z=None):
     inputs = inputs.view(-1, 1, 28, 28)  #huh?
     mu = self.localization_mu(inputs)
     sigma = self.localization_sigma(inputs)
     dist = Normal(mu, sigma)
     if z is None:
         z = dist.rsample()
     score = dist.log_prob(z).sum(dim=1).sum(dim=1).sum(dim=1)
     return Result(z, score)
예제 #5
0
    def forward(self, c, z, x=None):
        cond = self.stn(z, c)

        cond_blocks = {}
        cond_blocks[(28, 28)] = self.cond_conv_1(self.pad(cond))
        cond_blocks[(14,
                     14)] = self.cond_conv_2(self.pad(cond_blocks[(28, 28)]))
        cond_blocks[(7, 7)] = self.cond_conv_3(self.pad(cond_blocks[(14, 14)]))

        if x is None:

            x, dist = self.sample(self.model, cond_blocks=cond_blocks)
            return Result(x, -self.loss_op(x, dist))  # /x.size(0) )
        else:

            return Result(x, -self.loss_op(
                x, self.model(x, cond_blocks=cond_blocks,
                              sample=False)))  # /x.size(0)) #batch_size
예제 #6
0
	def forward(self, inputs, c=None):	
		emb = [self.embc(x) for x in inputs]
		emb = sum(emb)/len(emb)
		mu = self.conv_mu(emb)
		sigma = self.conv_sigma(emb)

		dist = Normal(mu, sigma)
		if c is None: c = dist.rsample()
		#print(dist.log_prob(c).sum(dim=1))
		return Result(c, dist.log_prob(c).sum(dim=1).sum(dim=1).sum(dim=1))
예제 #7
0
	def forward(self, c, z, x=None):

		#make z and c into a big volume, d x 28 x 28
		#TODO: make cond_blocks
		#assume c_dim, z_dim, etc

		cond = torch.cat((c,z), 1)

		cond_blocks = {}
		cond_blocks[(28, 28)] = self.cond_conv_1(self.pad(cond))
		cond_blocks[(14, 14)] = self.cond_conv_2(self.pad(cond_blocks[(28, 28)]))
		cond_blocks[(7, 7)] = self.cond_conv_3(self.pad(cond_blocks[(14, 14)]))
				

		if x is None: 
			x, dist = self.sample(self.model, cond_blocks=cond_blocks)
			return Result(x, -loss_op(x, dist)/x.size(0) )# batch_size loss_op, luke
		else:
			#return x and distribution (or is it a loss?)
			return Result(x, -loss_op(x, self.model(x, cond_blocks=cond_blocks, sample=False))/x.size(0)) #batch_size
예제 #8
0
    def forward(self, inputs, c=None):
        # transform the input
        xs = [self.stn(inputs[:, i, :, :, :]) for i in range(inputs.size(1))]

        embs = [self.conv_post_stn(x) for x in xs]
        emb = sum(embs) / len(embs)
        mu = self.conv_mu(emb)
        sigma = self.conv_sigma(emb)
        dist = Normal(mu, sigma)
        if c is None: c = dist.rsample()
        return Result(c, dist.log_prob(c).sum(dim=1).sum(dim=1).sum(dim=1))
예제 #9
0
    def forward(self, inputs, c=None):
        #exchangability stuff
        embs = [
            self.embc(inputs[:, i, :, :, :]) for i in range(inputs.size(1))
        ]
        emb = sum(embs) / len(embs)

        emb = nn.ReLU()(emb)
        mu = self.conv_mu(emb)
        sigma = self.conv_sigma(emb)

        dist = Normal(mu, sigma)
        if c is None: c = dist.rsample()

        return Result(c, dist.log_prob(c).sum(dim=1).sum(dim=1).sum(dim=1))
예제 #10
0
 def forward(self, inputs, c, z=None):
     mu_z = self.mu_z(inputs[:, 0])
     sigma_z = self.sigma_z(inputs[:, 0])
     dist = Normal(mu_z, sigma_z)
     if z is None: z = dist.rsample()
     return Result(z, dist.log_prob(z).sum(dim=1))
예제 #11
0
 def forward(self, c, z, x=None):
     cz = torch.cat([c, z], dim=1)
     dist = Normal(self.mu(cz), self.sigma(cz))
     if x is None: x = dist.rsample()
     return Result(x, dist.log_prob(x).sum(dim=1))
예제 #12
0
 def forward(self, inputs, c=None):
     if c is None: c, score = self.net.sampleAndScore(inputs)
     else: score = self.net.score(inputs, c, autograd=True)
     return Result(value=c, reinforce_log_prob=score)
예제 #13
0
 def forward(self, c, x=None):
     _c = [[example] for example in c]  #1 robustfill 'example'
     if x is None: x, score = self.net.sampleAndScore(_c)
     else: score = self.net.score(_c, x, autograd=True)
     return Result(value=x, reinforce_log_prob=score)