class AutoEncoder(nn.Module): def __init__(self,inchans=3): super(AutoEncoder, self).__init__() # conv, deconv self.inchans=inchans self.convs = ConvolutionStack(self.inchans) self.convs.append(3,3,2) self.convs.append(6,3,1) self.convs.append(16,3,1) self.convs.append(32,3,2) self.tconvs = TransposedConvolutionStack(32,final_relu=False) self.tconvs.append(16,3,2) self.tconvs.append(6,3,1) self.tconvs.append(3,3,1) self.tconvs.append(self.inchans,3,2) def forward(self, x): input_dims = x.size() x = self.convs.forward(x) # TODO: this is a dumb way to get the output dims for the deconv output_dims = self.convs.get_output_dims()[:-1][::-1] output_dims.append(input_dims) # print output_dims # get outputs from conv and pass them back to deconv x = self.tconvs.forward(x,output_dims) return F.sigmoid(x)
class VAE(nn.Module): def __init__(self, encoding_size=128, training=True): super(VAE, self).__init__() self.training = training self.encoding_size = encoding_size self.outchannel_size = 256 # encoding conv self.encoder = ConvolutionStack(3, final_relu=False, padding=0) self.encoder.append(16, 3, 2) self.encoder.append(32, 3, 1) self.encoder.append(64, 3, 2) self.encoder.append(128, 3, 2) self.encoder.append(self.outchannel_size, 3, 1) # decode self.decoder = TransposedConvolutionStack(self.outchannel_size, final_relu=False, padding=0) self.decoder.append(128, 3, 1) self.decoder.append(64, 3, 2) self.decoder.append(32, 3, 2) self.decoder.append(16, 3, 1) self.decoder.append(3, 3, 2) self.register_parameter('linear_mu_weights', None) self.register_parameter('linear_logvar_weights', None) self.register_parameter('linear_decode_weights', None) def initialize_linear_params(self, is_cuda): # linear op y = x*A_T + b # so here the dims are [b x c] * [c x s], then the weights need to have dims (s x c) # where s is the encoding size and b is the batch size self.linear_mu_weights = nn.Parameter( torch.Tensor(self.encoding_size, self.linear_size)) stdv = 1. / math.sqrt(self.linear_mu_weights.size(1)) self.linear_mu_weights.data.uniform_(-stdv, stdv) self.linear_logvar_weights = nn.Parameter( torch.Tensor(self.encoding_size, self.linear_size)) stdv = 1. / math.sqrt(self.linear_logvar_weights.size(1)) self.linear_logvar_weights.data.uniform_(-stdv, stdv) self.linear_decode_weights = nn.Parameter( torch.Tensor(self.linear_size, self.encoding_size)) stdv = 1. / math.sqrt(self.linear_decode_weights.size(1)) self.linear_decode_weights.data.uniform_(-stdv, stdv) if is_cuda: self.cuda() def encode(self, x): input_dims = x.size() conv_out = self.encoder.forward(x) conv_out = F.relu(conv_out) self.encoding_feature_map = conv_out self.conv_output_dims = self.encoder.get_output_dims()[:-1][::-1] self.conv_output_dims.append(input_dims) # print conv_out.size() # OPTION A -- AVERAGE POOL -> FC # assume bchw format [1,C,7,7] for inputs of size 100x100 # self.pool_size = conv_out.size(2) # h1 = F.avg_pool2d(conv_out,kernel_size=self.pool_size,stride=self.pool_size) # assert that h1 has dimensions b x c x 1 x 1 (squeeze to b x c) # OPTION B -- DIRECT FC self.conv_out_spatial = [conv_out.size(2), conv_out.size(3)] self.linear_size = self.outchannel_size * conv_out.size( 2) * conv_out.size(3) if self.linear_mu_weights is None: self.initialize_linear_params(x.data.is_cuda) mu = F.linear(conv_out.view(-1, self.linear_size), self.linear_mu_weights) logvar = F.linear(conv_out.view(-1, self.linear_size), self.linear_logvar_weights) # mu = self.linear_mu(conv_out.view(-1,linear_size)) # logvar = self.linear_logvar(conv_out.view(-1,linear_size)) return mu, logvar def reparameterize(self, mu, logvar): if self.training: std = logvar.mul(0.5).exp_() eps = Variable(std.data.new(std.size()).normal_()) return eps.mul(std).add_(mu) else: return mu def decode(self, z): # the output dims here should be [b x c] # OPTION A -- upsample # assert self.pool_size is not None # next upsample here to dimensions of conv_out from the encoder # h3 = F.upsample(h2.view(-1,self.outchannel_size,1,1),scale_factor=self.pool_size) # OPTION B -- Direct FC if self.linear_decode_weights is None: self.initialize_linear_params(z.data.is_cuda) h2 = F.relu(F.linear(z, self.linear_decode_weights)) h3 = h2.view(-1, self.outchannel_size, self.conv_out_spatial[0], self.conv_out_spatial[1]) h4 = self.decoder.forward(h3, self.conv_output_dims) return F.sigmoid(h4) def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) return self.decode(z), mu, logvar def get_encoder(self): return self.encoder def get_encoding_feature_map(self): return self.encoding_feature_map
class AttentionSegmenter(nn.Module): def __init__(self, num_classes, inchans=3, att_encoding_size=128, timesteps=10, attn_grid_size=50): super(AttentionSegmenter, self).__init__() self.num_classes = num_classes self.att_encoding_size = att_encoding_size self.timesteps = timesteps self.attn_grid_size = attn_grid_size self.encoder = ConvolutionStack(inchans, final_relu=False, padding=0) self.encoder.append(32, 3, 1) self.encoder.append(32, 3, 2) self.encoder.append(64, 3, 1) self.encoder.append(64, 3, 2) self.encoder.append(96, 3, 1) self.encoder.append(96, 3, 2) self.decoder = TransposedConvolutionStack(96, final_relu=False, padding=0) self.decoder.append(96, 3, 2) self.decoder.append(64, 3, 1) self.decoder.append(64, 3, 2) self.decoder.append(32, 3, 1) self.decoder.append(32, 3, 2) self.decoder.append(self.num_classes, 3, 1) self.attn_reader = GaussianAttentionReader() self.attn_writer = GaussianAttentionWriter() self.att_rnn = BasicRNN(hstate_size=att_encoding_size, output_size=5) self.register_parameter('att_decoder_weights', None) def init_weights(self, hstate): if self.att_decoder_weights is None: batch_size = hstate.size(0) self.att_decoder_weights = nn.Parameter( torch.Tensor(5, old_div(hstate.nelement(), batch_size))) stdv = 1. / math.sqrt(self.att_decoder_weights.size(1)) self.att_decoder_weights.data.uniform_(-stdv, stdv) if hstate.data.is_cuda: self.cuda() def forward(self, x): batch_size, chans, height, width = x.size() # need to first determine the hidden state size, which is tied to the cnn feature size dummy_glimpse = torch.Tensor(batch_size, chans, self.attn_grid_size, self.attn_grid_size) if x.is_cuda: dummy_glimpse = dummy_glimpse.cuda() dummy_feature_map = self.encoder.forward(dummy_glimpse) self.att_rnn.forward( dummy_feature_map.view( batch_size, old_div(dummy_feature_map.nelement(), batch_size))) self.att_rnn.reset_hidden_state(batch_size, x.data.is_cuda) outputs = [] init_tensor = torch.zeros(batch_size, self.num_classes, height, width) if x.data.is_cuda: init_tensor = init_tensor.cuda() outputs.append(init_tensor) self.init_weights(self.att_rnn.get_hidden_state()) for t in range(self.timesteps): # 1) decode hidden state to generate gaussian attention parameters state = self.att_rnn.get_hidden_state() gauss_attn_params = torch.tanh( F.linear(state, self.att_decoder_weights)) # 2) extract glimpse glimpse = self.attn_reader.forward(x, gauss_attn_params, self.attn_grid_size) # visualize first glimpse in batch for all t torch_glimpses = torch.chunk(glimpse, batch_size, dim=0) ImageVisualizer().set_image( PTImage.from_cwh_torch(torch_glimpses[0].squeeze().data), 'zGlimpse {}'.format(t)) # 3) use conv stack or resnet to extract features feature_map = self.encoder.forward(glimpse) conv_output_dims = self.encoder.get_output_dims()[:-1][::-1] conv_output_dims.append(glimpse.size()) # import ipdb;ipdb.set_trace() # 4) update hidden state # think about this connection a bit more self.att_rnn.forward( feature_map.view(batch_size, old_div(feature_map.nelement(), batch_size))) # 5) use deconv network to get partial masks partial_mask = self.decoder.forward(feature_map, conv_output_dims) # 6) write masks additively to mask canvas partial_canvas = self.attn_writer.forward(partial_mask, gauss_attn_params, (height, width)) outputs.append(torch.add(outputs[-1], partial_canvas)) # return the sigmoided versions for i in range(len(outputs)): outputs[i] = torch.sigmoid(outputs[i]) return outputs