class DRAW(nn.Module): def __init__(self, q_size=10, encoding_size=128, timesteps=10, training=True, use_attention=False): super(DRAW, self).__init__() self.training = training self.encoding_size = encoding_size self.q_size = q_size self.use_attention = use_attention self.timesteps = timesteps # use equal encoding and decoding size self.encoder_rnn = BasicRNN(output_size=self.encoding_size) self.decoder_rnn = BasicRNN(output_size=self.encoding_size) self.register_parameter('decoder_linear_weights', None) self.register_parameter('encoding_mu_weights', None) self.register_parameter('encoding_logvar_weights', None) def initialize(self, x): batch_size = x.size(0) self.decoder_linear_weights = nn.Parameter( torch.Tensor(x.nelement() / batch_size, self.encoding_size)) stdv = 1. / math.sqrt(self.decoder_linear_weights.size(1)) self.decoder_linear_weights.data.uniform_(-stdv, stdv) self.encoding_mu_weights = nn.Parameter( torch.Tensor(self.q_size, self.encoding_size)) stdv = 1. / math.sqrt(self.encoding_mu_weights.size(1)) self.encoding_mu_weights.data.uniform_(-stdv, stdv) self.encoding_logvar_weights = nn.Parameter( torch.Tensor(self.q_size, self.encoding_size)) stdv = 1. / math.sqrt(self.encoding_logvar_weights.size(1)) self.encoding_logvar_weights.data.uniform_(-stdv, stdv) if x.data.is_cuda: self.cuda() # selects where to sample from the input image, no attention version # dims is 2*W*H def read(self, x, x_hat, dec_state): return torch.cat((x, x_hat), 1) # write takes use from "encoding space" to image space def write(self, decoding): return F.linear(decoding, self.decoder_linear_weights) # this converts the encoding into both a mu and logvar vector def sampleZ(self, encoding): mu = F.linear(encoding, self.encoding_mu_weights) logvar = F.linear(encoding, self.encoding_logvar_weights) return self.reparameterize(mu, logvar), mu, logvar def reparameterize(self, mu, logvar): if self.training: std = logvar.mul(0.5).exp_() eps = Variable(std.data.new(std.size()).normal_()) return eps.mul(std).add_(mu) else: return mu # takes an input, returns the sequence of outputs, mus, and logvars def forward(self, x): # flatten x to 1-d, except for batch dimension xview = x.view(x.size()[0], x.nelement() / x.size()[0]) batch_size = x.size()[0] if self.decoder_linear_weights is None: self.initialize(xview) # zero out initial states self.encoder_rnn.reset_hidden_state(batch_size) self.decoder_rnn.reset_hidden_state(batch_size) outputs, mus, logvars = [], [], [] outputs.append(Variable(torch.zeros(x.size()))) for t in range(0, self.timesteps): # Step 1: diff the input against the prev output x_hat = xview - F.sigmoid(outputs[t].view(xview.size())) # Step 2: read rvec = self.read(xview, x_hat, self.decoder_rnn.get_hidden_state()) # Step 3: encoder rnn # note the dimensions of r doesn't have to match with the decoding size because # we are just concating 2 dim-1 tensors, which is kind of wierd, but ok... cat = torch.cat((rvec, self.decoder_rnn.get_hidden_state().view( batch_size, self.encoding_size)), 1) encoding = self.encoder_rnn.forward(cat) # Step 4: sample z z, mu, logvar = self.sampleZ(encoding) # store the mu and logvar for the loss function mus.append(mu) logvars.append(logvar) # Step 5: decoder rnn decoding = self.decoder_rnn.forward(z) # Step 6: write to canvas, (in the original dimensions of the input) outputs.append( torch.add(outputs[-1], self.write(decoding).view(x.size()))) return outputs, mus, logvars
class DRAW(nn.Module): def __init__(self, q_size=10, encoding_size=128, timesteps=10, training=True, use_attention=False, grid_size=5): super(DRAW, self).__init__() self.training = training self.encoding_size = encoding_size self.q_size = q_size self.use_attention = use_attention self.timesteps = timesteps # use equal encoding and decoding size self.encoder_rnn = BasicRNN(hstate_size=self.encoding_size, output_size=self.encoding_size) self.decoder_rnn = BasicRNN(hstate_size=self.encoding_size, output_size=self.encoding_size) self.register_parameter('decoder_linear_weights', None) self.register_parameter('encoding_mu_weights', None) self.register_parameter('encoding_logvar_weights', None) self.filter_linear_layer = nn.Linear(self.encoding_size, 5) self.grid_size = grid_size self.minclamp = 1e-8 self.maxclamp = 1e8 def initialize(self, x): batch_size = x.size(0) # we use attention, the decoder producers a patch of grid_size x grid_size # else it produces an output of the original image size if self.use_attention: self.decoder_linear_weights = nn.Parameter( torch.Tensor(self.grid_size * self.grid_size, self.encoding_size)) else: self.decoder_linear_weights = nn.Parameter( torch.Tensor(old_div(x.nelement(), batch_size), self.encoding_size)) stdv = 1. / math.sqrt(self.decoder_linear_weights.size(1)) self.decoder_linear_weights.data.uniform_(-stdv, stdv) self.encoding_mu_weights = nn.Parameter( torch.Tensor(self.q_size, self.encoding_size)) stdv = 1. / math.sqrt(self.encoding_mu_weights.size(1)) self.encoding_mu_weights.data.uniform_(-stdv, stdv) self.encoding_logvar_weights = nn.Parameter( torch.Tensor(self.q_size, self.encoding_size)) stdv = 1. / math.sqrt(self.encoding_logvar_weights.size(1)) self.encoding_logvar_weights.data.uniform_(-stdv, stdv) if x.data.is_cuda: self.cuda() # selects where to sample from the input image, no attention version # dims is 2*W*H def read(self, x, x_hat, dec_state): return torch.cat((x, x_hat), 1) # generate two sets of filterbanks # 1) batch x N x W (Fx) # 2) batch x N x H (Fy) def generate_filter_matrices(self, gx, gy, sigma2, delta): N = self.grid_size grid_points = torch.arange(0, N).view((1, N, 1)) a = torch.arange(0, self.image_w).view((1, 1, -1)) b = torch.arange(0, self.image_h).view((1, 1, -1)) if gx.data.is_cuda: grid_points = grid_points.cuda() a = a.cuda() b = b.cuda() # gx is Bx1, grid is (1xNx1), so this is a broadcast op -> BxNx1 mux = gx.view( (-1, 1, 1)) + (grid_points.float() - old_div(N, 2) - 0.5) * delta.view( (-1, 1, 1)) muy = gy.view( (-1, 1, 1)) + (grid_points.float() - old_div(N, 2) - 0.5) * delta.view( (-1, 1, 1)) s2 = sigma2.view((-1, 1, 1)) fx = torch.exp(old_div(-(a.float() - mux).pow(2), (2 * s2))) fy = torch.exp(old_div(-(b.float() - muy).pow(2), (2 * s2))) # normalize fx = old_div( fx, torch.clamp(torch.sum(fx, 2, keepdim=True), self.minclamp, self.maxclamp)) fy = old_div( fy, torch.clamp(torch.sum(fy, 2, keepdim=True), self.minclamp, self.maxclamp)) return fx, fy def generate_filter_params(self, state): filter_vector = self.filter_linear_layer(state) _gx, _gy, log_sigma2, log_delta, loggamma = filter_vector.split(1, 1) gx = old_div((self.image_w + 1), 2) * (_gx + 1) gy = old_div((self.image_h + 1), 2) * (_gy + 1) sigma2 = torch.exp(log_sigma2) delta = old_div((max(self.image_w, self.image_h) - 1), (self.grid_size - 1)) * torch.exp(log_delta) gamma = torch.exp(loggamma) return gx, gy, sigma2, delta, gamma def read_w_att(self, x, x_hat, dec_state): batch_size = x.size()[0] # 1) linear to convert dec_state into batchx5 params gx,gy,logsigma2,logdelta,loggamma # 2) convert to gaussian parameters gx, gy, sigma2, delta, gamma = self.generate_filter_params(dec_state) # 3) generate filter matrices fx, fy = self.generate_filter_matrices(gx, gy, sigma2, delta) # 4) apply filter matrices to get glimpses output = gamma.view(-1, 1, 1) * torch.bmm( torch.bmm(fy, x.view(batch_size, self.image_h, self.image_w)), torch.transpose(fx, 1, 2)) output_hat = gamma.view(-1, 1, 1) * torch.bmm( torch.bmm(fy, x_hat.view(batch_size, self.image_h, self.image_w)), torch.transpose(fx, 1, 2)) output_total = torch.cat( (output.view(batch_size, self.grid_size * self.grid_size), output_hat.view(batch_size, self.grid_size * self.grid_size)), 1) return output_total # write takes use from "encoding space" to image space def write(self, decoding): return F.linear(decoding, self.decoder_linear_weights) def write_w_att(self, decoding): batch_size = decoding.size()[0] write_patch = F.linear(decoding, self.decoder_linear_weights).view( batch_size, self.grid_size, self.grid_size) gx, gy, sigma2, gamma, delta = self.generate_filter_params(decoding) fx, fy = self.generate_filter_matrices(gx, gy, sigma2, delta) output = (old_div(1, gamma)).view(-1, 1, 1) * torch.bmm( torch.bmm(fy.transpose(1, 2), write_patch), fx) return output # this converts the encoding into both a mu and logvar vector def sampleZ(self, encoding): mu = F.linear(encoding, self.encoding_mu_weights) logvar = F.linear(encoding, self.encoding_logvar_weights) return self.reparameterize(mu, logvar), mu, logvar def reparameterize(self, mu, logvar): if self.training: std = logvar.mul(0.5).exp_() eps = std.data.new(std.size()).normal_() return eps.mul(std).add_(mu) else: return mu # takes an input, returns the sequence of outputs, mus, and logvars def forward(self, x): # flatten x to 1-d, except for batch dimension xview = x.view(x.size()[0], old_div(x.nelement(), x.size()[0])) # assume bchw dims self.image_w = x.size(3) self.image_h = x.size(2) batch_size = x.size()[0] if self.decoder_linear_weights is None: self.initialize(xview) # zero out initial states self.encoder_rnn.reset_hidden_state(batch_size, x.data.is_cuda) self.decoder_rnn.reset_hidden_state(batch_size, x.data.is_cuda) outputs, mus, logvars = [], [], [] init_tensor = torch.zeros(x.size()) if x.data.is_cuda: init_tensor = init_tensor.cuda() outputs.append(init_tensor) if self.use_attention: read_fn = self.read_w_att write_fn = self.write_w_att else: read_fn = self.read write_fn = self.write for t in range(0, self.timesteps): # import ipdb;ipdb.set_trace() # Step 1: diff the input against the prev output x_hat = xview - torch.sigmoid(outputs[t].view(xview.size())) # Step 2: read rvec = read_fn(xview, x_hat, self.decoder_rnn.get_hidden_state()) # Step 3: encoder rnn # note the dimensions of r doesn't have to match with the decoding size because # we are just concating 2 dim-1 tensors, which is kind of wierd, but ok... cat = torch.cat((rvec, self.decoder_rnn.get_hidden_state().view( batch_size, self.encoding_size)), 1) encoding = self.encoder_rnn.forward(cat) # Step 4: sample z z, mu, logvar = self.sampleZ(encoding) # store the mu and logvar for the loss function mus.append(mu) logvars.append(logvar) # Step 5: decoder rnn decoding = self.decoder_rnn.forward(z) # Step 6: write to canvas, (in the original dimensions of the input) outputs.append( torch.add(outputs[-1], write_fn(decoding).view(x.size()))) # return the sigmoided versions for i in range(len(outputs)): outputs[i] = torch.sigmoid(outputs[i]) return outputs, mus, logvars
class AttentionSegmenter(nn.Module): def __init__(self, num_classes, inchans=3, att_encoding_size=128, timesteps=10, attn_grid_size=50): super(AttentionSegmenter, self).__init__() self.num_classes = num_classes self.att_encoding_size = att_encoding_size self.timesteps = timesteps self.attn_grid_size = attn_grid_size self.encoder = ConvolutionStack(inchans, final_relu=False, padding=0) self.encoder.append(32, 3, 1) self.encoder.append(32, 3, 2) self.encoder.append(64, 3, 1) self.encoder.append(64, 3, 2) self.encoder.append(96, 3, 1) self.encoder.append(96, 3, 2) self.decoder = TransposedConvolutionStack(96, final_relu=False, padding=0) self.decoder.append(96, 3, 2) self.decoder.append(64, 3, 1) self.decoder.append(64, 3, 2) self.decoder.append(32, 3, 1) self.decoder.append(32, 3, 2) self.decoder.append(self.num_classes, 3, 1) self.attn_reader = GaussianAttentionReader() self.attn_writer = GaussianAttentionWriter() self.att_rnn = BasicRNN(hstate_size=att_encoding_size, output_size=5) self.register_parameter('att_decoder_weights', None) def init_weights(self, hstate): if self.att_decoder_weights is None: batch_size = hstate.size(0) self.att_decoder_weights = nn.Parameter( torch.Tensor(5, old_div(hstate.nelement(), batch_size))) stdv = 1. / math.sqrt(self.att_decoder_weights.size(1)) self.att_decoder_weights.data.uniform_(-stdv, stdv) if hstate.data.is_cuda: self.cuda() def forward(self, x): batch_size, chans, height, width = x.size() # need to first determine the hidden state size, which is tied to the cnn feature size dummy_glimpse = torch.Tensor(batch_size, chans, self.attn_grid_size, self.attn_grid_size) if x.is_cuda: dummy_glimpse = dummy_glimpse.cuda() dummy_feature_map = self.encoder.forward(dummy_glimpse) self.att_rnn.forward( dummy_feature_map.view( batch_size, old_div(dummy_feature_map.nelement(), batch_size))) self.att_rnn.reset_hidden_state(batch_size, x.data.is_cuda) outputs = [] init_tensor = torch.zeros(batch_size, self.num_classes, height, width) if x.data.is_cuda: init_tensor = init_tensor.cuda() outputs.append(init_tensor) self.init_weights(self.att_rnn.get_hidden_state()) for t in range(self.timesteps): # 1) decode hidden state to generate gaussian attention parameters state = self.att_rnn.get_hidden_state() gauss_attn_params = torch.tanh( F.linear(state, self.att_decoder_weights)) # 2) extract glimpse glimpse = self.attn_reader.forward(x, gauss_attn_params, self.attn_grid_size) # visualize first glimpse in batch for all t torch_glimpses = torch.chunk(glimpse, batch_size, dim=0) ImageVisualizer().set_image( PTImage.from_cwh_torch(torch_glimpses[0].squeeze().data), 'zGlimpse {}'.format(t)) # 3) use conv stack or resnet to extract features feature_map = self.encoder.forward(glimpse) conv_output_dims = self.encoder.get_output_dims()[:-1][::-1] conv_output_dims.append(glimpse.size()) # import ipdb;ipdb.set_trace() # 4) update hidden state # think about this connection a bit more self.att_rnn.forward( feature_map.view(batch_size, old_div(feature_map.nelement(), batch_size))) # 5) use deconv network to get partial masks partial_mask = self.decoder.forward(feature_map, conv_output_dims) # 6) write masks additively to mask canvas partial_canvas = self.attn_writer.forward(partial_mask, gauss_attn_params, (height, width)) outputs.append(torch.add(outputs[-1], partial_canvas)) # return the sigmoided versions for i in range(len(outputs)): outputs[i] = torch.sigmoid(outputs[i]) return outputs