Exemple #1
0
class DRAW(nn.Module):
    def __init__(self,
                 q_size=10,
                 encoding_size=128,
                 timesteps=10,
                 training=True,
                 use_attention=False):
        super(DRAW, self).__init__()
        self.training = training
        self.encoding_size = encoding_size
        self.q_size = q_size
        self.use_attention = use_attention
        self.timesteps = timesteps
        # use equal encoding and decoding size
        self.encoder_rnn = BasicRNN(output_size=self.encoding_size)
        self.decoder_rnn = BasicRNN(output_size=self.encoding_size)
        self.register_parameter('decoder_linear_weights', None)
        self.register_parameter('encoding_mu_weights', None)
        self.register_parameter('encoding_logvar_weights', None)

    def initialize(self, x):
        batch_size = x.size(0)
        self.decoder_linear_weights = nn.Parameter(
            torch.Tensor(x.nelement() / batch_size, self.encoding_size))
        stdv = 1. / math.sqrt(self.decoder_linear_weights.size(1))
        self.decoder_linear_weights.data.uniform_(-stdv, stdv)

        self.encoding_mu_weights = nn.Parameter(
            torch.Tensor(self.q_size, self.encoding_size))
        stdv = 1. / math.sqrt(self.encoding_mu_weights.size(1))
        self.encoding_mu_weights.data.uniform_(-stdv, stdv)

        self.encoding_logvar_weights = nn.Parameter(
            torch.Tensor(self.q_size, self.encoding_size))
        stdv = 1. / math.sqrt(self.encoding_logvar_weights.size(1))
        self.encoding_logvar_weights.data.uniform_(-stdv, stdv)
        if x.data.is_cuda:
            self.cuda()

    # selects where to sample from the input image, no attention version
    # dims is 2*W*H
    def read(self, x, x_hat, dec_state):
        return torch.cat((x, x_hat), 1)

    # write takes use from "encoding space" to image space
    def write(self, decoding):
        return F.linear(decoding, self.decoder_linear_weights)

    # this converts the encoding into both a mu and logvar vector
    def sampleZ(self, encoding):
        mu = F.linear(encoding, self.encoding_mu_weights)
        logvar = F.linear(encoding, self.encoding_logvar_weights)
        return self.reparameterize(mu, logvar), mu, logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:
            return mu

    # takes an input, returns the sequence of outputs, mus, and logvars
    def forward(self, x):
        # flatten x to 1-d, except for batch dimension
        xview = x.view(x.size()[0], x.nelement() / x.size()[0])
        batch_size = x.size()[0]

        if self.decoder_linear_weights is None:
            self.initialize(xview)

        # zero out initial states
        self.encoder_rnn.reset_hidden_state(batch_size)
        self.decoder_rnn.reset_hidden_state(batch_size)
        outputs, mus, logvars = [], [], []

        outputs.append(Variable(torch.zeros(x.size())))

        for t in range(0, self.timesteps):
            # Step 1: diff the input against the prev output
            x_hat = xview - F.sigmoid(outputs[t].view(xview.size()))
            # Step 2: read
            rvec = self.read(xview, x_hat, self.decoder_rnn.get_hidden_state())
            # Step 3: encoder rnn
            # note the dimensions of r doesn't have to match with the decoding size because
            # we are just concating 2 dim-1 tensors, which is kind of wierd, but ok...
            cat = torch.cat((rvec, self.decoder_rnn.get_hidden_state().view(
                batch_size, self.encoding_size)), 1)
            encoding = self.encoder_rnn.forward(cat)
            # Step 4: sample z
            z, mu, logvar = self.sampleZ(encoding)
            # store the mu and logvar for the loss function
            mus.append(mu)
            logvars.append(logvar)

            # Step 5: decoder rnn
            decoding = self.decoder_rnn.forward(z)
            # Step 6: write to canvas, (in the original dimensions of the input)
            outputs.append(
                torch.add(outputs[-1],
                          self.write(decoding).view(x.size())))

        return outputs, mus, logvars
