Exemplo n.º 1
0
 def __init__(self, loc, scale, validate_args=None):
     base_dist = Normal(loc, scale)
     if not base_dist.batch_shape:
         base_dist = base_dist.expand([1])
     super(LogisticNormal, self).__init__(base_dist,
                                          StickBreakingTransform(),
                                          validate_args=validate_args)
Exemplo n.º 2
0
class AIR(BaseGenerativeModel):
    """
    AIR model. Default settings are from the pyro tutorial. With those settings
    we can reproduce results from the original paper (although about 1/10 times
    it doesn't converge to a good solution).
    """

    z_where_dim = 3
    z_pres_dim = 1
    
    def __init__(self,
                 img_size,
                 object_size,
                 max_steps,
                 color_channels,
                 likelihood=None,
                 z_what_dim=50,
                 lstm_hidden_dim=256,
                 baseline_hidden_dim=256,
                 encoder_hidden_dim=200,
                 decoder_hidden_dim=200,
                 scale_prior_mean=3.0,
                 scale_prior_std=0.2,
                 pos_prior_mean=0.0,
                 pos_prior_std=1.0,
                 ):
        super().__init__()

        #### Settings

        self.max_steps = max_steps

        self.img_size = img_size
        self.object_size = object_size
        self.color_channels = color_channels
        self.z_what_dim = z_what_dim
        self.lstm_hidden_dim = lstm_hidden_dim
        self.baseline_hidden_dim = baseline_hidden_dim
        self.encoder_hidden_dim = encoder_hidden_dim
        self.decoder_hidden_dim = decoder_hidden_dim

        self.z_pres_prob_prior = nograd_param(0.01)
        self.z_where_loc_prior = nograd_param(
            [scale_prior_mean, pos_prior_mean, pos_prior_mean])
        self.z_where_scale_prior = nograd_param(
            [scale_prior_std, pos_prior_std, pos_prior_std])
        self.z_what_loc_prior = nograd_param(0.0)
        self.z_what_scale_prior = nograd_param(1.0)

        ####

        self.img_numel = color_channels * (img_size ** 2)

        lstm_input_size = (self.img_numel + self.z_what_dim
                           + self.z_where_dim + self.z_pres_dim)
        self.lstm = LSTMCell(lstm_input_size, self.lstm_hidden_dim)

        # Infer presence and location from LSTM hidden state
        self.predictor = Predictor(self.lstm_hidden_dim)

        # Infer z_what given an image crop around the object
        self.encoder = AppearanceEncoder(object_size, color_channels,
                                         encoder_hidden_dim, z_what_dim)

        # Generate pixel representation of an object given its z_what
        self.decoder = AppearanceDecoder(z_what_dim, decoder_hidden_dim,
                                         object_size, color_channels)
        
        # Spatial transformer (does both forward and inverse)
        self.spatial_transf = SpatialTransformer(
            (self.object_size, self.object_size),
            (self.img_size, self.img_size))

        # Baseline LSTM
        self.bl_lstm = LSTMCell(lstm_input_size, self.baseline_hidden_dim)

        # Baseline regressor
        self.bl_regressor = nn.Sequential(
            nn.Linear(self.baseline_hidden_dim, 200),
            nn.ReLU(),
            nn.Linear(200, 1)
        )

        # Prior distributions
        self.pres_prior = Bernoulli(probs=self.z_pres_prob_prior)
        self.where_prior = Normal(loc=self.z_where_loc_prior,
                                  scale=self.z_where_scale_prior)
        self.what_prior = Normal(loc=self.z_what_loc_prior,
                                 scale=self.z_what_scale_prior)

        # Data likelihood
        self.likelihood = likelihood

    @staticmethod
    def _module_list_to_params(modules):
        params = []
        for module in modules:
            params.extend(module.parameters())
        return params

    def air_params(self):
        air_modules = [self.predictor, self.lstm, self.encoder, self.decoder]
        return self._module_list_to_params(air_modules) + [self.z_pres_prob_prior]

    def baseline_params(self):
        baseline_modules = [self.bl_regressor, self.bl_lstm]
        return self._module_list_to_params(baseline_modules)

    def get_output_dist(self, mean):
        if self.likelihood == 'original':
            std = torch.tensor(0.3).to(self.get_device())
            dist = Normal(mean, std.expand_as(mean))
        elif self.likelihood == 'bernoulli':
            dist = Bernoulli(probs=mean)
        else:
            msg = "Unrecognized likelihood '{}'".format(self.likelihood)
            raise RuntimeError(msg)
        return dist
        
    def forward(self, x):
        bs = x.size(0)

        # Init model state
        state = State(
            h=torch.zeros(bs, self.lstm_hidden_dim, device=x.device),
            c=torch.zeros(bs, self.lstm_hidden_dim, device=x.device),
            bl_h=torch.zeros(bs, self.baseline_hidden_dim, device=x.device),
            bl_c=torch.zeros(bs, self.baseline_hidden_dim, device=x.device),
            z_pres=torch.ones(bs, 1, device=x.device),
            z_where=torch.zeros(bs, 3, device=x.device),
            z_what=torch.zeros(bs, self.z_what_dim, device=x.device),
        )

        # KL divergence for each step
        kl = torch.zeros(bs, self.max_steps, device=x.device)

        # Store KL for pres, where, and what separately
        kl_pres = torch.zeros(bs, self.max_steps, device=x.device)
        kl_where = torch.zeros(bs, self.max_steps, device=x.device)
        kl_what = torch.zeros(bs, self.max_steps, device=x.device)

        # Baseline value for each step
        baseline_value = torch.zeros(bs, self.max_steps, device=x.device)

        # Log likelihood for each step, with shape (B, T):
        # log q(z_pres[t] | x, z_{<t}), but only for t <= n+1
        z_pres_likelihood = torch.zeros(bs, self.max_steps, device=x.device)

        # Baseline target for each step
        baseline_target = torch.zeros(bs, self.max_steps, device=x.device)

        # signal_mask (prev.z_pres) for each step
        mask_prev = torch.ones(bs, self.max_steps, device=x.device)

        # Mask (z_pres) for each step
        mask_curr = torch.ones(bs, self.max_steps, device=x.device)

        # Output canvas
        h = w = self.img_size
        ch = self.color_channels
        canvas = torch.zeros(bs, ch, h, w, device=x.device)

        # Save z_where to visualize bounding boxes
        all_z_where = torch.zeros(bs, self.max_steps, 3, device=x.device)

        for t in range(self.max_steps):
            # This is the previous z_pres, so at step i=0 this mask is all 1s.
            # It is used to zero out all time steps after the first z_pres=0.
            # The first z_pres=0 is NOT masked.
            mask_prev[:, t] = state.z_pres.squeeze()
            
            # Do one inference step and save results
            result = self.inference_step(state, x)
            state = result['state']
            kl[:, t] = result['kl']
            kl_pres[:, t] = result['kl_pres']
            kl_where[:, t] = result['kl_where']
            kl_what[:, t] = result['kl_what']
            baseline_value[:, t] = result['baseline_value']
            z_pres_likelihood[:, t] = result['z_pres_likelihood']

            # Add KL at timestep t to baseline_target for timesteps 0 to t
            # At the end of the loop: baseline_target[t] = sum_{i=t}^T KL[i]
            for j in range(t + 1):
                baseline_target[:, j] += result['kl']
                
            # Decode z_what to object appearance
            sprite = self.decoder(state.z_what)

            # Spatial-transform it to image with shape (B, 1, H, W)
            img = self.spatial_transf.forward(sprite, state.z_where)

            # Add to the output canvas, masking according to object presence
            # state.z_pres has shape (B, 1)
            canvas += img * state.z_pres[:, :, None, None]

            # Presence mask for current time step
            mask_curr[:, t] = state.z_pres.squeeze(1)

            # Save z_where to visualize bounding boxes
            all_z_where[:, t] = state.z_where  # shape (B, 3)

        # Clip canvas to [0, 1] (lose gradient where overlap)
        if self.likelihood == 'bernoulli':
            canvas = canvas.clamp(min=0., max=1.)

        # Inferred number of objects in each image
        inferred_n = mask_curr.sum(1)   # shape (B,)

        # Output distribution p(x | z)
        output_dist = self.get_output_dist(canvas)

        # Data likelihood log p(x | z)
        likelihood_sep = output_dist.log_prob(x)

        # Sample from log p(x | z) with inferred z
        out_sample = output_dist.sample()

        # Sum over all data dimensions, resulting shape (B, )
        likelihood_sep = likelihood_sep.sum((1, 2, 3))

        # Sum KL over time steps, resulting shape (B, )
        kl = kl.sum(1)

        # ELBO separated per sample, shape (B, )
        elbo_sep = likelihood_sep - kl

        data = {
            'elbo_sep': elbo_sep,
            'elbo': elbo_sep.mean(),
            'inferred_n': inferred_n,
            'data_likelihood': likelihood_sep,
            'recons': -likelihood_sep.mean(),
            'kl': kl.mean(),
            'kl_pres': kl_pres.sum(1).mean(),
            'kl_where': kl_where.sum(1).mean(),
            'kl_what': kl_what.sum(1).mean(),
            'out_mean': canvas,
            'out_sample': out_sample,
            'all_z_where': all_z_where,
            'baseline_target': baseline_target,
            'baseline_value': baseline_value,
            'mask_prev': mask_prev,
            'z_pres_likelihood': z_pres_likelihood,
        }

        return data


    def inference_step(self, prev, x):
        """
        Given previous (or initial) state and input image, predict the next
        inference step (next object).
        """

        bs = x.size(0)
        
        # Flatten the image
        x_flat = x.view(bs, -1)
        
        # Feed (x, z_{<t}) through the LSTM cell, get encoding h
        lstm_input = torch.cat(
            (x_flat, prev.z_where, prev.z_what, prev.z_pres), dim=1)
        h, c = self.lstm(lstm_input, (prev.h, prev.c))

        # Predictor presence and location from h
        z_pres_p, z_where_loc, z_where_scale = self.predictor(h)
        
        # If previous z_pres is 0, force z_pres to 0
        z_pres_p = z_pres_p * prev.z_pres
        
        # Numerical stability
        eps = 1e-12
        z_pres_p = z_pres_p.clamp(min=eps, max=1.0-eps)

        # sample z_pres
        z_pres_post = Bernoulli(z_pres_p)
        z_pres = z_pres_post.sample()

        # If previous z_pres is 0, then this z_pres should also be 0.
        # However, this is sampled from a Bernoulli whose probability is at
        # least eps. In the unlucky event that the sample is 1, we force this
        # to 0 as well.
        z_pres = z_pres * prev.z_pres
        
        # Likelihood: log q(z_pres[i] | x, z_{<i}) (if z_pres[i-1]=1, else 0)
        # Mask with prev.z_pres instead of z_pres, i.e. if already at the
        # previous step there was no object.
        z_pres_likelihood = z_pres_post.log_prob(z_pres) * prev.z_pres
        z_pres_likelihood = z_pres_likelihood.squeeze()  # shape (B,)

        # Sample z_where
        z_where_post = Normal(z_where_loc, z_where_scale)
        z_where = z_where_post.rsample()
        
        # Get object from image - shape (B, 1, Hobj, Wobj)
        obj = self.spatial_transf.inverse(x, z_where)
        
        # Predictor z_what
        z_what_loc, z_what_scale = self.encoder(obj)
        z_what_post = Normal(z_what_loc, z_what_scale)
        z_what = z_what_post.rsample()

        # Compute baseline for this z_pres:
        # b_i(z_{<i}) depending on previous step latent variables only.
        bl_h, bl_c = self.bl_lstm(lstm_input.detach(), (prev.bl_h, prev.bl_c))
        baseline_value = self.bl_regressor(bl_h).squeeze()  # shape (B,)

        # The baseline is not used if z_pres[t-1] is 0 (object not present in
        # the previous step). Mask it out to be on the safe side.
        baseline_value = baseline_value * prev.z_pres.squeeze()
        
        # KL for the current step, sum over data dimension: shape (B,)
        kl_pres = kl_divergence(
            z_pres_post,
            self.pres_prior.expand(z_pres_post.batch_shape)).sum(1)
        kl_where = kl_divergence(
            z_where_post,
            self.where_prior.expand(z_where_post.batch_shape)).sum(1)
        kl_what = kl_divergence(
            z_what_post,
            self.what_prior.expand(z_what_post.batch_shape)).sum(1)

        # When z_pres[i] is 0, zwhere and zwhat are not used -> set KL=0
        kl_where = kl_where * z_pres.squeeze()
        kl_what = kl_what * z_pres.squeeze()

        # When z_pres[i-1] is 0, zpres is not used -> set KL=0
        kl_pres = kl_pres * prev.z_pres.squeeze()
        
        kl = (kl_pres + kl_where + kl_what)

        # New state
        new_state = State(
            z_pres=z_pres,
            z_where=z_where,
            z_what=z_what,
            h=h,
            c=c,
            bl_c=bl_c,
            bl_h=bl_h,
            )

        out = {
            'state': new_state,
            'kl': kl,
            'kl_pres': kl_pres,
            'kl_where': kl_where,
            'kl_what': kl_what,
            'baseline_value': baseline_value,
            'z_pres_likelihood': z_pres_likelihood,
        }
        return out


    def sample_prior(self, n_imgs, **kwargs):

        # Sample from prior. Shapes:
        # z_pres:  (B, T)
        # z_what:  (B, T, z_what_dim)
        # z_where: (B, T, 3)
        z_pres = self.pres_prior.sample((n_imgs, self.max_steps))
        z_what = self.what_prior.sample((n_imgs, self.max_steps, self.z_what_dim))
        z_where = self.where_prior.sample((n_imgs, self.max_steps))

        # TODO This is only for visualization! Not real model samples
        # The prior of z_pres puts a lot of probability on n=0, which doesn't
        # lead to informative samples. Instead, generate half images with 1
        # object and half with 2.
        # z_pres.fill_(0.)
        # z_pres[:, 0].fill_(1.)
        # z_pres[n_imgs//2:, 1].fill_(1.)

        # If z_pres is sampled from the prior, make sure there are no ones
        # after a zero.
        for t in range(1, self.max_steps):
            z_pres[:, t] *= z_pres[:, t-1]  # if previous=0, this is also 0

        n_obj = z_pres.sum(1)

        # Decode z_what to object appearance
        sprites = self.decoder(z_what)

        # Spatial-transform them to images with shape (B*T, 1, H, W)
        z_where_ = z_where.view(n_imgs * self.max_steps, 3)  # shape (B*T, 3)
        imgs = self.spatial_transf.forward(sprites, z_where_)

        # Reshape images to (B, T, 1, H, W)
        h = w = self.img_size
        ch = self.color_channels
        imgs = imgs.view(n_imgs, self.max_steps, ch, h, w)

        # Make canvas by masking and summing over timesteps
        canvas = imgs * z_pres[:, :, None, None, None]
        canvas = canvas.sum(1)

        return canvas, z_where, n_obj
