def _call(self, inp, mask, is_training): self.maybe_build_subnet('background_encoder') self.maybe_build_subnet('background_decoder') combined = tf.concat([inp, mask], axis=-1) latent = self.background_encoder(combined, 2 * self.n_latents_per_channel, is_training) mean, std = tf.split(latent, 2, axis=-1) sample, kl = normal_vae(mean, std, 0, 1) background = self.background_decoder(sample, None, is_training) return background, kl
def build_representation(self): # --- init modules --- if self.encoder is None: self.encoder = cfg.build_encoder(scope="encoder") if "encoder" in self.fixed_weights: self.encoder.fix_variables() if self.decoder is None: self.decoder = cfg.build_decoder(scope="decoder") if "decoder" in self.fixed_weights: self.decoder.fix_variables() # --- encode --- attr = self.encoder(self.inp, 2 * self.A, self.is_training) attr_mean, attr_log_std = tf.split(attr, 2, axis=-1) attr_std = tf.exp(attr_log_std) if not self.noisy: attr_std = tf.zeros_like(attr_std) attr, attr_kl = normal_vae(attr_mean, attr_std, self.attr_prior_mean, self.attr_prior_std) obj_shape = tf.concat([tf.shape(attr)[:-1], [1]], axis=0) self._tensors["obj"] = tf.ones(obj_shape) self._tensors.update(attr_mean=attr_mean, attr_std=attr_std, attr_kl=attr_kl, attr=attr) # --- decode --- reconstruction = self.decoder(attr, 3, self.is_training) reconstruction = reconstruction[:, :self.inp.shape[1], :self.inp.shape[2], :] reconstruction = tf.nn.sigmoid(tf.clip_by_value(reconstruction, -10, 10)) self._tensors["output"] = reconstruction # --- losses --- if self.train_kl: self.losses['attr_kl'] = tf_mean_sum(self._tensors["attr_kl"]) if self.train_reconstruction: self._tensors['per_pixel_reconstruction_loss'] = xent_loss(pred=reconstruction, label=self.inp) self.losses['reconstruction'] = tf_mean_sum(self._tensors['per_pixel_reconstruction_loss'])
def _build_program_interpreter(self, tensors): # --- Get object attributes using object encoder --- max_objects = tensors["max_objects"] yt, xt, ys, xs = tf.split(tensors["normalized_box"], 4, axis=-1) transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper((self.image_height, self.image_width), self.object_shape, transform_constraints) _boxes = tf.concat( [xs, 2 * (xt + xs / 2) - 1, ys, 2 * (yt + ys / 2) - 1], axis=-1) _boxes = tf.reshape(_boxes, (self.batch_size * max_objects, 4)) grid_coords = warper(_boxes) grid_coords = tf.reshape(grid_coords, ( self.batch_size, max_objects, *self.object_shape, 2, )) glimpse = tf.contrib.resampler.resampler(tensors["inp"], grid_coords) object_encoder_in = tf.reshape(glimpse, (self.batch_size * max_objects, *self.object_shape, self.image_depth)) attr = self.object_encoder(object_encoder_in, (1, 1, 2 * self.A), self.is_training) attr = tf.reshape(attr, (self.batch_size, max_objects, 2 * self.A)) attr_mean, attr_log_std = tf.split(attr, [self.A, self.A], axis=-1) attr_std = tf.exp(attr_log_std) if not self.noisy: attr_std = tf.zeros_like(attr_std) attr, attr_kl = normal_vae(attr_mean, attr_std, self.attr_prior_mean, self.attr_prior_std) object_decoder_in = tf.reshape( attr, (self.batch_size * max_objects, 1, 1, self.A)) # --- Compute sprites from attr using object decoder --- object_logits = self.object_decoder( object_decoder_in, self.object_shape + (self.image_depth, ), self.is_training) objects = tf.nn.sigmoid(tf.clip_by_value(object_logits, -10., 10.)) objects = tf.reshape(objects, ( self.batch_size, max_objects, *self.object_shape, self.image_depth, )) alpha = tensors["obj"][:, :, :, None, None] * tf.ones_like( objects[:, :, :, :, :1]) importance = tf.ones_like(objects[:, :, :, :, :1]) objects = tf.concat([objects, alpha, importance], axis=-1) # -- Reconstruct image --- scales = tf.concat([ys, xs], axis=-1) scales = tf.reshape(scales, (self.batch_size, max_objects, 2)) offsets = tf.concat([yt, xt], axis=-1) offsets = tf.reshape(offsets, (self.batch_size, max_objects, 2)) output = render_sprites.render_sprites(objects, tensors["n_objects"], scales, offsets, tensors["background"]) return dict(output=output, glimpse=tf.reshape(glimpse, (self.batch_size, max_objects, *self.object_shape, self.image_depth)), attr=tf.reshape(attr, (self.batch_size, max_objects, self.A)), attr_kl=tf.reshape(attr_kl, (self.batch_size, max_objects, self.A)), objects=tf.reshape(objects, ( self.batch_size, max_objects, *self.object_shape, self.image_depth, )))
def body(step, stopping_sum, prev_state, running_recon, kl_loss, running_digits, scale_ta, scale_kl_ta, scale_std_ta, shift_ta, shift_kl_ta, shift_std_ta, attr_ta, attr_kl_ta, attr_std_ta, z_pres_ta, z_pres_probs_ta, z_pres_kl_ta, vae_input_ta, vae_output_ta, scale, shift, attr, z_pres): if self.difference_air: inp = (self._tensors["inp"] - tf.reshape(running_recon, (self.batch_size, *self.obs_shape))) encoded_inp = self.image_encoder(inp, 0, self.is_training) encoded_inp = tf.layers.flatten(encoded_inp) else: encoded_inp = self.encoded_inp if self.complete_rnn_input: rnn_input = tf.concat( [encoded_inp, scale, shift, attr, z_pres], axis=1) else: rnn_input = encoded_inp hidden_rep, next_state = self.cell(rnn_input, prev_state) outputs = self.output_network(hidden_rep, 9, self.is_training) (scale_mean, scale_log_std, shift_mean, shift_log_std, z_pres_log_odds) = tf.split(outputs, [2, 2, 2, 2, 1], axis=1) # --- scale --- scale_std = tf.exp(scale_log_std) scale_mean = self.apply_fixed_value("scale_mean", scale_mean) scale_std = self.apply_fixed_value("scale_std", scale_std) scale_logits, scale_kl = normal_vae(scale_mean, scale_std, self.scale_prior_mean, self.scale_prior_std) scale_kl = tf.reduce_sum(scale_kl, axis=1, keepdims=True) scale = tf.nn.sigmoid(tf.clip_by_value(scale_logits, -10, 10)) # --- shift --- shift_std = tf.exp(shift_log_std) shift_mean = self.apply_fixed_value("shift_mean", shift_mean) shift_std = self.apply_fixed_value("shift_std", shift_std) shift_logits, shift_kl = normal_vae(shift_mean, shift_std, self.shift_prior_mean, self.shift_prior_std) shift_kl = tf.reduce_sum(shift_kl, axis=1, keepdims=True) shift = tf.nn.tanh(tf.clip_by_value(shift_logits, -10, 10)) # --- Extract windows from scene --- w, h = scale[:, 0:1], scale[:, 1:2] x, y = shift[:, 0:1], shift[:, 1:2] theta = tf.concat( [w, tf.zeros_like(w), x, tf.zeros_like(h), h, y], axis=1) theta = tf.reshape(theta, (-1, 2, 3)) vae_input = transformer(self._tensors["inp"], theta, self.object_shape) # This is a necessary reshape, as the output of transformer will have unknown dims vae_input = tf.reshape( vae_input, (self.batch_size, *self.object_shape, self.image_depth)) # --- Apply Object-level VAE (object encoder/object decoder) to windows --- attr = self.object_encoder(vae_input, 2 * self.A, self.is_training) attr_mean, attr_log_std = tf.split(attr, 2, axis=1) attr_std = tf.exp(attr_log_std) attr, attr_kl = normal_vae(attr_mean, attr_std, self.attr_prior_mean, self.attr_prior_std) attr_kl = tf.reduce_sum(attr_kl, axis=1, keepdims=True) vae_output = self.object_decoder( attr, self.object_shape[0] * self.object_shape[1] * self.image_depth, self.is_training) vae_output = tf.nn.sigmoid(tf.clip_by_value(vae_output, -10, 10)) # --- Place reconstructed objects in image --- theta_inverse = tf.concat([ 1. / w, tf.zeros_like(w), -x / w, tf.zeros_like(h), 1. / h, -y / h ], axis=1) theta_inverse = tf.reshape(theta_inverse, (-1, 2, 3)) vae_output_transformed = transformer( tf.reshape(vae_output, ( self.batch_size, *self.object_shape, self.image_depth, )), theta_inverse, self.obs_shape[:2]) vae_output_transformed = tf.reshape(vae_output_transformed, [ self.batch_size, self.image_height * self.image_width * self.image_depth ]) # --- z_pres --- if self.run_all_time_steps: z_pres = tf.ones_like(z_pres_log_odds) z_pres_prob = tf.ones_like(z_pres_log_odds) z_pres_kl = tf.zeros_like(z_pres_log_odds) else: z_pres_log_odds = tf.clip_by_value(z_pres_log_odds, -10, 10) z_pres_pre_sigmoid = concrete_binary_pre_sigmoid_sample( z_pres_log_odds, self.z_pres_temperature) z_pres = tf.nn.sigmoid(z_pres_pre_sigmoid) z_pres = (self.float_is_training * z_pres + (1 - self.float_is_training) * tf.round(z_pres)) z_pres_prob = tf.nn.sigmoid(z_pres_log_odds) z_pres_kl = concrete_binary_sample_kl( z_pres_pre_sigmoid, z_pres_log_odds, self.z_pres_temperature, self.z_pres_prior_log_odds, self.z_pres_temperature, ) stopping_sum += (1.0 - z_pres) alive = tf.less(stopping_sum, self.stopping_threshold) running_digits += tf.to_int32(alive) # --- adjust reconstruction --- running_recon += tf.where( tf.tile(alive, (1, vae_output_transformed.shape[1])), z_pres * vae_output_transformed, tf.zeros_like(running_recon)) # --- add kl to loss --- kl_loss += tf.where(alive, scale_kl, tf.zeros_like(kl_loss)) kl_loss += tf.where(alive, shift_kl, tf.zeros_like(kl_loss)) kl_loss += tf.where(alive, attr_kl, tf.zeros_like(kl_loss)) kl_loss += tf.where(alive, z_pres_kl, tf.zeros_like(kl_loss)) # --- record values --- scale_ta = scale_ta.write(scale_ta.size(), scale) scale_kl_ta = scale_kl_ta.write(scale_kl_ta.size(), scale_kl) scale_std_ta = scale_std_ta.write(scale_std_ta.size(), scale_std) shift_ta = shift_ta.write(shift_ta.size(), shift) shift_kl_ta = shift_kl_ta.write(shift_kl_ta.size(), shift_kl) shift_std_ta = shift_std_ta.write(shift_std_ta.size(), shift_std) attr_ta = attr_ta.write(attr_ta.size(), attr) attr_kl_ta = attr_kl_ta.write(attr_kl_ta.size(), attr_kl) attr_std_ta = attr_std_ta.write(attr_std_ta.size(), attr_std) vae_input_ta = vae_input_ta.write(vae_input_ta.size(), tf.layers.flatten(vae_input)) vae_output_ta = vae_output_ta.write(vae_output_ta.size(), vae_output) z_pres_ta = z_pres_ta.write(z_pres_ta.size(), z_pres) z_pres_probs_ta = z_pres_probs_ta.write(z_pres_probs_ta.size(), z_pres_prob) z_pres_kl_ta = z_pres_kl_ta.write(z_pres_kl_ta.size(), z_pres_kl) return ( step + 1, stopping_sum, next_state, running_recon, kl_loss, running_digits, scale_ta, scale_kl_ta, scale_std_ta, shift_ta, shift_kl_ta, shift_std_ta, attr_ta, attr_kl_ta, attr_std_ta, z_pres_ta, z_pres_probs_ta, z_pres_kl_ta, vae_input_ta, vae_output_ta, scale, shift, attr, z_pres, )
def build_representation(self): # --- init modules --- if self.encoder is None: self.encoder = self.build_encoder(scope="encoder") if "encoder" in self.fixed_weights: self.encoder.fix_variables() if self.cell is None and self.build_cell is not None: self.cell = cfg.build_cell(scope="cell") if "cell" in self.fixed_weights: self.cell.fix_variables() if self.decoder is None: self.decoder = cfg.build_decoder(scope="decoder") if "decoder" in self.fixed_weights: self.decoder.fix_variables() # --- encode --- inp_trailing_shape = tf_shape(self.inp)[2:] video = tf.reshape(self.inp, (self.batch_size * self.dynamic_n_frames, *inp_trailing_shape)) encoder_output = self.encoder(video, 2 * self.A, self.is_training) eo_trailing_shape = tf_shape(encoder_output)[1:] encoder_output = tf.reshape( encoder_output, (self.batch_size, self.dynamic_n_frames, *eo_trailing_shape)) if self.cell is None: attr = encoder_output else: if self.flat_latent: n_trailing_dims = int(np.prod(eo_trailing_shape)) encoder_output = tf.reshape( encoder_output, (self.batch_size, self.dynamic_n_frames, n_trailing_dims)) else: raise Exception("NotImplemented") n_objects = int(np.prod(eo_trailing_shape[:-1])) D = eo_trailing_shape[-1] encoder_output = tf.reshape( encoder_output, (self.batch_size, self.dynamic_n_frames, n_objects, D)) encoder_output = tf.layers.flatten(encoder_output) attr, final_state = dynamic_rnn( self.cell, encoder_output, initial_state=self.cell.zero_state(self.batch_size, tf.float32), parallel_iterations=1, swap_memory=False, time_major=False) attr_mean, attr_log_std = tf.split(attr, 2, axis=-1) attr_std = tf.math.softplus(attr_log_std) if not self.noisy: attr_std = tf.zeros_like(attr_std) attr, attr_kl = normal_vae(attr_mean, attr_std, self.attr_prior_mean, self.attr_prior_std) self._tensors.update(attr_mean=attr_mean, attr_std=attr_std, attr_kl=attr_kl, attr=attr) # --- decode --- decoder_input = tf.reshape(attr, (self.batch_size*self.dynamic_n_frames, *tf_shape(attr)[2:])) reconstruction = self.decoder(decoder_input, tf_shape(self.inp)[2:], self.is_training) reconstruction = reconstruction[:, :self.obs_shape[1], :self.obs_shape[2], :] reconstruction = tf.reshape(reconstruction, (self.batch_size, self.dynamic_n_frames, *self.obs_shape[1:])) reconstruction = tf.nn.sigmoid(tf.clip_by_value(reconstruction, -10, 10)) self._tensors["output"] = reconstruction # --- losses --- if self.train_kl: self.losses['attr_kl'] = tf_mean_sum(self._tensors["attr_kl"]) if self.train_reconstruction: self._tensors['per_pixel_reconstruction_loss'] = xent_loss(pred=reconstruction, label=self.inp) self.losses['reconstruction'] = tf_mean_sum(self._tensors['per_pixel_reconstruction_loss'])
def build_background(self): if cfg.background_cfg.mode == "colour": rgb = np.array(to_rgb(cfg.background_cfg.colour))[None, None, None, :] background = rgb * tf.ones_like(self.inp) elif cfg.background_cfg.mode == "learn_solid": # Learn a solid colour for the background self.solid_background_logits = tf.get_variable("solid_background", initializer=[0.0, 0.0, 0.0]) if "background" in self.fixed_weights: tf.add_to_collection(FIXED_COLLECTION, self.solid_background_logits) solid_background = tf.nn.sigmoid(10 * self.solid_background_logits) background = solid_background[None, None, None, :] * tf.ones_like(self.inp) elif cfg.background_cfg.mode == "scalor": pass elif cfg.background_cfg.mode == "learn": self.maybe_build_subnet("background_encoder") self.maybe_build_subnet("background_decoder") # Here I'm just encoding the first frame... bg_attr = self.background_encoder(self.inp[:, 0], 2 * cfg.background_cfg.A, self.is_training) bg_attr_mean, bg_attr_log_std = tf.split(bg_attr, 2, axis=-1) bg_attr_std = tf.exp(bg_attr_log_std) if not self.noisy: bg_attr_std = tf.zeros_like(bg_attr_std) bg_attr, bg_attr_kl = normal_vae(bg_attr_mean, bg_attr_std, self.attr_prior_mean, self.attr_prior_std) self._tensors.update( bg_attr_mean=bg_attr_mean, bg_attr_std=bg_attr_std, bg_attr_kl=bg_attr_kl, bg_attr=bg_attr) # --- decode --- _, T, H, W, _ = tf_shape(self.inp) background = self.background_decoder(bg_attr, 3, self.is_training) assert len(background.shape) == 2 and background.shape[1] == 3 background = tf.nn.sigmoid(tf.clip_by_value(background, -10, 10)) background = tf.tile(background[:, None, None, None, :], (1, T, H, W, 1)) elif cfg.background_cfg.mode == "learn_and_transform": self.maybe_build_subnet("background_encoder") self.maybe_build_subnet("background_decoder") # --- encode --- n_transform_latents = 4 n_latents = (2 * cfg.background_cfg.A, 2 * n_transform_latents) bg_attr, bg_transform_params = self.background_encoder(self.inp, n_latents, self.is_training) # --- bg attributes --- bg_attr_mean, bg_attr_log_std = tf.split(bg_attr, 2, axis=-1) bg_attr_std = self.std_nonlinearity(bg_attr_log_std) bg_attr, bg_attr_kl = normal_vae(bg_attr_mean, bg_attr_std, self.attr_prior_mean, self.attr_prior_std) # --- bg location --- bg_transform_params = tf.reshape( bg_transform_params, (self.batch_size, self.dynamic_n_frames, 2*n_transform_latents)) mean, log_std = tf.split(bg_transform_params, 2, axis=2) std = self.std_nonlinearity(log_std) logits, kl = normal_vae(mean, std, 0.0, 1.0) # integrate across timesteps logits = tf.cumsum(logits, axis=1) logits = tf.reshape(logits, (self.batch_size*self.dynamic_n_frames, n_transform_latents)) y, x, h, w = tf.split(logits, n_transform_latents, axis=1) h = (0.9 - 0.5) * tf.nn.sigmoid(h) + 0.5 w = (0.9 - 0.5) * tf.nn.sigmoid(w) + 0.5 y = (1 - h) * tf.nn.tanh(y) x = (1 - w) * tf.nn.tanh(x) # --- decode --- background = self.background_decoder(bg_attr, self.image_depth, self.is_training) bg_shape = cfg.background_cfg.bg_shape background = background[:, :bg_shape[0], :bg_shape[1], :] assert background.shape[1:3] == bg_shape background_raw = tf.nn.sigmoid(tf.clip_by_value(background, -10, 10)) transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper( bg_shape, (self.image_height, self.image_width), transform_constraints) transforms = tf.concat([w, x, h, y], axis=-1) grid_coords = warper(transforms) grid_coords = tf.reshape( grid_coords, (self.batch_size, self.dynamic_n_frames, *tf_shape(grid_coords)[1:])) background = tf.contrib.resampler.resampler(background_raw, grid_coords) self._tensors.update( bg_attr_mean=bg_attr_mean, bg_attr_std=bg_attr_std, bg_attr_kl=bg_attr_kl, bg_attr=bg_attr, bg_y=tf.reshape(y, (self.batch_size, self.dynamic_n_frames, 1)), bg_x=tf.reshape(x, (self.batch_size, self.dynamic_n_frames, 1)), bg_h=tf.reshape(h, (self.batch_size, self.dynamic_n_frames, 1)), bg_w=tf.reshape(w, (self.batch_size, self.dynamic_n_frames, 1)), bg_transform_kl=kl, bg_raw=background_raw, ) elif cfg.background_cfg.mode == "data": background = self._tensors["ground_truth_background"] else: raise Exception("Unrecognized background mode: {}.".format(cfg.background_cfg.mode)) self._tensors["background"] = background[:, :self.dynamic_n_frames]