Exemple #2
0
class DRAW(nn.Module):
    def __init__(self,
                 q_size=10,
                 encoding_size=128,
                 timesteps=10,
                 training=True,
                 use_attention=False,
                 grid_size=5):
        super(DRAW, self).__init__()
        self.training = training
        self.encoding_size = encoding_size
        self.q_size = q_size
        self.use_attention = use_attention
        self.timesteps = timesteps
        # use equal encoding and decoding size
        self.encoder_rnn = BasicRNN(hstate_size=self.encoding_size,
                                    output_size=self.encoding_size)
        self.decoder_rnn = BasicRNN(hstate_size=self.encoding_size,
                                    output_size=self.encoding_size)
        self.register_parameter('decoder_linear_weights', None)
        self.register_parameter('encoding_mu_weights', None)
        self.register_parameter('encoding_logvar_weights', None)
        self.filter_linear_layer = nn.Linear(self.encoding_size, 5)
        self.grid_size = grid_size
        self.minclamp = 1e-8
        self.maxclamp = 1e8

    def initialize(self, x):
        batch_size = x.size(0)
        # we use attention, the decoder producers a patch of grid_size x grid_size
        # else it produces an output of the original image size
        if self.use_attention:
            self.decoder_linear_weights = nn.Parameter(
                torch.Tensor(self.grid_size * self.grid_size,
                             self.encoding_size))
        else:
            self.decoder_linear_weights = nn.Parameter(
                torch.Tensor(old_div(x.nelement(), batch_size),
                             self.encoding_size))

        stdv = 1. / math.sqrt(self.decoder_linear_weights.size(1))
        self.decoder_linear_weights.data.uniform_(-stdv, stdv)

        self.encoding_mu_weights = nn.Parameter(
            torch.Tensor(self.q_size, self.encoding_size))
        stdv = 1. / math.sqrt(self.encoding_mu_weights.size(1))
        self.encoding_mu_weights.data.uniform_(-stdv, stdv)

        self.encoding_logvar_weights = nn.Parameter(
            torch.Tensor(self.q_size, self.encoding_size))
        stdv = 1. / math.sqrt(self.encoding_logvar_weights.size(1))
        self.encoding_logvar_weights.data.uniform_(-stdv, stdv)
        if x.data.is_cuda:
            self.cuda()

    # selects where to sample from the input image, no attention version
    # dims is 2*W*H
    def read(self, x, x_hat, dec_state):
        return torch.cat((x, x_hat), 1)

    # generate two sets of filterbanks
    # 1) batch x N x W (Fx)
    # 2) batch x N x H (Fy)
    def generate_filter_matrices(self, gx, gy, sigma2, delta):
        N = self.grid_size
        grid_points = torch.arange(0, N).view((1, N, 1))
        a = torch.arange(0, self.image_w).view((1, 1, -1))
        b = torch.arange(0, self.image_h).view((1, 1, -1))
        if gx.data.is_cuda:
            grid_points = grid_points.cuda()
            a = a.cuda()
            b = b.cuda()

        # gx is Bx1, grid is (1xNx1), so this is a broadcast op -> BxNx1
        mux = gx.view(
            (-1, 1,
             1)) + (grid_points.float() - old_div(N, 2) - 0.5) * delta.view(
                 (-1, 1, 1))
        muy = gy.view(
            (-1, 1,
             1)) + (grid_points.float() - old_div(N, 2) - 0.5) * delta.view(
                 (-1, 1, 1))

        s2 = sigma2.view((-1, 1, 1))
        fx = torch.exp(old_div(-(a.float() - mux).pow(2), (2 * s2)))
        fy = torch.exp(old_div(-(b.float() - muy).pow(2), (2 * s2)))
        # normalize
        fx = old_div(
            fx,
            torch.clamp(torch.sum(fx, 2, keepdim=True), self.minclamp,
                        self.maxclamp))
        fy = old_div(
            fy,
            torch.clamp(torch.sum(fy, 2, keepdim=True), self.minclamp,
                        self.maxclamp))
        return fx, fy

    def generate_filter_params(self, state):
        filter_vector = self.filter_linear_layer(state)
        _gx, _gy, log_sigma2, log_delta, loggamma = filter_vector.split(1, 1)
        gx = old_div((self.image_w + 1), 2) * (_gx + 1)
        gy = old_div((self.image_h + 1), 2) * (_gy + 1)
        sigma2 = torch.exp(log_sigma2)
        delta = old_div((max(self.image_w, self.image_h) - 1),
                        (self.grid_size - 1)) * torch.exp(log_delta)
        gamma = torch.exp(loggamma)
        return gx, gy, sigma2, delta, gamma

    def read_w_att(self, x, x_hat, dec_state):
        batch_size = x.size()[0]

        # 1) linear to convert dec_state into batchx5 params gx,gy,logsigma2,logdelta,loggamma
        # 2) convert to gaussian parameters
        gx, gy, sigma2, delta, gamma = self.generate_filter_params(dec_state)

        # 3) generate filter matrices
        fx, fy = self.generate_filter_matrices(gx, gy, sigma2, delta)

        # 4) apply filter matrices to get glimpses
        output = gamma.view(-1, 1, 1) * torch.bmm(
            torch.bmm(fy, x.view(batch_size, self.image_h, self.image_w)),
            torch.transpose(fx, 1, 2))
        output_hat = gamma.view(-1, 1, 1) * torch.bmm(
            torch.bmm(fy, x_hat.view(batch_size, self.image_h, self.image_w)),
            torch.transpose(fx, 1, 2))
        output_total = torch.cat(
            (output.view(batch_size, self.grid_size * self.grid_size),
             output_hat.view(batch_size, self.grid_size * self.grid_size)), 1)
        return output_total

    # write takes use from "encoding space" to image space
    def write(self, decoding):
        return F.linear(decoding, self.decoder_linear_weights)

    def write_w_att(self, decoding):
        batch_size = decoding.size()[0]
        write_patch = F.linear(decoding, self.decoder_linear_weights).view(
            batch_size, self.grid_size, self.grid_size)
        gx, gy, sigma2, gamma, delta = self.generate_filter_params(decoding)
        fx, fy = self.generate_filter_matrices(gx, gy, sigma2, delta)
        output = (old_div(1, gamma)).view(-1, 1, 1) * torch.bmm(
            torch.bmm(fy.transpose(1, 2), write_patch), fx)
        return output

    # this converts the encoding into both a mu and logvar vector
    def sampleZ(self, encoding):
        mu = F.linear(encoding, self.encoding_mu_weights)
        logvar = F.linear(encoding, self.encoding_logvar_weights)
        return self.reparameterize(mu, logvar), mu, logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = std.data.new(std.size()).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu

    # takes an input, returns the sequence of outputs, mus, and logvars
    def forward(self, x):
        # flatten x to 1-d, except for batch dimension
        xview = x.view(x.size()[0], old_div(x.nelement(), x.size()[0]))
        # assume bchw dims
        self.image_w = x.size(3)
        self.image_h = x.size(2)
        batch_size = x.size()[0]

        if self.decoder_linear_weights is None:
            self.initialize(xview)

        # zero out initial states
        self.encoder_rnn.reset_hidden_state(batch_size, x.data.is_cuda)
        self.decoder_rnn.reset_hidden_state(batch_size, x.data.is_cuda)
        outputs, mus, logvars = [], [], []

        init_tensor = torch.zeros(x.size())
        if x.data.is_cuda:
            init_tensor = init_tensor.cuda()
        outputs.append(init_tensor)

        if self.use_attention:
            read_fn = self.read_w_att
            write_fn = self.write_w_att
        else:
            read_fn = self.read
            write_fn = self.write

        for t in range(0, self.timesteps):
            # import ipdb;ipdb.set_trace()
            # Step 1: diff the input against the prev output
            x_hat = xview - torch.sigmoid(outputs[t].view(xview.size()))
            # Step 2: read
            rvec = read_fn(xview, x_hat, self.decoder_rnn.get_hidden_state())
            # Step 3: encoder rnn
            # note the dimensions of r doesn't have to match with the decoding size because
            # we are just concating 2 dim-1 tensors, which is kind of wierd, but ok...
            cat = torch.cat((rvec, self.decoder_rnn.get_hidden_state().view(
                batch_size, self.encoding_size)), 1)
            encoding = self.encoder_rnn.forward(cat)
            # Step 4: sample z
            z, mu, logvar = self.sampleZ(encoding)
            # store the mu and logvar for the loss function
            mus.append(mu)
            logvars.append(logvar)

            # Step 5: decoder rnn
            decoding = self.decoder_rnn.forward(z)
            # Step 6: write to canvas, (in the original dimensions of the input)
            outputs.append(
                torch.add(outputs[-1],
                          write_fn(decoding).view(x.size())))

        # return the sigmoided versions
        for i in range(len(outputs)):
            outputs[i] = torch.sigmoid(outputs[i])
        return outputs, mus, logvars