Exemplo n.º 3
0
class AIR(nn.Module):
    def __init__(self, arch=None):
        """
        :param arch: dictionary, for overriding default architecture
        """
        nn.Module.__init__(self)
        self.arch = deepcopy(default_arch)
        if arch is not None:
            self.arch.update(arch)

        self.T = self.arch.max_steps
        self.reinforce_weight = 0.0

        # 4: where + pres
        lstm_input_size = self.arch.input_size + self.arch.z_what_size + 4
        self.lstm_cell = LSTMCell(lstm_input_size, self.arch.lstm_hidden_size)

        # predict z_where, z_pres from h
        self.predict = Predict(self.arch)
        # encode object into what
        self.encoder = Encoder(self.arch)
        # decode what into object
        self.decoder = Decoder(self.arch)

        # spatial transformers
        self.image_to_object = SpatialTransformer(self.arch.input_shape,
                                                  self.arch.object_shape)
        self.object_to_image = SpatialTransformer(self.arch.object_shape,
                                                  self.arch.input_shape)

        # baseline RNN
        self.bl_rnn = LSTMCell(lstm_input_size, self.arch.baseline_hidden_size)
        # predict baseline value
        self.bl_predict = nn.Linear(self.arch.baseline_hidden_size, 1)

        # priors
        self.pres_prior = Bernoulli(probs=self.arch.z_pres_prob_prior)
        self.where_prior = Normal(loc=self.arch.z_where_loc_prior,
                                  scale=self.arch.z_where_scale_prior)
        self.what_prior = Normal(loc=self.arch.z_what_loc_prior,
                                 scale=self.arch.z_what_scale_prior)

        # modules excluding baseline rnn
        self.air_modules = nn.ModuleList(
            [self.predict, self.lstm_cell, self.encoder, self.decoder])

        self.baseline_modules = nn.ModuleList([self.bl_rnn, self.bl_predict])

    def forward(self, x):
        B = x.size(0)
        state = AIRState.get_intial_state(B, self.arch)

        # accumulated KL divergence
        kl = []
        # baseline value for each step
        baseline_value = []
        # z_pres likelihood for each step
        z_pres_likelihood = []
        # learning signal for each step
        learning_signal = torch.zeros(B, self.arch.max_steps, device=x.device)
        # signal_mask (prev.z_pres)
        signal_mask = torch.ones(B, self.arch.max_steps, device=x.device)
        # mask (z_pres)
        mask = torch.ones(B, self.arch.max_steps, device=x.device)
        # canvas
        h, w = self.arch.input_shape
        canvas = torch.zeros(B, 1, h, w, device=x.device)

        if DEBUG:
            vis_logger['image'] = x[0]
            vis_logger['z_pres_p_list'] = []
            vis_logger['z_pres_list'] = []
            vis_logger['canvas_list'] = []
            vis_logger['z_where_list'] = []
            vis_logger['object_enc_list'] = []
            vis_logger['object_dec_list'] = []
            vis_logger['kl_pres_list'] = []
            vis_logger['kl_what_list'] = []
            vis_logger['kl_where_list'] = []

        for t in range(self.T):
            # This is prev.z_pres. The only purpose is for masking learning signal.
            signal_mask[:, t] = state.z_pres.squeeze()

            # all terms are already masked
            state, this_kl, this_baseline_value, this_z_pres_likelihood = self.infer_step(
                state, x)
            baseline_value.append(this_baseline_value.squeeze())
            kl.append(this_kl)
            z_pres_likelihood.append(this_z_pres_likelihood.squeeze())

            # add learning signal to depending terms (1:i-1)
            # NOTE: kl of z_pres of current step does not depends on sample from
            # z_pres, but kl of z_where and z_what DOES. They cannot be excluded
            # from learning signal. So here we use t + 1 instead of t. Although
            # this also includes kl of z_pres of current step, this will not
            # matter too much
            for j in range(t + 1):
                learning_signal[:, j] += this_kl.squeeze()

            # reconstruct
            object = self.decoder(state.z_what)
            # (B, 1, H, W)
            img = self.object_to_image(object, state.z_where, inverse=False)
            # Masking is crucial here.
            canvas = canvas + img * state.z_pres[:, :, None, None]

            mask[:, t] = state.z_pres.squeeze()

            vis_logger['canvas_list'].append(canvas[0])
            vis_logger['object_dec_list'].append(object[0])

        baseline_value = torch.stack(baseline_value, dim=1)
        kl = torch.stack(kl, dim=1)
        z_pres_likelihood = torch.stack(z_pres_likelihood, dim=1)

        # construct output distribution
        output_dist = Normal(canvas, self.arch.x_scale.expand(canvas.shape))
        likelihood = output_dist.log_prob(x)
        # sum over data dimension
        likelihood = likelihood.view(B, -1).sum(1)

        # Construct surrogate loss
        # Note the MNIUS sign here !
        learning_signal = learning_signal - likelihood[:, None]
        learning_signal = learning_signal * signal_mask
        reinforce_term = (learning_signal.detach() -
                          baseline_value.detach()) * z_pres_likelihood
        reinforce_term = reinforce_term.sum(1)
        # reinforce_term = torch.zeros_like(reinforce_term)

        # kl term, sum over batch dimension
        kl = kl.sum(1)

        loss = self.reinforce_weight * reinforce_term + kl - likelihood
        # mean over batch dimension
        loss = loss.mean()

        vis_logger['reinforce_loss'] = (reinforce_term.mean())
        vis_logger['kl_loss'] = (kl.mean())
        vis_logger['neg_likelihood'] = (-likelihood.mean())

        # compute baseline loss
        baseline_loss = F.mse_loss(baseline_value, learning_signal.detach())

        vis_logger['baseline_loss'] = baseline_loss

        # losslist = (reinforce_term.mean(), kl.mean(), likelihood.mean(), baseline_loss)

        return loss + baseline_loss, mask.sum(1)

    def infer_step(self, prev, x):
        """
        Given previous state, predict next state. We assume that z_pres is 1
        :param prev: AIRState
        :return: new_state, KL, baseline value, z_pres_likelihood
        """

        B = x.size(0)

        # Flatten x
        x_flat = x.view(B, -1)

        # First, compute h_t that encodes (x, z[1:i-1])
        lstm_input = torch.cat(
            (x_flat, prev.z_where, prev.z_what, prev.z_pres), dim=1)
        h, c = self.lstm_cell(lstm_input, (prev.h, prev.c))

        # Predict presence and location
        z_pres_p, z_where_loc, z_where_scale = self.predict(h)

        # In theory, if z_pres is 0, we don't need to continue computation. But
        # for batch processing, we will do this anyway.

        # sample z_pres
        z_pres_p = z_pres_p * prev.z_pres

        # NOTE: for numerical stability, if z_pres_p is 0 or 1, we will need to
        # clamp it to within (0, 1), or otherwise the gradient will explode
        eps = 1e-6
        z_pres_p = z_pres_p + eps * (z_pres_p == 0).float() - eps * (
            z_pres_p == 1).float()

        z_pres_post = Bernoulli(z_pres_p)
        z_pres = z_pres_post.sample()
        z_pres = z_pres * prev.z_pres

        # Likelihood. Note we must use prev.z_pres instead of z_pres because
        # p(z_pres[i]=0|z_prse[i]=1) is non-zero.
        z_pres_likelihood = z_pres_post.log_prob(z_pres) * prev.z_pres
        # (B,)
        z_pres_likelihood = z_pres_likelihood.squeeze()

        # sample z_where
        z_where_post = Normal(z_where_loc, z_where_scale)
        z_where = z_where_post.rsample()

        # extract object
        # (B, 1, Hobj, Wobj)
        object = self.image_to_object(x, z_where, inverse=True)

        # predict z_what
        z_what_loc, z_what_scale = self.encoder(object)
        z_what_post = Normal(z_what_loc, z_what_scale)
        z_what = z_what_post.rsample()

        # compute baseline for this z_pres
        bl_h, bl_c = self.bl_rnn(lstm_input.detach(), (prev.bl_h, prev.bl_c))
        # (B,)
        baseline_value = self.bl_predict(bl_h).squeeze()
        # If z_pres[i-1] is 0, the reinforce term will not be dependent on phi.
        # In this case, we don't need the term. So we set it to zero.
        # At the same time, we must set learning signal to zero as this will
        # matter in baseline loss computation.
        baseline_value = baseline_value * prev.z_pres.squeeze()

        # Compute KL as we go, sum over data dimension
        kl_pres = kl_divergence(
            z_pres_post,
            self.pres_prior.expand(z_pres_post.batch_shape)).sum(1)
        kl_where = kl_divergence(
            z_where_post,
            self.where_prior.expand(z_where_post.batch_shape)).sum(1)
        kl_what = kl_divergence(
            z_what_post,
            self.what_prior.expand(z_what_post.batch_shape)).sum(1)

        # For where and what, when z_pres[i] is 0, they are determnisitic
        kl_where = kl_where * z_pres.squeeze()
        kl_what = kl_what * z_pres.squeeze()
        # For pres, this is not the case. So we use prev.z_pres.
        kl_pres = kl_pres * prev.z_pres.squeeze()

        kl = (kl_pres + kl_where + kl_what)

        # new state
        new_state = AIRState(z_pres=z_pres,
                             z_where=z_where,
                             z_what=z_what,
                             h=h,
                             c=c,
                             bl_c=bl_c,
                             bl_h=bl_h,
                             z_pres_p=z_pres_p)

        # Logging
        if DEBUG:
            vis_logger['z_pres_p_list'].append(z_pres_p[0])
            vis_logger['z_pres_list'].append(z_pres[0])
            vis_logger['z_where_list'].append(z_where[0])
            vis_logger['object_enc_list'].append(object[0])
            vis_logger['kl_pres_list'].append(kl_pres.mean())
            vis_logger['kl_what_list'].append(kl_what.mean())
            vis_logger['kl_where_list'].append(kl_where.mean())

        return new_state, kl, baseline_value, z_pres_likelihood