def _build_obj(self, obj_logits, is_training, **kwargs): obj_logits = self.training_wheels * tf.stop_gradient(obj_logits) + ( 1 - self.training_wheels) * obj_logits obj_logits = obj_logits / self.obj_temp obj_log_odds = tf.clip_by_value(obj_logits, -10., 10.) obj_pre_sigmoid = concrete_binary_pre_sigmoid_sample( obj_log_odds, self.obj_concrete_temp) raw_obj = tf.nn.sigmoid(obj_pre_sigmoid) if self.noisy: obj = (self.float_is_training * raw_obj + (1 - self.float_is_training) * tf.round(raw_obj)) else: obj = tf.round(raw_obj) if "obj" in self.no_gradient: obj = tf.stop_gradient(obj) if "obj" in self.fixed_values: obj = self.fixed_values['obj'] * tf.ones_like(obj) return dict( obj=obj, raw_obj=raw_obj, obj_pre_sigmoid=obj_pre_sigmoid, obj_log_odds=obj_log_odds, obj_prob=tf.nn.sigmoid(obj_log_odds), )
def _build_obj(self, obj_logit, is_training, **kwargs): obj_logit = self.training_wheels * tf.stop_gradient(obj_logit) + (1-self.training_wheels) * obj_logit obj_log_odds = tf.clip_by_value(obj_logit / self.obj_temp, -10., 10.) obj_pre_sigmoid = ( self._noisy * concrete_binary_pre_sigmoid_sample(obj_log_odds, self.obj_concrete_temp) + (1 - self._noisy) * obj_log_odds ) obj = tf.nn.sigmoid(obj_pre_sigmoid) return dict( obj_log_odds=obj_log_odds, obj_prob=tf.nn.sigmoid(obj_log_odds), obj_pre_sigmoid=obj_pre_sigmoid, obj=obj, )
def _build_obj(self, obj_logit, is_training, **kwargs): obj_logit = self.training_wheels * tf.stop_gradient(obj_logit) + (1-self.training_wheels) * obj_logit obj_log_odds = tf.clip_by_value(obj_logit / self.obj_temp, -10., 10.) if self.noisy: obj_pre_sigmoid = concrete_binary_pre_sigmoid_sample(obj_log_odds, self.obj_concrete_temp) else: obj_pre_sigmoid = obj_log_odds obj = tf.nn.sigmoid(obj_pre_sigmoid) render_obj = ( self.float_is_training * obj + (1 - self.float_is_training) * tf.round(obj) ) return dict( obj_log_odds=obj_log_odds, obj_prob=tf.nn.sigmoid(obj_log_odds), obj_pre_sigmoid=obj_pre_sigmoid, obj=obj, render_obj=render_obj, )
def _body(self, inp, features, objects, is_posterior): """ Summary of how updates are done for the different variables: glimpse': glimpse_params = where + 0.1 * predicted_logit where_y/x: new_where_y/x = where_y/x + where_t_scale * tanh(predicted_logit) where_h/w: new_where_h/w_logit = where_h/w_logit + predicted_logit what: Hard to summarize here, taken from SQAIR. Kind of like an LSTM. depth: new_depth_logit = depth_logit + predicted_logit obj: new_obj = obj * sigmoid(predicted_logit) """ batch_size, n_objects, _ = tf_shape(features) new_objects = AttrDict() is_posterior_tf = tf.ones_like(features[..., 0:2]) if is_posterior: is_posterior_tf = is_posterior_tf * [1, 0] else: is_posterior_tf = is_posterior_tf * [0, 1] base_features = tf.concat([features, is_posterior_tf], axis=-1) cyt, cxt, ys, xs = tf.split(objects.normalized_box, 4, axis=-1) # Do this regardless of is_posterior, otherwise ScopedFunction gets messed up glimpse_dim = self.object_shape[0] * self.object_shape[1] glimpse_prime_params = apply_object_wise(self.glimpse_prime_network, base_features, output_size=4 + 2 * glimpse_dim, is_training=self.is_training) glimpse_prime_params, glimpse_prime_mask_logit, glimpse_mask_logit = \ tf.split(glimpse_prime_params, [4, glimpse_dim, glimpse_dim], axis=-1) if is_posterior: # --- obtain final parameters for glimpse prime by modifying current pose --- _yt, _xt, _ys, _xs = tf.split(glimpse_prime_params, 4, axis=-1) g_yt = cyt + 0.1 * _yt g_xt = cxt + 0.1 * _xt g_ys = ys + 0.1 * _ys g_xs = xs + 0.1 * _xs # --- extract glimpse prime --- _, image_height, image_width, _ = tf_shape(inp) g_yt, g_xt, g_ys, g_xs = coords_to_image_space( g_yt, g_xt, g_ys, g_xs, (image_height, image_width), self.anchor_box, top_left=False) glimpse_prime = extract_affine_glimpse(inp, self.object_shape, g_yt, g_xt, g_ys, g_xs, self.edge_resampler) else: g_yt = tf.zeros_like(cyt) g_xt = tf.zeros_like(cxt) g_ys = tf.zeros_like(ys) g_xs = tf.zeros_like(xs) glimpse_prime = tf.zeros( (batch_size, n_objects, *self.object_shape, self.image_depth)) glimpse_prime_mask = tf.nn.sigmoid(glimpse_prime_mask_logit + 1.) leading_mask_shape = tf_shape(glimpse_prime)[:-1] glimpse_prime_mask = tf.reshape(glimpse_prime_mask, (*leading_mask_shape, 1)) new_objects.update( glimpse_prime_box=tf.concat([g_yt, g_xt, g_ys, g_xs], axis=-1), glimpse_prime=glimpse_prime, glimpse_prime_mask=glimpse_prime_mask, ) glimpse_prime *= glimpse_prime_mask # --- encode glimpse --- encoded_glimpse_prime = apply_object_wise(self.glimpse_prime_encoder, glimpse_prime, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse_prime = tf.zeros((batch_size, n_objects, self.A), dtype=tf.float32) # --- position and scale --- # roughly: # base_features == temporal_state, encoded_glimpse_prime == hidden_output # hidden_output conditions on encoded_glimpse, and that's the only place encoded_glimpse_prime is used. # Here SQAIR conditions on the actual location values from the previous timestep, but we leave that out for now. d_box_inp = tf.concat([base_features, encoded_glimpse_prime], axis=-1) d_box_params = apply_object_wise(self.d_box_network, d_box_inp, output_size=8, is_training=self.is_training) d_box_mean, d_box_log_std = tf.split(d_box_params, 2, axis=-1) d_box_std = self.std_nonlinearity(d_box_log_std) d_box_mean = self.training_wheels * tf.stop_gradient(d_box_mean) + ( 1 - self.training_wheels) * d_box_mean d_box_std = self.training_wheels * tf.stop_gradient(d_box_std) + ( 1 - self.training_wheels) * d_box_std d_yt_mean, d_xt_mean, d_ys, d_xs = tf.split(d_box_mean, 4, axis=-1) d_yt_std, d_xt_std, ys_std, xs_std = tf.split(d_box_std, 4, axis=-1) # --- position --- # We predict position a bit differently from scale. For scale we want to put a prior on the actual value of # the scale, whereas for position we want to put a prior on the difference in position over timesteps. d_yt_logit = Normal(loc=d_yt_mean, scale=d_yt_std).sample() d_xt_logit = Normal(loc=d_xt_mean, scale=d_xt_std).sample() d_yt = self.where_t_scale * tf.nn.tanh(d_yt_logit) d_xt = self.where_t_scale * tf.nn.tanh(d_xt_logit) new_cyt = cyt + d_yt new_cxt = cxt + d_xt new_abs_posn = objects.abs_posn + tf.concat([d_yt, d_xt], axis=-1) # --- scale --- new_ys_mean = objects.ys_logit + d_ys new_xs_mean = objects.xs_logit + d_xs new_ys_logit = Normal(loc=new_ys_mean, scale=ys_std).sample() new_xs_logit = Normal(loc=new_xs_mean, scale=xs_std).sample() new_ys = float(self.max_hw - self.min_hw) * tf.nn.sigmoid( tf.clip_by_value(new_ys_logit, -10, 10)) + self.min_hw new_xs = float(self.max_hw - self.min_hw) * tf.nn.sigmoid( tf.clip_by_value(new_xs_logit, -10, 10)) + self.min_hw if self.use_abs_posn: box_params = tf.concat([ new_abs_posn, d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit ], axis=-1) else: box_params = tf.concat( [d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit], axis=-1) new_objects.update( abs_posn=new_abs_posn, yt=new_cyt, xt=new_cxt, ys=new_ys, xs=new_xs, normalized_box=tf.concat([new_cyt, new_cxt, new_ys, new_xs], axis=-1), d_yt_logit=d_yt_logit, d_xt_logit=d_xt_logit, ys_logit=new_ys_logit, xs_logit=new_xs_logit, d_yt_logit_mean=d_yt_mean, d_xt_logit_mean=d_xt_mean, ys_logit_mean=new_ys_mean, xs_logit_mean=new_xs_mean, d_yt_logit_std=d_yt_std, d_xt_logit_std=d_xt_std, ys_logit_std=ys_std, xs_logit_std=xs_std, ) # --- attributes --- # --- extract a glimpse using new box --- if is_posterior: _, image_height, image_width, _ = tf_shape(inp) _new_cyt, _new_cxt, _new_ys, _new_xs = coords_to_image_space( new_cyt, new_cxt, new_ys, new_xs, (image_height, image_width), self.anchor_box, top_left=False) glimpse = extract_affine_glimpse(inp, self.object_shape, _new_cyt, _new_cxt, _new_ys, _new_xs, self.edge_resampler) else: glimpse = tf.zeros( (batch_size, n_objects, *self.object_shape, self.image_depth)) glimpse_mask = tf.nn.sigmoid(glimpse_mask_logit + 1.) leading_mask_shape = tf_shape(glimpse)[:-1] glimpse_mask = tf.reshape(glimpse_mask, (*leading_mask_shape, 1)) glimpse *= glimpse_mask encoded_glimpse = apply_object_wise(self.glimpse_encoder, glimpse, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse = tf.zeros((batch_size, n_objects, self.A), dtype=tf.float32) # --- predict change in attributes --- # so under sqair we mix between three different values for the attributes: # 1. value from previous timestep # 2. value predicted directly from glimpse # 3. value predicted based on update of temporal cell...this update conditions on hidden_output, # the prediction in #2., and the where values. # How to do this given that we are predicting the change in attr? We could just directly predict # the attr instead, but call it d_attr. After all, it is in this function that we control # whether d_attr is added to attr. # So, make a prediction based on just the input: attr_from_inp = apply_object_wise(self.predict_attr_inp, encoded_glimpse, output_size=2 * self.A, is_training=self.is_training) attr_from_inp_mean, attr_from_inp_log_std = tf.split(attr_from_inp, [self.A, self.A], axis=-1) attr_from_inp_std = self.std_nonlinearity(attr_from_inp_log_std) # And then a prediction which takes the past into account (predicting gate values at the same time): attr_from_temp_inp = tf.concat( [base_features, box_params, encoded_glimpse], axis=-1) attr_from_temp = apply_object_wise(self.predict_attr_temp, attr_from_temp_inp, output_size=5 * self.A, is_training=self.is_training) (attr_from_temp_mean, attr_from_temp_log_std, f_gate_logit, i_gate_logit, t_gate_logit) = tf.split(attr_from_temp, 5, axis=-1) attr_from_temp_std = self.std_nonlinearity(attr_from_temp_log_std) # bias the gates f_gate = tf.nn.sigmoid(f_gate_logit + 1) * .9999 i_gate = tf.nn.sigmoid(i_gate_logit + 1) * .9999 t_gate = tf.nn.sigmoid(t_gate_logit + 1) * .9999 attr_mean = f_gate * objects.attr + ( 1 - i_gate) * attr_from_inp_mean + (1 - t_gate) * attr_from_temp_mean attr_std = (1 - i_gate) * attr_from_inp_std + ( 1 - t_gate) * attr_from_temp_std new_attr = Normal(loc=attr_mean, scale=attr_std).sample() # --- apply change in attributes --- new_objects.update( attr=new_attr, d_attr=new_attr - objects.attr, d_attr_mean=attr_mean - objects.attr, d_attr_std=attr_std, f_gate=f_gate, i_gate=i_gate, t_gate=t_gate, glimpse=glimpse, glimpse_mask=glimpse_mask, ) # --- z --- d_z_inp = tf.concat( [base_features, box_params, new_attr, encoded_glimpse], axis=-1) d_z_params = apply_object_wise(self.d_z_network, d_z_inp, output_size=2, is_training=self.is_training) d_z_mean, d_z_log_std = tf.split(d_z_params, 2, axis=-1) d_z_std = self.std_nonlinearity(d_z_log_std) d_z_mean = self.training_wheels * tf.stop_gradient(d_z_mean) + ( 1 - self.training_wheels) * d_z_mean d_z_std = self.training_wheels * tf.stop_gradient(d_z_std) + ( 1 - self.training_wheels) * d_z_std d_z_logit = Normal(loc=d_z_mean, scale=d_z_std).sample() new_z_logit = objects.z_logit + d_z_logit new_z = self.z_nonlinearity(new_z_logit) new_objects.update( z=new_z, z_logit=new_z_logit, d_z_logit=d_z_logit, d_z_logit_mean=d_z_mean, d_z_logit_std=d_z_std, ) # --- obj --- d_obj_inp = tf.concat( [base_features, box_params, new_attr, new_z, encoded_glimpse], axis=-1) d_obj_logit = apply_object_wise(self.d_obj_network, d_obj_inp, output_size=1, is_training=self.is_training) d_obj_logit = self.training_wheels * tf.stop_gradient(d_obj_logit) + ( 1 - self.training_wheels) * d_obj_logit d_obj_log_odds = tf.clip_by_value(d_obj_logit / self.obj_temp, -10., 10.) d_obj_pre_sigmoid = (self._noisy * concrete_binary_pre_sigmoid_sample( d_obj_log_odds, self.obj_concrete_temp) + (1 - self._noisy) * d_obj_log_odds) d_obj = tf.nn.sigmoid(d_obj_pre_sigmoid) new_obj = objects.obj * d_obj new_objects.update( d_obj_log_odds=d_obj_log_odds, d_obj_prob=tf.nn.sigmoid(d_obj_log_odds), d_obj_pre_sigmoid=d_obj_pre_sigmoid, d_obj=d_obj, obj=new_obj, ) # --- update each object's hidden state -- cell_input = tf.concat([box_params, new_attr, new_z, new_obj], axis=-1) if is_posterior: _, new_objects.prop_state = apply_object_wise( self.cell, cell_input, objects.prop_state) new_objects.prior_prop_state = new_objects.prop_state else: _, new_objects.prior_prop_state = apply_object_wise( self.cell, cell_input, objects.prior_prop_state) new_objects.prop_state = new_objects.prior_prop_state return new_objects
def _body(self, inp, features, objects, is_posterior): batch_size, n_objects, _ = tf_shape(features) new_objects = AttrDict() is_posterior_tf = tf.ones_like(features[..., 0:2]) if is_posterior: is_posterior_tf = is_posterior_tf * [1, 0] else: is_posterior_tf = is_posterior_tf * [0, 1] base_features = tf.concat([features, is_posterior_tf], axis=-1) cyt, cxt, ys, xs = tf.split(objects.normalized_box, 4, axis=-1) if self.learn_glimpse_prime: # Do this regardless of is_posterior, otherwise ScopedFunction gets messed up glimpse_prime_params = apply_object_wise( self.glimpse_prime_network, base_features, output_size=4, is_training=self.is_training) else: glimpse_prime_params = tf.zeros_like(base_features[..., :4]) if is_posterior: if self.learn_glimpse_prime: # --- obtain final parameters for glimpse prime by modifying current pose --- _yt, _xt, _ys, _xs = tf.split(glimpse_prime_params, 4, axis=-1) # This is how it is done in SQAIR g_yt = cyt + 0.1 * _yt g_xt = cxt + 0.1 * _xt g_ys = ys + 0.1 * _ys g_xs = xs + 0.1 * _xs # g_yt = cyt + self.glimpse_prime_scale * tf.nn.tanh(_yt) # g_xt = cxt + self.glimpse_prime_scale * tf.nn.tanh(_xt) # g_ys = ys + self.glimpse_prime_scale * tf.nn.tanh(_ys) # g_xs = xs + self.glimpse_prime_scale * tf.nn.tanh(_xs) else: g_yt = cyt g_xt = cxt g_ys = self.glimpse_prime_scale * ys g_xs = self.glimpse_prime_scale * xs # --- extract glimpse prime --- _, image_height, image_width, _ = tf_shape(inp) g_yt, g_xt, g_ys, g_xs = coords_to_image_space( g_yt, g_xt, g_ys, g_xs, (image_height, image_width), self.anchor_box, top_left=False) glimpse_prime = extract_affine_glimpse(inp, self.object_shape, g_yt, g_xt, g_ys, g_xs, self.edge_resampler) else: g_yt = tf.zeros_like(cyt) g_xt = tf.zeros_like(cxt) g_ys = tf.zeros_like(ys) g_xs = tf.zeros_like(xs) glimpse_prime = tf.zeros( (batch_size, n_objects, *self.object_shape, self.image_depth)) new_objects.update(glimpse_prime_box=tf.concat( [g_yt, g_xt, g_ys, g_xs], axis=-1), ) # --- encode glimpse --- encoded_glimpse_prime = apply_object_wise(self.glimpse_prime_encoder, glimpse_prime, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse_prime = tf.zeros((batch_size, n_objects, self.A), dtype=tf.float32) # --- position and scale --- d_box_inp = tf.concat([base_features, encoded_glimpse_prime], axis=-1) d_box_params = apply_object_wise(self.d_box_network, d_box_inp, output_size=8, is_training=self.is_training) d_box_mean, d_box_log_std = tf.split(d_box_params, 2, axis=-1) d_box_std = self.std_nonlinearity(d_box_log_std) d_box_mean = self.training_wheels * tf.stop_gradient(d_box_mean) + ( 1 - self.training_wheels) * d_box_mean d_box_std = self.training_wheels * tf.stop_gradient(d_box_std) + ( 1 - self.training_wheels) * d_box_std d_yt_mean, d_xt_mean, d_ys, d_xs = tf.split(d_box_mean, 4, axis=-1) d_yt_std, d_xt_std, ys_std, xs_std = tf.split(d_box_std, 4, axis=-1) # --- position --- # We predict position a bit differently from scale. For scale we want to put a prior on the actual value of # the scale, whereas for position we want to put a prior on the difference in position over timesteps. d_yt_logit = Normal(loc=d_yt_mean, scale=d_yt_std).sample() d_xt_logit = Normal(loc=d_xt_mean, scale=d_xt_std).sample() d_yt = self.where_t_scale * tf.nn.tanh(d_yt_logit) d_xt = self.where_t_scale * tf.nn.tanh(d_xt_logit) new_cyt = cyt + d_yt new_cxt = cxt + d_xt new_abs_posn = objects.abs_posn + tf.concat([d_yt, d_xt], axis=-1) # --- scale --- new_ys_mean = objects.ys_logit + d_ys new_xs_mean = objects.xs_logit + d_xs new_ys_logit = Normal(loc=new_ys_mean, scale=ys_std).sample() new_xs_logit = Normal(loc=new_xs_mean, scale=xs_std).sample() new_ys = float(self.max_hw - self.min_hw) * tf.nn.sigmoid( tf.clip_by_value(new_ys_logit, -10, 10)) + self.min_hw new_xs = float(self.max_hw - self.min_hw) * tf.nn.sigmoid( tf.clip_by_value(new_xs_logit, -10, 10)) + self.min_hw # Used for conditioning if self.use_abs_posn: box_params = tf.concat([ new_abs_posn, d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit ], axis=-1) else: box_params = tf.concat( [d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit], axis=-1) new_objects.update( abs_posn=new_abs_posn, yt=new_cyt, xt=new_cxt, ys=new_ys, xs=new_xs, normalized_box=tf.concat([new_cyt, new_cxt, new_ys, new_xs], axis=-1), d_yt_logit=d_yt_logit, d_xt_logit=d_xt_logit, ys_logit=new_ys_logit, xs_logit=new_xs_logit, d_yt_logit_mean=d_yt_mean, d_xt_logit_mean=d_xt_mean, ys_logit_mean=new_ys_mean, xs_logit_mean=new_xs_mean, d_yt_logit_std=d_yt_std, d_xt_logit_std=d_xt_std, ys_logit_std=ys_std, xs_logit_std=xs_std, glimpse_prime=glimpse_prime, ) # --- attributes --- # --- extract a glimpse using new box --- if is_posterior: _, image_height, image_width, _ = tf_shape(inp) _new_cyt, _new_cxt, _new_ys, _new_xs = coords_to_image_space( new_cyt, new_cxt, new_ys, new_xs, (image_height, image_width), self.anchor_box, top_left=False) glimpse = extract_affine_glimpse(inp, self.object_shape, _new_cyt, _new_cxt, _new_ys, _new_xs, self.edge_resampler) else: glimpse = tf.zeros( (batch_size, n_objects, *self.object_shape, self.image_depth)) encoded_glimpse = apply_object_wise(self.glimpse_encoder, glimpse, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse = tf.zeros((batch_size, n_objects, self.A), dtype=tf.float32) # --- predict change in attributes --- d_attr_inp = tf.concat([base_features, box_params, encoded_glimpse], axis=-1) d_attr_params = apply_object_wise(self.d_attr_network, d_attr_inp, output_size=2 * self.A + 1, is_training=self.is_training) d_attr_mean, d_attr_log_std, gate_logit = tf.split(d_attr_params, [self.A, self.A, 1], axis=-1) d_attr_std = self.std_nonlinearity(d_attr_log_std) gate = tf.nn.sigmoid(gate_logit) if self.gate_d_attr: d_attr_mean *= gate d_attr = Normal(loc=d_attr_mean, scale=d_attr_std).sample() # --- apply change in attributes --- new_attr = objects.attr + d_attr new_objects.update( attr=new_attr, d_attr=d_attr, d_attr_mean=d_attr_mean, d_attr_std=d_attr_std, glimpse=glimpse, d_attr_gate=gate, ) # --- z --- d_z_inp = tf.concat( [base_features, box_params, new_attr, encoded_glimpse], axis=-1) d_z_params = apply_object_wise(self.d_z_network, d_z_inp, output_size=2, is_training=self.is_training) d_z_mean, d_z_log_std = tf.split(d_z_params, 2, axis=-1) d_z_std = self.std_nonlinearity(d_z_log_std) d_z_mean = self.training_wheels * tf.stop_gradient(d_z_mean) + ( 1 - self.training_wheels) * d_z_mean d_z_std = self.training_wheels * tf.stop_gradient(d_z_std) + ( 1 - self.training_wheels) * d_z_std d_z_logit = Normal(loc=d_z_mean, scale=d_z_std).sample() new_z_logit = objects.z_logit + d_z_logit new_z = self.z_nonlinearity(new_z_logit) new_objects.update( z=new_z, z_logit=new_z_logit, d_z_logit=d_z_logit, d_z_logit_mean=d_z_mean, d_z_logit_std=d_z_std, ) # --- obj --- d_obj_inp = tf.concat( [base_features, box_params, new_attr, new_z, encoded_glimpse], axis=-1) d_obj_logit = apply_object_wise(self.d_obj_network, d_obj_inp, output_size=1, is_training=self.is_training) d_obj_logit = self.training_wheels * tf.stop_gradient(d_obj_logit) + ( 1 - self.training_wheels) * d_obj_logit d_obj_log_odds = tf.clip_by_value(d_obj_logit / self.obj_temp, -10., 10.) d_obj_pre_sigmoid = (self._noisy * concrete_binary_pre_sigmoid_sample( d_obj_log_odds, self.obj_concrete_temp) + (1 - self._noisy) * d_obj_log_odds) d_obj = tf.nn.sigmoid(d_obj_pre_sigmoid) new_obj = objects.obj * d_obj new_objects.update( d_obj_log_odds=d_obj_log_odds, d_obj_prob=tf.nn.sigmoid(d_obj_log_odds), d_obj_pre_sigmoid=d_obj_pre_sigmoid, d_obj=d_obj, obj=new_obj, ) # --- update each object's hidden state -- cell_input = tf.concat([box_params, new_attr, new_z, new_obj], axis=-1) if is_posterior: _, new_objects.prop_state = apply_object_wise( self.cell, cell_input, objects.prop_state) new_objects.prior_prop_state = new_objects.prop_state else: _, new_objects.prior_prop_state = apply_object_wise( self.cell, cell_input, objects.prior_prop_state) new_objects.prop_state = new_objects.prior_prop_state return new_objects
def body(step, stopping_sum, prev_state, running_recon, kl_loss, running_digits, scale_ta, scale_kl_ta, scale_std_ta, shift_ta, shift_kl_ta, shift_std_ta, attr_ta, attr_kl_ta, attr_std_ta, z_pres_ta, z_pres_probs_ta, z_pres_kl_ta, vae_input_ta, vae_output_ta, scale, shift, attr, z_pres): if self.difference_air: inp = (self._tensors["inp"] - tf.reshape(running_recon, (self.batch_size, *self.obs_shape))) encoded_inp = self.image_encoder(inp, 0, self.is_training) encoded_inp = tf.layers.flatten(encoded_inp) else: encoded_inp = self.encoded_inp if self.complete_rnn_input: rnn_input = tf.concat( [encoded_inp, scale, shift, attr, z_pres], axis=1) else: rnn_input = encoded_inp hidden_rep, next_state = self.cell(rnn_input, prev_state) outputs = self.output_network(hidden_rep, 9, self.is_training) (scale_mean, scale_log_std, shift_mean, shift_log_std, z_pres_log_odds) = tf.split(outputs, [2, 2, 2, 2, 1], axis=1) # --- scale --- scale_std = tf.exp(scale_log_std) scale_mean = self.apply_fixed_value("scale_mean", scale_mean) scale_std = self.apply_fixed_value("scale_std", scale_std) scale_logits, scale_kl = normal_vae(scale_mean, scale_std, self.scale_prior_mean, self.scale_prior_std) scale_kl = tf.reduce_sum(scale_kl, axis=1, keepdims=True) scale = tf.nn.sigmoid(tf.clip_by_value(scale_logits, -10, 10)) # --- shift --- shift_std = tf.exp(shift_log_std) shift_mean = self.apply_fixed_value("shift_mean", shift_mean) shift_std = self.apply_fixed_value("shift_std", shift_std) shift_logits, shift_kl = normal_vae(shift_mean, shift_std, self.shift_prior_mean, self.shift_prior_std) shift_kl = tf.reduce_sum(shift_kl, axis=1, keepdims=True) shift = tf.nn.tanh(tf.clip_by_value(shift_logits, -10, 10)) # --- Extract windows from scene --- w, h = scale[:, 0:1], scale[:, 1:2] x, y = shift[:, 0:1], shift[:, 1:2] theta = tf.concat( [w, tf.zeros_like(w), x, tf.zeros_like(h), h, y], axis=1) theta = tf.reshape(theta, (-1, 2, 3)) vae_input = transformer(self._tensors["inp"], theta, self.object_shape) # This is a necessary reshape, as the output of transformer will have unknown dims vae_input = tf.reshape( vae_input, (self.batch_size, *self.object_shape, self.image_depth)) # --- Apply Object-level VAE (object encoder/object decoder) to windows --- attr = self.object_encoder(vae_input, 2 * self.A, self.is_training) attr_mean, attr_log_std = tf.split(attr, 2, axis=1) attr_std = tf.exp(attr_log_std) attr, attr_kl = normal_vae(attr_mean, attr_std, self.attr_prior_mean, self.attr_prior_std) attr_kl = tf.reduce_sum(attr_kl, axis=1, keepdims=True) vae_output = self.object_decoder( attr, self.object_shape[0] * self.object_shape[1] * self.image_depth, self.is_training) vae_output = tf.nn.sigmoid(tf.clip_by_value(vae_output, -10, 10)) # --- Place reconstructed objects in image --- theta_inverse = tf.concat([ 1. / w, tf.zeros_like(w), -x / w, tf.zeros_like(h), 1. / h, -y / h ], axis=1) theta_inverse = tf.reshape(theta_inverse, (-1, 2, 3)) vae_output_transformed = transformer( tf.reshape(vae_output, ( self.batch_size, *self.object_shape, self.image_depth, )), theta_inverse, self.obs_shape[:2]) vae_output_transformed = tf.reshape(vae_output_transformed, [ self.batch_size, self.image_height * self.image_width * self.image_depth ]) # --- z_pres --- if self.run_all_time_steps: z_pres = tf.ones_like(z_pres_log_odds) z_pres_prob = tf.ones_like(z_pres_log_odds) z_pres_kl = tf.zeros_like(z_pres_log_odds) else: z_pres_log_odds = tf.clip_by_value(z_pres_log_odds, -10, 10) z_pres_pre_sigmoid = concrete_binary_pre_sigmoid_sample( z_pres_log_odds, self.z_pres_temperature) z_pres = tf.nn.sigmoid(z_pres_pre_sigmoid) z_pres = (self.float_is_training * z_pres + (1 - self.float_is_training) * tf.round(z_pres)) z_pres_prob = tf.nn.sigmoid(z_pres_log_odds) z_pres_kl = concrete_binary_sample_kl( z_pres_pre_sigmoid, z_pres_log_odds, self.z_pres_temperature, self.z_pres_prior_log_odds, self.z_pres_temperature, ) stopping_sum += (1.0 - z_pres) alive = tf.less(stopping_sum, self.stopping_threshold) running_digits += tf.to_int32(alive) # --- adjust reconstruction --- running_recon += tf.where( tf.tile(alive, (1, vae_output_transformed.shape[1])), z_pres * vae_output_transformed, tf.zeros_like(running_recon)) # --- add kl to loss --- kl_loss += tf.where(alive, scale_kl, tf.zeros_like(kl_loss)) kl_loss += tf.where(alive, shift_kl, tf.zeros_like(kl_loss)) kl_loss += tf.where(alive, attr_kl, tf.zeros_like(kl_loss)) kl_loss += tf.where(alive, z_pres_kl, tf.zeros_like(kl_loss)) # --- record values --- scale_ta = scale_ta.write(scale_ta.size(), scale) scale_kl_ta = scale_kl_ta.write(scale_kl_ta.size(), scale_kl) scale_std_ta = scale_std_ta.write(scale_std_ta.size(), scale_std) shift_ta = shift_ta.write(shift_ta.size(), shift) shift_kl_ta = shift_kl_ta.write(shift_kl_ta.size(), shift_kl) shift_std_ta = shift_std_ta.write(shift_std_ta.size(), shift_std) attr_ta = attr_ta.write(attr_ta.size(), attr) attr_kl_ta = attr_kl_ta.write(attr_kl_ta.size(), attr_kl) attr_std_ta = attr_std_ta.write(attr_std_ta.size(), attr_std) vae_input_ta = vae_input_ta.write(vae_input_ta.size(), tf.layers.flatten(vae_input)) vae_output_ta = vae_output_ta.write(vae_output_ta.size(), vae_output) z_pres_ta = z_pres_ta.write(z_pres_ta.size(), z_pres) z_pres_probs_ta = z_pres_probs_ta.write(z_pres_probs_ta.size(), z_pres_prob) z_pres_kl_ta = z_pres_kl_ta.write(z_pres_kl_ta.size(), z_pres_kl) return ( step + 1, stopping_sum, next_state, running_recon, kl_loss, running_digits, scale_ta, scale_kl_ta, scale_std_ta, shift_ta, shift_kl_ta, shift_std_ta, attr_ta, attr_kl_ta, attr_std_ta, z_pres_ta, z_pres_probs_ta, z_pres_kl_ta, vae_input_ta, vae_output_ta, scale, shift, attr, z_pres, )