def _call(self, input_signal, input_locs, output_locs, is_training): if not self.is_built: self.value_func = self.build_mlp(scope="value_func") self.after_func = self.build_mlp(scope="after") if self.do_object_wise: self.object_wise_func = self.build_object_wise( scope="object_wise") self.is_built = True batch_size, n_inp, _ = tf_shape(input_signal) loc_dim = tf_shape(input_locs)[-1] n_outp = tf_shape(output_locs)[-2] input_locs = tf.broadcast_to(input_locs, (batch_size, n_inp, loc_dim)) output_locs = tf.broadcast_to(output_locs, (batch_size, n_outp, loc_dim)) dist = output_locs[:, :, None, :] - input_locs[:, None, :, :] proximity = tf.exp(-0.5 * tf.reduce_sum( (dist / self.kernel_std)**2, axis=3)) proximity = proximity / (2 * np.pi)**( 0.5 * loc_dim) / self.kernel_std**loc_dim V = apply_object_wise( self.value_func, input_signal, output_size=self.n_hidden, is_training=is_training) # (batch_size, n_inp, value_dim) result = tf.matmul(proximity, V) # (batch_size, n_outp, value_dim) # `after_func` is applied to the concatenation of the head outputs, and the result is added to the original # signal. Next, if `object_wise_func` is not None and `do_object_wise` is True, object_wise_func is # applied object wise and in a ResNet-style manner. output = apply_object_wise(self.after_func, result, output_size=self.n_hidden, is_training=is_training) output = tf.layers.dropout(output, self.p_dropout, training=is_training) signal = tf.contrib.layers.layer_norm(output) if self.do_object_wise: output = apply_object_wise(self.object_wise_func, signal, output_size=self.n_hidden, is_training=is_training) output = tf.layers.dropout(output, self.p_dropout, training=is_training) signal = tf.contrib.layers.layer_norm(signal + output) return signal
def _call(self, input_locs, input_features, reference_locs, reference_features, is_training): """ Assumes input_features and reference_features are identical. """ assert self.do_object_wise if not self.is_built: self.object_wise_func = self.build_mlp(scope="object_wise_func") self.is_built = True object_wise = apply_object_wise( self.object_wise_func, reference_features, output_size=self.n_hidden, is_training=is_training) # (B, n_ref, n_hidden) return object_wise
def _loop_body(self, f, conf, hidden_states, *tensor_arrays): batch_size = self.batch_size memory = self.encoded_frames[:, f] # (batch_size, H*W, S) delta = 0.0001 * tf.range(self.n_trackers, dtype=tf.float32)[:, None, None] sort_criteria = tf.round(conf) - delta sorted_order = tf.contrib.framework.argsort(sort_criteria, axis=0, direction='DESCENDING') sorted_order = tf.reshape(sorted_order, (self.n_trackers, batch_size, 1)) order = tf.cond( tf.logical_or(tf.equal(f, 0), not self.prioritize), lambda: tf.tile( tf.range(self.n_trackers)[:, None, None], (1, batch_size, 1)), lambda: sorted_order, ) order = tf.reshape(order, (self.n_trackers, batch_size, 1)) inverse_order = tf.contrib.framework.argsort(order, axis=0, direction='ASCENDING') tensors = defaultdict(list) for i in range(self.n_trackers): tensors["memory_activation"].append( tf.reduce_mean(tf.abs(memory), axis=2)) # --- apply ordering if applicable --- indices = order[i] # (batch_size, 1) indexor = tf.concat( [indices, tf.range(batch_size)[:, None]], axis=1) # (batch_size, 2) _hidden_states = tf.gather_nd(hidden_states, indexor) # (batch_size, n_hidden) # --- access the memory using spatial attention --- keys = self.key_network(_hidden_states, self.S, self.is_training) # (batch_size, self.S) beta_logit = self.beta_network(_hidden_states, 1, self.is_training) # (batch_size, 1) # beta = 1 + tf.math.softplus(beta_logit) beta_pos = tf.maximum(0.0, beta_logit) beta_neg = tf.minimum(0.0, beta_logit) beta = tf.log1p(tf.exp(beta_neg)) + beta_pos + tf.log1p( tf.exp(-beta_pos)) + (1 - np.log(2)) _memory = tf.identity(memory) _memory = limit_grad_norm(_memory, 1.) key_activation = beta * tf_cosine_similarity( _memory, keys[:, None, :]) # (batch_size, H*W) attention_weights = tf.nn.softmax( key_activation, axis=1)[:, :, None] # (batch_size, H*W, 1) _attention_weights = tf.identity(attention_weights) _attention_weights = limit_grad_norm(_attention_weights, 1.) attention_result = tf.reduce_sum(_attention_weights * memory, axis=1) # (batch_size, S) # --- update tracker hidden state and output --- tracker_output, new_hidden = self.cell(attention_result, _hidden_states) # --- update the memory for the next trackers --- write = self.write_network(tracker_output, self.S, self.is_training) erase = self.erase_network(tracker_output, self.S, self.is_training) erase = tf.nn.sigmoid(erase) memory = ((1 - attention_weights * erase[:, None, :]) * memory + attention_weights * write[:, None, :]) tensors["hidden_states"].append(new_hidden) tensors["tracker_output"].append(tracker_output) tensors["attention_result"].append(attention_result) tensors["attention_weights"].append(attention_weights[..., 0]) tensors = {k: tf.stack(v, axis=0) for k, v in tensors.items()} # --- invert the ordering --- batch_indices = tf.tile( tf.range(batch_size)[None, :, None], (self.n_trackers, 1, 1)) inverse_indexor = tf.concat([inverse_order, batch_indices], axis=2) # (n_trackers, batch_size, 2) tensors = { k: tf.gather_nd(v, inverse_indexor) for k, v in tensors.items() } # --- compute the output values --- output = apply_object_wise(self.output_network, tensors["tracker_output"], output_size=self.output_size_per_object, is_training=self.is_training) conf, layer, pose, mask, appearance = tf.split(output, [ 1, self.n_layers, 4, np.prod(self.object_shape), self.image_depth * np.prod(self.object_shape) ], axis=-1) conf = tf.abs(tf.nn.tanh(conf)) conf = (self.float_is_training * conf + (1 - self.float_is_training) * tf.round(conf)) layer = tf.nn.softmax(layer, axis=-1) layer = tf.transpose(layer, (1, 0, 2)) layer = limit_grad_norm(layer, 10.) layer = tf.transpose(layer, (1, 0, 2)) layer = tfp.distributions.RelaxedOneHotCategorical( self.layer_temperature, probs=layer).sample() pose = tf.nn.tanh(pose) mask = tfp.distributions.RelaxedBernoulli(self.mask_temperature, logits=mask).sample() if self.fixed_mask: mask = tf.ones_like(mask) appearance = tf.nn.sigmoid(appearance) output = dict(conf=conf, layer=layer, pose=pose, mask=mask, appearance=appearance, order=order, **tensors) tensor_arrays = append_to_tensor_arrays(f, output, tensor_arrays) f += 1 return [f, conf, tensors["hidden_states"], *tensor_arrays]
def _body(self, inp, features, objects, is_posterior): """ Summary of how updates are done for the different variables: glimpse': glimpse_params = where + 0.1 * predicted_logit where_y/x: new_where_y/x = where_y/x + where_t_scale * tanh(predicted_logit) where_h/w: new_where_h/w_logit = where_h/w_logit + predicted_logit what: Hard to summarize here, taken from SQAIR. Kind of like an LSTM. depth: new_depth_logit = depth_logit + predicted_logit obj: new_obj = obj * sigmoid(predicted_logit) """ batch_size, n_objects, _ = tf_shape(features) new_objects = AttrDict() is_posterior_tf = tf.ones_like(features[..., 0:2]) if is_posterior: is_posterior_tf = is_posterior_tf * [1, 0] else: is_posterior_tf = is_posterior_tf * [0, 1] base_features = tf.concat([features, is_posterior_tf], axis=-1) cyt, cxt, ys, xs = tf.split(objects.normalized_box, 4, axis=-1) # Do this regardless of is_posterior, otherwise ScopedFunction gets messed up glimpse_dim = self.object_shape[0] * self.object_shape[1] glimpse_prime_params = apply_object_wise(self.glimpse_prime_network, base_features, output_size=4 + 2 * glimpse_dim, is_training=self.is_training) glimpse_prime_params, glimpse_prime_mask_logit, glimpse_mask_logit = \ tf.split(glimpse_prime_params, [4, glimpse_dim, glimpse_dim], axis=-1) if is_posterior: # --- obtain final parameters for glimpse prime by modifying current pose --- _yt, _xt, _ys, _xs = tf.split(glimpse_prime_params, 4, axis=-1) g_yt = cyt + 0.1 * _yt g_xt = cxt + 0.1 * _xt g_ys = ys + 0.1 * _ys g_xs = xs + 0.1 * _xs # --- extract glimpse prime --- _, image_height, image_width, _ = tf_shape(inp) g_yt, g_xt, g_ys, g_xs = coords_to_image_space( g_yt, g_xt, g_ys, g_xs, (image_height, image_width), self.anchor_box, top_left=False) glimpse_prime = extract_affine_glimpse(inp, self.object_shape, g_yt, g_xt, g_ys, g_xs, self.edge_resampler) else: g_yt = tf.zeros_like(cyt) g_xt = tf.zeros_like(cxt) g_ys = tf.zeros_like(ys) g_xs = tf.zeros_like(xs) glimpse_prime = tf.zeros( (batch_size, n_objects, *self.object_shape, self.image_depth)) glimpse_prime_mask = tf.nn.sigmoid(glimpse_prime_mask_logit + 1.) leading_mask_shape = tf_shape(glimpse_prime)[:-1] glimpse_prime_mask = tf.reshape(glimpse_prime_mask, (*leading_mask_shape, 1)) new_objects.update( glimpse_prime_box=tf.concat([g_yt, g_xt, g_ys, g_xs], axis=-1), glimpse_prime=glimpse_prime, glimpse_prime_mask=glimpse_prime_mask, ) glimpse_prime *= glimpse_prime_mask # --- encode glimpse --- encoded_glimpse_prime = apply_object_wise(self.glimpse_prime_encoder, glimpse_prime, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse_prime = tf.zeros((batch_size, n_objects, self.A), dtype=tf.float32) # --- position and scale --- # roughly: # base_features == temporal_state, encoded_glimpse_prime == hidden_output # hidden_output conditions on encoded_glimpse, and that's the only place encoded_glimpse_prime is used. # Here SQAIR conditions on the actual location values from the previous timestep, but we leave that out for now. d_box_inp = tf.concat([base_features, encoded_glimpse_prime], axis=-1) d_box_params = apply_object_wise(self.d_box_network, d_box_inp, output_size=8, is_training=self.is_training) d_box_mean, d_box_log_std = tf.split(d_box_params, 2, axis=-1) d_box_std = self.std_nonlinearity(d_box_log_std) d_box_mean = self.training_wheels * tf.stop_gradient(d_box_mean) + ( 1 - self.training_wheels) * d_box_mean d_box_std = self.training_wheels * tf.stop_gradient(d_box_std) + ( 1 - self.training_wheels) * d_box_std d_yt_mean, d_xt_mean, d_ys, d_xs = tf.split(d_box_mean, 4, axis=-1) d_yt_std, d_xt_std, ys_std, xs_std = tf.split(d_box_std, 4, axis=-1) # --- position --- # We predict position a bit differently from scale. For scale we want to put a prior on the actual value of # the scale, whereas for position we want to put a prior on the difference in position over timesteps. d_yt_logit = Normal(loc=d_yt_mean, scale=d_yt_std).sample() d_xt_logit = Normal(loc=d_xt_mean, scale=d_xt_std).sample() d_yt = self.where_t_scale * tf.nn.tanh(d_yt_logit) d_xt = self.where_t_scale * tf.nn.tanh(d_xt_logit) new_cyt = cyt + d_yt new_cxt = cxt + d_xt new_abs_posn = objects.abs_posn + tf.concat([d_yt, d_xt], axis=-1) # --- scale --- new_ys_mean = objects.ys_logit + d_ys new_xs_mean = objects.xs_logit + d_xs new_ys_logit = Normal(loc=new_ys_mean, scale=ys_std).sample() new_xs_logit = Normal(loc=new_xs_mean, scale=xs_std).sample() new_ys = float(self.max_hw - self.min_hw) * tf.nn.sigmoid( tf.clip_by_value(new_ys_logit, -10, 10)) + self.min_hw new_xs = float(self.max_hw - self.min_hw) * tf.nn.sigmoid( tf.clip_by_value(new_xs_logit, -10, 10)) + self.min_hw if self.use_abs_posn: box_params = tf.concat([ new_abs_posn, d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit ], axis=-1) else: box_params = tf.concat( [d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit], axis=-1) new_objects.update( abs_posn=new_abs_posn, yt=new_cyt, xt=new_cxt, ys=new_ys, xs=new_xs, normalized_box=tf.concat([new_cyt, new_cxt, new_ys, new_xs], axis=-1), d_yt_logit=d_yt_logit, d_xt_logit=d_xt_logit, ys_logit=new_ys_logit, xs_logit=new_xs_logit, d_yt_logit_mean=d_yt_mean, d_xt_logit_mean=d_xt_mean, ys_logit_mean=new_ys_mean, xs_logit_mean=new_xs_mean, d_yt_logit_std=d_yt_std, d_xt_logit_std=d_xt_std, ys_logit_std=ys_std, xs_logit_std=xs_std, ) # --- attributes --- # --- extract a glimpse using new box --- if is_posterior: _, image_height, image_width, _ = tf_shape(inp) _new_cyt, _new_cxt, _new_ys, _new_xs = coords_to_image_space( new_cyt, new_cxt, new_ys, new_xs, (image_height, image_width), self.anchor_box, top_left=False) glimpse = extract_affine_glimpse(inp, self.object_shape, _new_cyt, _new_cxt, _new_ys, _new_xs, self.edge_resampler) else: glimpse = tf.zeros( (batch_size, n_objects, *self.object_shape, self.image_depth)) glimpse_mask = tf.nn.sigmoid(glimpse_mask_logit + 1.) leading_mask_shape = tf_shape(glimpse)[:-1] glimpse_mask = tf.reshape(glimpse_mask, (*leading_mask_shape, 1)) glimpse *= glimpse_mask encoded_glimpse = apply_object_wise(self.glimpse_encoder, glimpse, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse = tf.zeros((batch_size, n_objects, self.A), dtype=tf.float32) # --- predict change in attributes --- # so under sqair we mix between three different values for the attributes: # 1. value from previous timestep # 2. value predicted directly from glimpse # 3. value predicted based on update of temporal cell...this update conditions on hidden_output, # the prediction in #2., and the where values. # How to do this given that we are predicting the change in attr? We could just directly predict # the attr instead, but call it d_attr. After all, it is in this function that we control # whether d_attr is added to attr. # So, make a prediction based on just the input: attr_from_inp = apply_object_wise(self.predict_attr_inp, encoded_glimpse, output_size=2 * self.A, is_training=self.is_training) attr_from_inp_mean, attr_from_inp_log_std = tf.split(attr_from_inp, [self.A, self.A], axis=-1) attr_from_inp_std = self.std_nonlinearity(attr_from_inp_log_std) # And then a prediction which takes the past into account (predicting gate values at the same time): attr_from_temp_inp = tf.concat( [base_features, box_params, encoded_glimpse], axis=-1) attr_from_temp = apply_object_wise(self.predict_attr_temp, attr_from_temp_inp, output_size=5 * self.A, is_training=self.is_training) (attr_from_temp_mean, attr_from_temp_log_std, f_gate_logit, i_gate_logit, t_gate_logit) = tf.split(attr_from_temp, 5, axis=-1) attr_from_temp_std = self.std_nonlinearity(attr_from_temp_log_std) # bias the gates f_gate = tf.nn.sigmoid(f_gate_logit + 1) * .9999 i_gate = tf.nn.sigmoid(i_gate_logit + 1) * .9999 t_gate = tf.nn.sigmoid(t_gate_logit + 1) * .9999 attr_mean = f_gate * objects.attr + ( 1 - i_gate) * attr_from_inp_mean + (1 - t_gate) * attr_from_temp_mean attr_std = (1 - i_gate) * attr_from_inp_std + ( 1 - t_gate) * attr_from_temp_std new_attr = Normal(loc=attr_mean, scale=attr_std).sample() # --- apply change in attributes --- new_objects.update( attr=new_attr, d_attr=new_attr - objects.attr, d_attr_mean=attr_mean - objects.attr, d_attr_std=attr_std, f_gate=f_gate, i_gate=i_gate, t_gate=t_gate, glimpse=glimpse, glimpse_mask=glimpse_mask, ) # --- z --- d_z_inp = tf.concat( [base_features, box_params, new_attr, encoded_glimpse], axis=-1) d_z_params = apply_object_wise(self.d_z_network, d_z_inp, output_size=2, is_training=self.is_training) d_z_mean, d_z_log_std = tf.split(d_z_params, 2, axis=-1) d_z_std = self.std_nonlinearity(d_z_log_std) d_z_mean = self.training_wheels * tf.stop_gradient(d_z_mean) + ( 1 - self.training_wheels) * d_z_mean d_z_std = self.training_wheels * tf.stop_gradient(d_z_std) + ( 1 - self.training_wheels) * d_z_std d_z_logit = Normal(loc=d_z_mean, scale=d_z_std).sample() new_z_logit = objects.z_logit + d_z_logit new_z = self.z_nonlinearity(new_z_logit) new_objects.update( z=new_z, z_logit=new_z_logit, d_z_logit=d_z_logit, d_z_logit_mean=d_z_mean, d_z_logit_std=d_z_std, ) # --- obj --- d_obj_inp = tf.concat( [base_features, box_params, new_attr, new_z, encoded_glimpse], axis=-1) d_obj_logit = apply_object_wise(self.d_obj_network, d_obj_inp, output_size=1, is_training=self.is_training) d_obj_logit = self.training_wheels * tf.stop_gradient(d_obj_logit) + ( 1 - self.training_wheels) * d_obj_logit d_obj_log_odds = tf.clip_by_value(d_obj_logit / self.obj_temp, -10., 10.) d_obj_pre_sigmoid = (self._noisy * concrete_binary_pre_sigmoid_sample( d_obj_log_odds, self.obj_concrete_temp) + (1 - self._noisy) * d_obj_log_odds) d_obj = tf.nn.sigmoid(d_obj_pre_sigmoid) new_obj = objects.obj * d_obj new_objects.update( d_obj_log_odds=d_obj_log_odds, d_obj_prob=tf.nn.sigmoid(d_obj_log_odds), d_obj_pre_sigmoid=d_obj_pre_sigmoid, d_obj=d_obj, obj=new_obj, ) # --- update each object's hidden state -- cell_input = tf.concat([box_params, new_attr, new_z, new_obj], axis=-1) if is_posterior: _, new_objects.prop_state = apply_object_wise( self.cell, cell_input, objects.prop_state) new_objects.prior_prop_state = new_objects.prop_state else: _, new_objects.prior_prop_state = apply_object_wise( self.cell, cell_input, objects.prior_prop_state) new_objects.prop_state = new_objects.prior_prop_state return new_objects
def _body(self, inp, features, objects, is_posterior): batch_size, n_objects, _ = tf_shape(features) new_objects = AttrDict() is_posterior_tf = tf.ones_like(features[..., 0:2]) if is_posterior: is_posterior_tf = is_posterior_tf * [1, 0] else: is_posterior_tf = is_posterior_tf * [0, 1] base_features = tf.concat([features, is_posterior_tf], axis=-1) cyt, cxt, ys, xs = tf.split(objects.normalized_box, 4, axis=-1) if self.learn_glimpse_prime: # Do this regardless of is_posterior, otherwise ScopedFunction gets messed up glimpse_prime_params = apply_object_wise( self.glimpse_prime_network, base_features, output_size=4, is_training=self.is_training) else: glimpse_prime_params = tf.zeros_like(base_features[..., :4]) if is_posterior: if self.learn_glimpse_prime: # --- obtain final parameters for glimpse prime by modifying current pose --- _yt, _xt, _ys, _xs = tf.split(glimpse_prime_params, 4, axis=-1) # This is how it is done in SQAIR g_yt = cyt + 0.1 * _yt g_xt = cxt + 0.1 * _xt g_ys = ys + 0.1 * _ys g_xs = xs + 0.1 * _xs # g_yt = cyt + self.glimpse_prime_scale * tf.nn.tanh(_yt) # g_xt = cxt + self.glimpse_prime_scale * tf.nn.tanh(_xt) # g_ys = ys + self.glimpse_prime_scale * tf.nn.tanh(_ys) # g_xs = xs + self.glimpse_prime_scale * tf.nn.tanh(_xs) else: g_yt = cyt g_xt = cxt g_ys = self.glimpse_prime_scale * ys g_xs = self.glimpse_prime_scale * xs # --- extract glimpse prime --- _, image_height, image_width, _ = tf_shape(inp) g_yt, g_xt, g_ys, g_xs = coords_to_image_space( g_yt, g_xt, g_ys, g_xs, (image_height, image_width), self.anchor_box, top_left=False) glimpse_prime = extract_affine_glimpse(inp, self.object_shape, g_yt, g_xt, g_ys, g_xs, self.edge_resampler) else: g_yt = tf.zeros_like(cyt) g_xt = tf.zeros_like(cxt) g_ys = tf.zeros_like(ys) g_xs = tf.zeros_like(xs) glimpse_prime = tf.zeros( (batch_size, n_objects, *self.object_shape, self.image_depth)) new_objects.update(glimpse_prime_box=tf.concat( [g_yt, g_xt, g_ys, g_xs], axis=-1), ) # --- encode glimpse --- encoded_glimpse_prime = apply_object_wise(self.glimpse_prime_encoder, glimpse_prime, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse_prime = tf.zeros((batch_size, n_objects, self.A), dtype=tf.float32) # --- position and scale --- d_box_inp = tf.concat([base_features, encoded_glimpse_prime], axis=-1) d_box_params = apply_object_wise(self.d_box_network, d_box_inp, output_size=8, is_training=self.is_training) d_box_mean, d_box_log_std = tf.split(d_box_params, 2, axis=-1) d_box_std = self.std_nonlinearity(d_box_log_std) d_box_mean = self.training_wheels * tf.stop_gradient(d_box_mean) + ( 1 - self.training_wheels) * d_box_mean d_box_std = self.training_wheels * tf.stop_gradient(d_box_std) + ( 1 - self.training_wheels) * d_box_std d_yt_mean, d_xt_mean, d_ys, d_xs = tf.split(d_box_mean, 4, axis=-1) d_yt_std, d_xt_std, ys_std, xs_std = tf.split(d_box_std, 4, axis=-1) # --- position --- # We predict position a bit differently from scale. For scale we want to put a prior on the actual value of # the scale, whereas for position we want to put a prior on the difference in position over timesteps. d_yt_logit = Normal(loc=d_yt_mean, scale=d_yt_std).sample() d_xt_logit = Normal(loc=d_xt_mean, scale=d_xt_std).sample() d_yt = self.where_t_scale * tf.nn.tanh(d_yt_logit) d_xt = self.where_t_scale * tf.nn.tanh(d_xt_logit) new_cyt = cyt + d_yt new_cxt = cxt + d_xt new_abs_posn = objects.abs_posn + tf.concat([d_yt, d_xt], axis=-1) # --- scale --- new_ys_mean = objects.ys_logit + d_ys new_xs_mean = objects.xs_logit + d_xs new_ys_logit = Normal(loc=new_ys_mean, scale=ys_std).sample() new_xs_logit = Normal(loc=new_xs_mean, scale=xs_std).sample() new_ys = float(self.max_hw - self.min_hw) * tf.nn.sigmoid( tf.clip_by_value(new_ys_logit, -10, 10)) + self.min_hw new_xs = float(self.max_hw - self.min_hw) * tf.nn.sigmoid( tf.clip_by_value(new_xs_logit, -10, 10)) + self.min_hw # Used for conditioning if self.use_abs_posn: box_params = tf.concat([ new_abs_posn, d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit ], axis=-1) else: box_params = tf.concat( [d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit], axis=-1) new_objects.update( abs_posn=new_abs_posn, yt=new_cyt, xt=new_cxt, ys=new_ys, xs=new_xs, normalized_box=tf.concat([new_cyt, new_cxt, new_ys, new_xs], axis=-1), d_yt_logit=d_yt_logit, d_xt_logit=d_xt_logit, ys_logit=new_ys_logit, xs_logit=new_xs_logit, d_yt_logit_mean=d_yt_mean, d_xt_logit_mean=d_xt_mean, ys_logit_mean=new_ys_mean, xs_logit_mean=new_xs_mean, d_yt_logit_std=d_yt_std, d_xt_logit_std=d_xt_std, ys_logit_std=ys_std, xs_logit_std=xs_std, glimpse_prime=glimpse_prime, ) # --- attributes --- # --- extract a glimpse using new box --- if is_posterior: _, image_height, image_width, _ = tf_shape(inp) _new_cyt, _new_cxt, _new_ys, _new_xs = coords_to_image_space( new_cyt, new_cxt, new_ys, new_xs, (image_height, image_width), self.anchor_box, top_left=False) glimpse = extract_affine_glimpse(inp, self.object_shape, _new_cyt, _new_cxt, _new_ys, _new_xs, self.edge_resampler) else: glimpse = tf.zeros( (batch_size, n_objects, *self.object_shape, self.image_depth)) encoded_glimpse = apply_object_wise(self.glimpse_encoder, glimpse, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse = tf.zeros((batch_size, n_objects, self.A), dtype=tf.float32) # --- predict change in attributes --- d_attr_inp = tf.concat([base_features, box_params, encoded_glimpse], axis=-1) d_attr_params = apply_object_wise(self.d_attr_network, d_attr_inp, output_size=2 * self.A + 1, is_training=self.is_training) d_attr_mean, d_attr_log_std, gate_logit = tf.split(d_attr_params, [self.A, self.A, 1], axis=-1) d_attr_std = self.std_nonlinearity(d_attr_log_std) gate = tf.nn.sigmoid(gate_logit) if self.gate_d_attr: d_attr_mean *= gate d_attr = Normal(loc=d_attr_mean, scale=d_attr_std).sample() # --- apply change in attributes --- new_attr = objects.attr + d_attr new_objects.update( attr=new_attr, d_attr=d_attr, d_attr_mean=d_attr_mean, d_attr_std=d_attr_std, glimpse=glimpse, d_attr_gate=gate, ) # --- z --- d_z_inp = tf.concat( [base_features, box_params, new_attr, encoded_glimpse], axis=-1) d_z_params = apply_object_wise(self.d_z_network, d_z_inp, output_size=2, is_training=self.is_training) d_z_mean, d_z_log_std = tf.split(d_z_params, 2, axis=-1) d_z_std = self.std_nonlinearity(d_z_log_std) d_z_mean = self.training_wheels * tf.stop_gradient(d_z_mean) + ( 1 - self.training_wheels) * d_z_mean d_z_std = self.training_wheels * tf.stop_gradient(d_z_std) + ( 1 - self.training_wheels) * d_z_std d_z_logit = Normal(loc=d_z_mean, scale=d_z_std).sample() new_z_logit = objects.z_logit + d_z_logit new_z = self.z_nonlinearity(new_z_logit) new_objects.update( z=new_z, z_logit=new_z_logit, d_z_logit=d_z_logit, d_z_logit_mean=d_z_mean, d_z_logit_std=d_z_std, ) # --- obj --- d_obj_inp = tf.concat( [base_features, box_params, new_attr, new_z, encoded_glimpse], axis=-1) d_obj_logit = apply_object_wise(self.d_obj_network, d_obj_inp, output_size=1, is_training=self.is_training) d_obj_logit = self.training_wheels * tf.stop_gradient(d_obj_logit) + ( 1 - self.training_wheels) * d_obj_logit d_obj_log_odds = tf.clip_by_value(d_obj_logit / self.obj_temp, -10., 10.) d_obj_pre_sigmoid = (self._noisy * concrete_binary_pre_sigmoid_sample( d_obj_log_odds, self.obj_concrete_temp) + (1 - self._noisy) * d_obj_log_odds) d_obj = tf.nn.sigmoid(d_obj_pre_sigmoid) new_obj = objects.obj * d_obj new_objects.update( d_obj_log_odds=d_obj_log_odds, d_obj_prob=tf.nn.sigmoid(d_obj_log_odds), d_obj_pre_sigmoid=d_obj_pre_sigmoid, d_obj=d_obj, obj=new_obj, ) # --- update each object's hidden state -- cell_input = tf.concat([box_params, new_attr, new_z, new_obj], axis=-1) if is_posterior: _, new_objects.prop_state = apply_object_wise( self.cell, cell_input, objects.prop_state) new_objects.prior_prop_state = new_objects.prop_state else: _, new_objects.prior_prop_state = apply_object_wise( self.cell, cell_input, objects.prior_prop_state) new_objects.prop_state = new_objects.prior_prop_state return new_objects
def _call(self, inp, inp_features, is_training, is_posterior=True, prop_state=None): print("\n" + "-" * 10 + " ConvGridObjectLayer(is_posterior={}) ".format(is_posterior) + "-" * 10) # --- set up sub networks and attributes --- self.maybe_build_subnet("box_network", builder=cfg.build_conv_lateral, key="box") self.maybe_build_subnet("attr_network", builder=cfg.build_conv_lateral, key="attr") self.maybe_build_subnet("z_network", builder=cfg.build_conv_lateral, key="z") self.maybe_build_subnet("obj_network", builder=cfg.build_conv_lateral, key="obj") self.maybe_build_subnet("object_encoder") _, H, W, _, n_channels = tf_shape(inp_features) if self.B != 1: raise Exception("NotImplemented") if not self.initialized: # Note this limits the re-usability of this module to images # with a fixed shape (the shape of the first image it is used on) self.batch_size, self.image_height, self.image_width, self.image_depth = tf_shape(inp) self.H = H self.W = W self.HWB = H*W self.batch_size = tf.shape(inp)[0] self.is_training = is_training self.float_is_training = tf.to_float(is_training) inp_features = tf.reshape(inp_features, (self.batch_size, H, W, n_channels)) is_posterior_tf = tf.ones_like(inp_features[..., :2]) if is_posterior: is_posterior_tf = is_posterior_tf * [1, 0] else: is_posterior_tf = is_posterior_tf * [0, 1] objects = AttrDict() base_features = tf.concat([inp_features, is_posterior_tf], axis=-1) # --- box --- layer_inp = base_features n_features = self.n_passthrough_features output_size = 8 network_output = self.box_network(layer_inp, output_size + n_features, self.is_training) rep_input, features = tf.split(network_output, (output_size, n_features), axis=-1) _objects = self._build_box(rep_input, self.is_training) objects.update(_objects) # --- attr --- if is_posterior: # --- Get object attributes using object encoder --- yt, xt, ys, xs = tf.split(objects['normalized_box'], 4, axis=-1) yt, xt, ys, xs = coords_to_image_space( yt, xt, ys, xs, (self.image_height, self.image_width), self.anchor_box, top_left=False) transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper( (self.image_height, self.image_width), self.object_shape, transform_constraints) _boxes = tf.concat([xs, 2*xt - 1, ys, 2*yt - 1], axis=-1) _boxes = tf.reshape(_boxes, (self.batch_size*H*W, 4)) grid_coords = warper(_boxes) grid_coords = tf.reshape(grid_coords, (self.batch_size, H, W, *self.object_shape, 2,)) if self.edge_resampler: glimpse = resampler_edge.resampler_edge(inp, grid_coords) else: glimpse = tf.contrib.resampler.resampler(inp, grid_coords) else: glimpse = tf.zeros((self.batch_size, H, W, *self.object_shape, self.image_depth)) # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction encoded_glimpse = apply_object_wise( self.object_encoder, glimpse, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse = tf.zeros_like(encoded_glimpse) layer_inp = tf.concat([base_features, features, encoded_glimpse, objects['local_box']], axis=-1) network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self.is_training) attr_mean, attr_log_std, features = tf.split(network_output, (self.A, self.A, n_features), axis=-1) attr_std = self.std_nonlinearity(attr_log_std) attr = Normal(loc=attr_mean, scale=attr_std).sample() objects.update(attr_mean=attr_mean, attr_std=attr_std, attr=attr, glimpse=glimpse) # --- z --- layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr']], axis=-1) n_features = self.n_passthrough_features network_output = self.z_network(layer_inp, 2 + n_features, self.is_training) z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=-1) z_std = self.std_nonlinearity(z_log_std) z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (1-self.training_wheels) * z_mean z_std = self.training_wheels * tf.stop_gradient(z_std) + (1-self.training_wheels) * z_std z_logit = Normal(loc=z_mean, scale=z_std).sample() z = self.z_nonlinearity(z_logit) objects.update(z_logit_mean=z_mean, z_logit_std=z_std, z_logit=z_logit, z=z) # --- obj --- layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr'], objects['z']], axis=-1) rep_input = self.obj_network(layer_inp, 1, self.is_training) _objects = self._build_obj(rep_input, self.is_training) objects.update(_objects) # --- final --- _objects = AttrDict() for k, v in objects.items(): _, _, _, *trailing_dims = tf_shape(v) _objects[k] = tf.reshape(v, (self.batch_size, self.HWB, *trailing_dims)) objects = _objects if prop_state is not None: objects.prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1)) objects.prior_prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1)) # --- misc --- objects.n_objects = tf.fill((self.batch_size,), self.HWB) objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2)) objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj), axis=(1, 2)) return objects
def _call(self, objects, background, is_training, appearance_only=False): if not self.initialized: self.image_depth = tf_shape(background)[-1] self.maybe_build_subnet("object_decoder") # --- compute sprite appearance from attr using object decoder --- appearance_logit = apply_object_wise( self.object_decoder, objects.attr, output_size=self.object_shape + (self.image_depth+1,), is_training=is_training) appearance_logit = appearance_logit * ([self.color_logit_scale] * 3 + [self.alpha_logit_scale]) appearance_logit = appearance_logit + ([0.] * 3 + [self.alpha_logit_bias]) appearance = tf.nn.sigmoid(tf.clip_by_value(appearance_logit, -10., 10.)) if appearance_only: return dict(appearance=appearance) appearance_for_output = appearance batch_size, *obj_leading_shape, _, _, _ = tf_shape(appearance) n_objects = np.prod(obj_leading_shape) appearance = tf.reshape( appearance, (batch_size, n_objects, *self.object_shape, self.image_depth+1)) obj_colors, obj_alpha = tf.split(appearance, [self.image_depth, 1], axis=-1) obj_alpha *= tf.reshape(objects.render_obj, (batch_size, n_objects, 1, 1, 1)) z = tf.reshape(objects.z, (batch_size, n_objects, 1, 1, 1)) obj_importance = tf.maximum(obj_alpha * z / self.importance_temp, 0.01) object_maps = tf.concat([obj_colors, obj_alpha, obj_importance], axis=-1) *_, image_height, image_width, _ = tf_shape(background) yt, xt, ys, xs = coords_to_image_space( objects.yt, objects.xt, objects.ys, objects.xs, (image_height, image_width), self.anchor_box, top_left=True) scales = tf.concat([ys, xs], axis=-1) scales = tf.reshape(scales, (batch_size, n_objects, 2)) offsets = tf.concat([yt, xt], axis=-1) offsets = tf.reshape(offsets, (batch_size, n_objects, 2)) # --- Compose images --- n_objects_per_image = tf.fill((batch_size,), int(n_objects)) output = render_sprites.render_sprites( object_maps, n_objects_per_image, scales, offsets, background ) return dict( appearance=appearance_for_output, output=output)
def _call(self, input_locs, input_features, reference_locs, reference_features, is_training): """ input_features: (B, n_inp, n_hidden) input_locs: (B, n_inp, loc_dim) reference_locs: (B, n_ref, loc_dim) """ assert (reference_features is not None) == self.do_object_wise if not self.is_built: self.relation_func = self.build_mlp(scope="relation_func") if self.do_object_wise: self.object_wise_func = self.build_mlp( scope="object_wise_func") self.is_built = True loc_dim = tf_shape(input_locs)[-1] n_ref = tf_shape(reference_locs)[-2] batch_size, n_inp, _ = tf_shape(input_features) input_locs = tf.broadcast_to(input_locs, (batch_size, n_inp, loc_dim)) reference_locs = tf.broadcast_to(reference_locs, (batch_size, n_ref, loc_dim)) adjusted_locs = input_locs[:, None, :, :] - reference_locs[:, :, None, :] # (B, n_ref, n_inp, loc_dim) adjusted_features = tf.tile( input_features[:, None], (1, n_ref, 1, 1)) # (B, n_ref, n_inp, features_dim) relation_input = tf.concat([adjusted_features, adjusted_locs], axis=-1) if self.do_object_wise: object_wise = apply_object_wise( self.object_wise_func, reference_features, output_size=self.n_hidden, is_training=is_training) # (B, n_ref, n_hidden) _object_wise = tf.tile(object_wise[:, :, None], (1, 1, n_inp, 1)) relation_input = tf.concat([relation_input, _object_wise], axis=-1) else: object_wise = None V = apply_object_wise( self.relation_func, relation_input, output_size=self.n_hidden, is_training=is_training) # (B, n_ref, n_inp, n_hidden) attention_weights = tf.exp(-0.5 * tf.reduce_sum( (adjusted_locs / self.kernel_std)**2, axis=3)) attention_weights = (attention_weights / (2 * np.pi)**(loc_dim / 2) / self.kernel_std**loc_dim) # (B, n_ref, n_inp) result = tf.reduce_sum(V * attention_weights[..., None], axis=2) # (B, n_ref, n_hidden) if self.do_object_wise: result += object_wise # result = tf.contrib.layers.layer_norm(result) return result
def _call(self, signal, is_training, memory=None): if not self.is_built: self.query_funcs = [ self.build_mlp(scope="query_head_{}".format(j)) for j in range(self.n_heads) ] self.key_funcs = [ self.build_mlp(scope="key_head_{}".format(j)) for j in range(self.n_heads) ] self.value_funcs = [ self.build_mlp(scope="value_head_{}".format(j)) for j in range(self.n_heads) ] self.after_func = self.build_mlp(scope="after") if self.do_object_wise: self.object_wise_func = self.build_object_wise( scope="object_wise") if self.memory is not None: self.K = [ apply_object_wise(self.key_funcs[j], memory, output_size=self.key_dim, is_training=is_training) for j in range(self.n_heads) ] self.V = [ apply_object_wise(self.value_funcs[j], memory, output_size=self.value_dim, is_training=is_training) for j in range(self.n_heads) ] self.is_built = True n_signal_dim = len(signal.shape) assert n_signal_dim in [2, 3] if isinstance(memory, tuple): # keys and values passed in directly K, V = memory elif memory is not None: # memory is a value that we apply key_funcs and value_funcs to to obtain keys and values K = [ apply_object_wise(self.key_funcs[j], memory, output_size=self.key_dim, is_training=is_training) for j in range(self.n_heads) ] V = [ apply_object_wise(self.value_funcs[j], memory, output_size=self.value_dim, is_training=is_training) for j in range(self.n_heads) ] elif self.K is not None: K = self.K V = self.V else: # self-attention - `signal` used for queries, keys and values. K = [ apply_object_wise(self.key_funcs[j], signal, output_size=self.key_dim, is_training=is_training) for j in range(self.n_heads) ] V = [ apply_object_wise(self.value_funcs[j], signal, output_size=self.value_dim, is_training=is_training) for j in range(self.n_heads) ] head_outputs = [] for j in range(self.n_heads): Q = apply_object_wise(self.query_funcs[j], signal, output_size=self.key_dim, is_training=is_training) if n_signal_dim == 2: Q = Q[:, None, :] attention_logits = tf.matmul(Q, K[j], transpose_b=True) / tf.sqrt( tf.to_float(self.key_dim)) attention = tf.nn.softmax(attention_logits) attended = tf.matmul(attention, V[j]) # (..., n_queries, value_dim) if n_signal_dim == 2: attended = attended[:, 0, :] head_outputs.append(attended) head_outputs = tf.concat(head_outputs, axis=-1) # `after_func` is applied to the concatenation of the head outputs, and the result is added to the original # signal. Next, if `object_wise_func` is not None and `do_object_wise` is True, object_wise_func is # applied object wise and in a ResNet-style manner. output = apply_object_wise(self.after_func, head_outputs, output_size=self.n_hidden, is_training=is_training) output = tf.layers.dropout(output, self.p_dropout, training=is_training) signal = tf.contrib.layers.layer_norm(signal + output) if self.do_object_wise: output = apply_object_wise(self.object_wise_func, signal, output_size=self.n_hidden, is_training=is_training) output = tf.layers.dropout(output, self.p_dropout, training=is_training) signal = tf.contrib.layers.layer_norm(signal + output) return signal
def _call(self, objects, background, is_training, appearance_only=False): if not self.initialized: self.image_depth = tf_shape(background)[-1] self.maybe_build_subnet("object_decoder") # --- compute sprite appearance from attr using object decoder --- appearance_logits = apply_object_wise( self.object_decoder, objects.attr, self.object_shape + (self.image_depth + 1, ), is_training) appearance_logits = appearance_logits * ([self.color_logit_scale] * 3 + [self.alpha_logit_scale]) appearance_logits = appearance_logits + ([0.] * 3 + [self.alpha_logit_bias]) appearance = tf.nn.sigmoid( tf.clip_by_value(appearance_logits, -10., 10.)) if appearance_only: return dict(appearance=appearance) appearance_for_output = appearance batch_size, *obj_leading_shape, _, _, _ = tf_shape(appearance) n_objects = np.prod(obj_leading_shape) appearance = tf.reshape( appearance, (batch_size, n_objects, *self.object_shape, self.image_depth + 1)) obj_colors, obj_alpha = tf.split(appearance, [self.image_depth, 1], axis=-1) if "alpha" in self.no_gradient: obj_alpha = tf.stop_gradient(obj_alpha) if "alpha" in self.fixed_values: obj_alpha = float( self.fixed_values["alpha"]) * tf.ones_like(obj_alpha) obj_alpha *= tf.reshape(objects.obj, (batch_size, n_objects, 1, 1, 1)) z = tf.reshape(objects.z, (batch_size, n_objects, 1, 1, 1)) obj_importance = tf.maximum(obj_alpha * z, 0.01) object_maps = tf.concat([obj_colors, obj_alpha, obj_importance], axis=-1) ys, xs, yt, xt = objects.ys, objects.xs, objects.yt, objects.xt scales = tf.concat([ys, xs], axis=-1) scales = tf.reshape(scales, (batch_size, n_objects, 2)) offsets = tf.concat([yt, xt], axis=-1) offsets = tf.reshape(offsets, (batch_size, n_objects, 2)) # --- Compose images --- n_objects_per_image = tf.fill((batch_size, ), int(n_objects)) output = render_sprites.render_sprites(object_maps, n_objects_per_image, scales, offsets, background) return dict(appearance=appearance_for_output, output=output)
def _call(self, objects, background, is_training, appearance_only=False, mask_only=False): """ If mask_only==True, then we ignore the provided background, using a black blackground instead, and also ignore the computed appearance, using all-white appearances instead. """ if not self.initialized: self.image_depth = tf_shape(background)[-1] single = False if isinstance(objects, dict): single = True objects = [objects] _object_maps = [] _scales = [] _offsets = [] _appearance = [] for i, obj in enumerate(objects): anchor_box = self.anchor_boxes[i] object_shape = self.object_shapes[i] object_decoder = self.maybe_build_subnet( "object_decoder_for_flight_{}".format(i), builder_name='build_object_decoder') # --- compute sprite appearance from attr using object decoder --- appearance_logit = apply_object_wise( object_decoder, obj.attr, output_size=object_shape + (self.image_depth+1,), is_training=is_training) appearance_logit = appearance_logit * ([self.color_logit_scale] * self.image_depth + [self.alpha_logit_scale]) appearance_logit = appearance_logit + ([0.] * self.image_depth + [self.alpha_logit_bias]) appearance = tf.nn.sigmoid(tf.clip_by_value(appearance_logit, -10., 10.)) _appearance.append(appearance) if appearance_only: continue batch_size, *obj_leading_shape, _, _, _ = tf_shape(appearance) n_objects = np.prod(obj_leading_shape) appearance = tf.reshape( appearance, (batch_size, n_objects, *object_shape, self.image_depth+1)) obj_colors, obj_alpha = tf.split(appearance, [self.image_depth, 1], axis=-1) if mask_only: obj_colors = tf.ones_like(obj_colors) obj_alpha *= tf.reshape(obj.obj, (batch_size, n_objects, 1, 1, 1)) z = tf.reshape(obj.z, (batch_size, n_objects, 1, 1, 1)) obj_importance = tf.maximum(obj_alpha * z / self.importance_temp, 0.01) object_maps = tf.concat([obj_colors, obj_alpha, obj_importance], axis=-1) *_, image_height, image_width, _ = tf_shape(background) yt, xt, ys, xs = coords_to_image_space( obj.yt, obj.xt, obj.ys, obj.xs, (image_height, image_width), anchor_box, top_left=True) scales = tf.concat([ys, xs], axis=-1) scales = tf.reshape(scales, (batch_size, n_objects, 2)) offsets = tf.concat([yt, xt], axis=-1) offsets = tf.reshape(offsets, (batch_size, n_objects, 2)) _object_maps.append(object_maps) _scales.append(scales) _offsets.append(offsets) if single: _appearance = _appearance[0] if appearance_only: return dict(appearance=_appearance) if mask_only: background = tf.zeros_like(background) # --- Compose images --- output = render_sprites.render_sprites( _object_maps, _scales, _offsets, background ) return dict( appearance=_appearance, output=output)