예제 #1
0
class AutoEncoder(nn.Module):
    def __init__(self,inchans=3):
        super(AutoEncoder, self).__init__()
        # conv, deconv
        self.inchans=inchans
        self.convs = ConvolutionStack(self.inchans)
        self.convs.append(3,3,2)
        self.convs.append(6,3,1)
        self.convs.append(16,3,1)
        self.convs.append(32,3,2)

        self.tconvs = TransposedConvolutionStack(32,final_relu=False)
        self.tconvs.append(16,3,2)
        self.tconvs.append(6,3,1)
        self.tconvs.append(3,3,1)
        self.tconvs.append(self.inchans,3,2)

    def forward(self, x):
        input_dims = x.size()
        x = self.convs.forward(x)
        # TODO: this is a dumb way to get the output dims for the deconv
        output_dims = self.convs.get_output_dims()[:-1][::-1]
        output_dims.append(input_dims)
        # print output_dims
        # get outputs from conv and pass them back to deconv
        x = self.tconvs.forward(x,output_dims)
        return F.sigmoid(x)
예제 #2
0
class VAE(nn.Module):
    def __init__(self, encoding_size=128, training=True):
        super(VAE, self).__init__()
        self.training = training
        self.encoding_size = encoding_size
        self.outchannel_size = 256
        # encoding conv
        self.encoder = ConvolutionStack(3, final_relu=False, padding=0)
        self.encoder.append(16, 3, 2)
        self.encoder.append(32, 3, 1)
        self.encoder.append(64, 3, 2)
        self.encoder.append(128, 3, 2)
        self.encoder.append(self.outchannel_size, 3, 1)

        # decode
        self.decoder = TransposedConvolutionStack(self.outchannel_size,
                                                  final_relu=False,
                                                  padding=0)
        self.decoder.append(128, 3, 1)
        self.decoder.append(64, 3, 2)
        self.decoder.append(32, 3, 2)
        self.decoder.append(16, 3, 1)
        self.decoder.append(3, 3, 2)

        self.register_parameter('linear_mu_weights', None)
        self.register_parameter('linear_logvar_weights', None)
        self.register_parameter('linear_decode_weights', None)

    def initialize_linear_params(self, is_cuda):
        # linear op y = x*A_T + b
        # so here the dims are [b x c] * [c x s], then the weights need to have dims (s x c)
        # where s is the encoding size and b is the batch size
        self.linear_mu_weights = nn.Parameter(
            torch.Tensor(self.encoding_size, self.linear_size))
        stdv = 1. / math.sqrt(self.linear_mu_weights.size(1))
        self.linear_mu_weights.data.uniform_(-stdv, stdv)

        self.linear_logvar_weights = nn.Parameter(
            torch.Tensor(self.encoding_size, self.linear_size))
        stdv = 1. / math.sqrt(self.linear_logvar_weights.size(1))
        self.linear_logvar_weights.data.uniform_(-stdv, stdv)

        self.linear_decode_weights = nn.Parameter(
            torch.Tensor(self.linear_size, self.encoding_size))
        stdv = 1. / math.sqrt(self.linear_decode_weights.size(1))
        self.linear_decode_weights.data.uniform_(-stdv, stdv)
        if is_cuda:
            self.cuda()

    def encode(self, x):
        input_dims = x.size()
        conv_out = self.encoder.forward(x)
        conv_out = F.relu(conv_out)
        self.encoding_feature_map = conv_out
        self.conv_output_dims = self.encoder.get_output_dims()[:-1][::-1]
        self.conv_output_dims.append(input_dims)

        # print conv_out.size()
        # OPTION A -- AVERAGE POOL -> FC
        # assume bchw format [1,C,7,7] for inputs of size 100x100
        # self.pool_size = conv_out.size(2)
        # h1 = F.avg_pool2d(conv_out,kernel_size=self.pool_size,stride=self.pool_size)
        # assert that h1 has dimensions b x c x 1 x 1 (squeeze to b x c)

        # OPTION B -- DIRECT FC
        self.conv_out_spatial = [conv_out.size(2), conv_out.size(3)]
        self.linear_size = self.outchannel_size * conv_out.size(
            2) * conv_out.size(3)

        if self.linear_mu_weights is None:
            self.initialize_linear_params(x.data.is_cuda)
        mu = F.linear(conv_out.view(-1, self.linear_size),
                      self.linear_mu_weights)
        logvar = F.linear(conv_out.view(-1, self.linear_size),
                          self.linear_logvar_weights)
        # mu = self.linear_mu(conv_out.view(-1,linear_size))
        # logvar = self.linear_logvar(conv_out.view(-1,linear_size))
        return mu, logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z):
        # the output dims here should be [b x c]

        # OPTION A -- upsample
        # assert self.pool_size is not None
        # next upsample here to dimensions of conv_out from the encoder
        # h3 = F.upsample(h2.view(-1,self.outchannel_size,1,1),scale_factor=self.pool_size)

        # OPTION B -- Direct FC
        if self.linear_decode_weights is None:
            self.initialize_linear_params(z.data.is_cuda)

        h2 = F.relu(F.linear(z, self.linear_decode_weights))

        h3 = h2.view(-1, self.outchannel_size, self.conv_out_spatial[0],
                     self.conv_out_spatial[1])
        h4 = self.decoder.forward(h3, self.conv_output_dims)
        return F.sigmoid(h4)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    def get_encoder(self):
        return self.encoder

    def get_encoding_feature_map(self):
        return self.encoding_feature_map