Exemple #3
0
class AttentionSegmenter(nn.Module):
    def __init__(self,
                 num_classes,
                 inchans=3,
                 att_encoding_size=128,
                 timesteps=10,
                 attn_grid_size=50):
        super(AttentionSegmenter, self).__init__()
        self.num_classes = num_classes
        self.att_encoding_size = att_encoding_size
        self.timesteps = timesteps
        self.attn_grid_size = attn_grid_size
        self.encoder = ConvolutionStack(inchans, final_relu=False, padding=0)
        self.encoder.append(32, 3, 1)
        self.encoder.append(32, 3, 2)
        self.encoder.append(64, 3, 1)
        self.encoder.append(64, 3, 2)
        self.encoder.append(96, 3, 1)
        self.encoder.append(96, 3, 2)

        self.decoder = TransposedConvolutionStack(96,
                                                  final_relu=False,
                                                  padding=0)
        self.decoder.append(96, 3, 2)
        self.decoder.append(64, 3, 1)
        self.decoder.append(64, 3, 2)
        self.decoder.append(32, 3, 1)
        self.decoder.append(32, 3, 2)
        self.decoder.append(self.num_classes, 3, 1)

        self.attn_reader = GaussianAttentionReader()
        self.attn_writer = GaussianAttentionWriter()
        self.att_rnn = BasicRNN(hstate_size=att_encoding_size, output_size=5)
        self.register_parameter('att_decoder_weights', None)

    def init_weights(self, hstate):
        if self.att_decoder_weights is None:
            batch_size = hstate.size(0)
            self.att_decoder_weights = nn.Parameter(
                torch.Tensor(5, old_div(hstate.nelement(), batch_size)))
            stdv = 1. / math.sqrt(self.att_decoder_weights.size(1))
            self.att_decoder_weights.data.uniform_(-stdv, stdv)
        if hstate.data.is_cuda:
            self.cuda()

    def forward(self, x):
        batch_size, chans, height, width = x.size()

        # need to first determine the hidden state size, which is tied to the cnn feature size
        dummy_glimpse = torch.Tensor(batch_size, chans, self.attn_grid_size,
                                     self.attn_grid_size)
        if x.is_cuda:
            dummy_glimpse = dummy_glimpse.cuda()
        dummy_feature_map = self.encoder.forward(dummy_glimpse)
        self.att_rnn.forward(
            dummy_feature_map.view(
                batch_size, old_div(dummy_feature_map.nelement(), batch_size)))
        self.att_rnn.reset_hidden_state(batch_size, x.data.is_cuda)

        outputs = []
        init_tensor = torch.zeros(batch_size, self.num_classes, height, width)
        if x.data.is_cuda:
            init_tensor = init_tensor.cuda()
        outputs.append(init_tensor)

        self.init_weights(self.att_rnn.get_hidden_state())

        for t in range(self.timesteps):
            # 1) decode hidden state to generate gaussian attention parameters
            state = self.att_rnn.get_hidden_state()
            gauss_attn_params = torch.tanh(
                F.linear(state, self.att_decoder_weights))

            # 2) extract glimpse
            glimpse = self.attn_reader.forward(x, gauss_attn_params,
                                               self.attn_grid_size)

            # visualize first glimpse in batch for all t
            torch_glimpses = torch.chunk(glimpse, batch_size, dim=0)
            ImageVisualizer().set_image(
                PTImage.from_cwh_torch(torch_glimpses[0].squeeze().data),
                'zGlimpse {}'.format(t))

            # 3) use conv stack or resnet to extract features
            feature_map = self.encoder.forward(glimpse)
            conv_output_dims = self.encoder.get_output_dims()[:-1][::-1]
            conv_output_dims.append(glimpse.size())
            # import ipdb;ipdb.set_trace()

            # 4) update hidden state # think about this connection a bit more
            self.att_rnn.forward(
                feature_map.view(batch_size,
                                 old_div(feature_map.nelement(), batch_size)))

            # 5) use deconv network to get partial masks
            partial_mask = self.decoder.forward(feature_map, conv_output_dims)

            # 6) write masks additively to mask canvas
            partial_canvas = self.attn_writer.forward(partial_mask,
                                                      gauss_attn_params,
                                                      (height, width))
            outputs.append(torch.add(outputs[-1], partial_canvas))

            # return the sigmoided versions
        for i in range(len(outputs)):
            outputs[i] = torch.sigmoid(outputs[i])
        return outputs