def forward(self, input): # init batch_size = input.size(0) x = input.view(batch_size, self.input_dim) # forward hidden = x if self.num_hidden_layers >= 1: for i in range(self.num_hidden_layers): hidden = get_nonlinear_func(self.nonlinearity)(self.layers[i](hidden)) output = self.fc(hidden) if self.use_nonlinearity_output: output = get_nonlinear_func(self.nonlinearity)(output) return output
def __init__( self, input_height=28, input_channels=1, z0_dim=100, z_dim=32, nonlinearity='softplus', ): super().__init__() self.input_height = input_height self.input_channels = input_channels self.z0_dim = z0_dim self.z_dim = z_dim self.nonlinearity = nonlinearity s_h = input_height s_h2 = conv_out_size(s_h, 5, 2, 2) s_h4 = conv_out_size(s_h2, 5, 2, 2) s_h8 = conv_out_size(s_h4, 5, 2, 2) #print(s_h, s_h2, s_h4, s_h8) self.afun = get_nonlinear_func(nonlinearity) self.conv1 = nn.Conv2d(self.input_channels, 16, 5, 2, 2, bias=True) self.conv2 = nn.Conv2d(16, 32, 5, 2, 2, bias=True) self.conv3 = nn.Conv2d(32, 32, 5, 2, 2, bias=True) self.fc = nn.Linear(s_h8 * s_h8 * 32 + z0_dim, 800, bias=True) self.reparam = NormalDistributionLinear(800, z_dim)
def __init__( self, input_height=28, input_channels=1, noise_dim=100, z_dim=32, nonlinearity='softplus', enc_noise=False, ): super().__init__() self.input_height = input_height self.input_channels = input_channels self.noise_dim = noise_dim self.z_dim = z_dim self.nonlinearity = nonlinearity self.enc_noise = enc_noise h_dim = 256 nos_dim = noise_dim if not enc_noise else h_dim s_h = input_height s_h2 = conv_out_size(s_h, 5, 2, 2) s_h4 = conv_out_size(s_h2, 5, 2, 2) s_h8 = conv_out_size(s_h4, 5, 2, 2) #print(s_h, s_h2, s_h4, s_h8) self.afun = get_nonlinear_func(nonlinearity) self.conv1 = nn.Conv2d(self.input_channels, 16, 5, 2, 2, bias=True) self.conv2 = nn.Conv2d(16, 32, 5, 2, 2, bias=True) self.conv3 = nn.Conv2d(32, 32, 5, 2, 2, bias=True) self.fc4 = nn.Linear(s_h8 * s_h8 * 32 + nos_dim, 800, bias=True) self.fc5 = nn.Linear(800, z_dim, bias=True) self.nos_encode = Identity() if not enc_noise \ else MLP(input_dim=noise_dim, hidden_dim=h_dim, output_dim=h_dim, nonlinearity=nonlinearity, num_hidden_layers=2, use_nonlinearity_output=True)
def forward(self, input, context): # init batch_size = input.size(0) x = input.view(batch_size, self.input_dim) ctx = context.view(batch_size, self.context_dim) # forward hidden = x if self.num_hidden_layers >= 1: for i in range(self.num_hidden_layers): _hidden = torch.cat([hidden, ctx], dim=1) hidden = get_nonlinear_func(self.nonlinearity)(self.layers[i](_hidden)) _hidden = torch.cat([hidden, ctx], dim=1) output = self.fc(_hidden) if self.use_nonlinearity_output: output = get_nonlinear_func(self.nonlinearity)(output) return output
def __init__( self, input_height=28, input_channels=1, z_dim=32, nonlinearity='softplus', #do_trim=True, ): super().__init__() self.input_height = input_height self.input_channels = input_channels self.z_dim = z_dim self.nonlinearity = nonlinearity #self.do_trim = do_trim s_h = input_height s_h2 = conv_out_size(s_h, 5, 2, 2) s_h4 = conv_out_size(s_h2, 5, 2, 2) s_h8 = conv_out_size(s_h4, 5, 2, 2) #print(s_h, s_h2, s_h4, s_h8) #_s_h8 = s_h8 #_s_h4 = deconv_out_size(_s_h8, 5, 2, 2, 0) #_s_h2 = deconv_out_size(_s_h4+1, 5, 2, 2, 0) #_s_h = deconv_out_size(_s_h2, 5, 2, 2, 0) #if self.do_trim: #else: # _s_h = deconv_out_size(_s_h2, 5, 2, 2, 1) #print(_s_h, _s_h2, _s_h4, _s_h8) #ipdb.set_trace() self.s_h8 = s_h8 self.afun = get_nonlinear_func(nonlinearity) self.fc = MLP(input_dim=z_dim, hidden_dim=300, output_dim=s_h8 * s_h8 * 32, nonlinearity=nonlinearity, num_hidden_layers=1, use_nonlinearity_output=True) self.deconv1 = nn.ConvTranspose2d(32, 32, 5, 2, 2, 0, bias=True) self.pad1 = nn.ZeroPad2d((0, 1, 0, 1)) self.deconv2 = nn.ConvTranspose2d(32, 16, 5, 2, 2, 0, bias=True) self.reparam = BernoulliDistributionConvTranspose2d( 16, self.input_channels, 5, 2, 2, 0, bias=True) self.padr = nn.ZeroPad2d((0, -1, 0, -1))
def forward(self, input1, input2): afunc = get_nonlinear_func(self.nonlinearity) hid = afunc(self.main(input1, input2)) out = self.fc(hid) return out