예제 #3
0
class AttentionSegmenter(nn.Module):
    def __init__(self,
                 num_classes,
                 inchans=3,
                 att_encoding_size=128,
                 timesteps=10,
                 attn_grid_size=50):
        super(AttentionSegmenter, self).__init__()
        self.num_classes = num_classes
        self.att_encoding_size = att_encoding_size
        self.timesteps = timesteps
        self.attn_grid_size = attn_grid_size
        self.encoder = ConvolutionStack(inchans, final_relu=False, padding=0)
        self.encoder.append(32, 3, 1)
        self.encoder.append(32, 3, 2)
        self.encoder.append(64, 3, 1)
        self.encoder.append(64, 3, 2)
        self.encoder.append(96, 3, 1)
        self.encoder.append(96, 3, 2)

        self.decoder = TransposedConvolutionStack(96,
                                                  final_relu=False,
                                                  padding=0)
        self.decoder.append(96, 3, 2)
        self.decoder.append(64, 3, 1)
        self.decoder.append(64, 3, 2)
        self.decoder.append(32, 3, 1)
        self.decoder.append(32, 3, 2)
        self.decoder.append(self.num_classes, 3, 1)

        self.attn_reader = GaussianAttentionReader()
        self.attn_writer = GaussianAttentionWriter()
        self.att_rnn = BasicRNN(hstate_size=att_encoding_size, output_size=5)
        self.register_parameter('att_decoder_weights', None)

    def init_weights(self, hstate):
        if self.att_decoder_weights is None:
            batch_size = hstate.size(0)
            self.att_decoder_weights = nn.Parameter(
                torch.Tensor(5, old_div(hstate.nelement(), batch_size)))
            stdv = 1. / math.sqrt(self.att_decoder_weights.size(1))
            self.att_decoder_weights.data.uniform_(-stdv, stdv)
        if hstate.data.is_cuda:
            self.cuda()

    def forward(self, x):
        batch_size, chans, height, width = x.size()

        # need to first determine the hidden state size, which is tied to the cnn feature size
        dummy_glimpse = torch.Tensor(batch_size, chans, self.attn_grid_size,
                                     self.attn_grid_size)
        if x.is_cuda:
            dummy_glimpse = dummy_glimpse.cuda()
        dummy_feature_map = self.encoder.forward(dummy_glimpse)
        self.att_rnn.forward(
            dummy_feature_map.view(
                batch_size, old_div(dummy_feature_map.nelement(), batch_size)))
        self.att_rnn.reset_hidden_state(batch_size, x.data.is_cuda)

        outputs = []
        init_tensor = torch.zeros(batch_size, self.num_classes, height, width)
        if x.data.is_cuda:
            init_tensor = init_tensor.cuda()
        outputs.append(init_tensor)

        self.init_weights(self.att_rnn.get_hidden_state())

        for t in range(self.timesteps):
            # 1) decode hidden state to generate gaussian attention parameters
            state = self.att_rnn.get_hidden_state()
            gauss_attn_params = torch.tanh(
                F.linear(state, self.att_decoder_weights))

            # 2) extract glimpse
            glimpse = self.attn_reader.forward(x, gauss_attn_params,
                                               self.attn_grid_size)

            # visualize first glimpse in batch for all t
            torch_glimpses = torch.chunk(glimpse, batch_size, dim=0)
            ImageVisualizer().set_image(
                PTImage.from_cwh_torch(torch_glimpses[0].squeeze().data),
                'zGlimpse {}'.format(t))

            # 3) use conv stack or resnet to extract features
            feature_map = self.encoder.forward(glimpse)
            conv_output_dims = self.encoder.get_output_dims()[:-1][::-1]
            conv_output_dims.append(glimpse.size())
            # import ipdb;ipdb.set_trace()

            # 4) update hidden state # think about this connection a bit more
            self.att_rnn.forward(
                feature_map.view(batch_size,
                                 old_div(feature_map.nelement(), batch_size)))

            # 5) use deconv network to get partial masks
            partial_mask = self.decoder.forward(feature_map, conv_output_dims)

            # 6) write masks additively to mask canvas
            partial_canvas = self.attn_writer.forward(partial_mask,
                                                      gauss_attn_params,
                                                      (height, width))
            outputs.append(torch.add(outputs[-1], partial_canvas))

            # return the sigmoided versions
        for i in range(len(outputs)):
            outputs[i] = torch.sigmoid(outputs[i])
        return outputs