def extract_affine_glimpse(image, object_shape, cyt, cxt, ys, xs, edge_resampler): """ (cyt, cxt) are rectangle center. (ys, xs) are rectangle height/width """ _, *image_shape, image_depth = tf_shape(image) transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper(image_shape, object_shape, transform_constraints) # change coordinate system cyt = 2 * cyt - 1 cxt = 2 * cxt - 1 leading_shape = tf_shape(cyt)[:-1] _boxes = tf.concat([xs, cxt, ys, cyt], axis=-1) _boxes = tf.reshape(_boxes, (-1, 4)) grid_coords = warper(_boxes) grid_coords = tf.reshape(grid_coords, (*leading_shape, *object_shape, 2)) if edge_resampler: glimpses = resampler_edge.resampler_edge(image, grid_coords) else: glimpses = tf.contrib.resampler.resampler(image, grid_coords) glimpses = tf.reshape(glimpses, (*leading_shape, *object_shape, image_depth)) return glimpses
def _call(self, input_signal, input_locs, output_locs, is_training): if not self.is_built: self.value_func = self.build_mlp(scope="value_func") self.after_func = self.build_mlp(scope="after") if self.do_object_wise: self.object_wise_func = self.build_object_wise( scope="object_wise") self.is_built = True batch_size, n_inp, _ = tf_shape(input_signal) loc_dim = tf_shape(input_locs)[-1] n_outp = tf_shape(output_locs)[-2] input_locs = tf.broadcast_to(input_locs, (batch_size, n_inp, loc_dim)) output_locs = tf.broadcast_to(output_locs, (batch_size, n_outp, loc_dim)) dist = output_locs[:, :, None, :] - input_locs[:, None, :, :] proximity = tf.exp(-0.5 * tf.reduce_sum( (dist / self.kernel_std)**2, axis=3)) proximity = proximity / (2 * np.pi)**( 0.5 * loc_dim) / self.kernel_std**loc_dim V = apply_object_wise( self.value_func, input_signal, output_size=self.n_hidden, is_training=is_training) # (batch_size, n_inp, value_dim) result = tf.matmul(proximity, V) # (batch_size, n_outp, value_dim) # `after_func` is applied to the concatenation of the head outputs, and the result is added to the original # signal. Next, if `object_wise_func` is not None and `do_object_wise` is True, object_wise_func is # applied object wise and in a ResNet-style manner. output = apply_object_wise(self.after_func, result, output_size=self.n_hidden, is_training=is_training) output = tf.layers.dropout(output, self.p_dropout, training=is_training) signal = tf.contrib.layers.layer_norm(output) if self.do_object_wise: output = apply_object_wise(self.object_wise_func, signal, output_size=self.n_hidden, is_training=is_training) output = tf.layers.dropout(output, self.p_dropout, training=is_training) signal = tf.contrib.layers.layer_norm(signal + output) return signal
def _call(self, inp, output_size, is_training): if self.bg_head is None: self.bg_head = ConvNet( layers=[ dict(filters=None, kernel_size=1, strides=1, padding="SAME"), dict(filters=None, kernel_size=1, strides=1, padding="SAME"), ], scope="bg_head" ) if self.transform_head is None: self.transform_head = MLP(n_units=[64, 64], scope="transform_head") n_attr_channels, n_transform_values = output_size processed = super()._call(inp, n_attr_channels, is_training) B, F, H, W, C = tf_shape(processed) # Map processed to shapes (B, H, W, C) and (B, F, 2) bg_attrs = self.bg_head(tf.reduce_mean(processed, axis=1), None, is_training) transform_values = self.transform_head( tf.reshape(processed, (B*F, H*W*C)), n_transform_values, is_training) transform_values = tf.reshape(transform_values, (B, F, n_transform_values)) return bg_attrs, transform_values
def __call__(self, tensors): batch_size = tf_shape(tensors["obj"])[0] exp_rate = self.exp_rate assert_exp_rate_gt_zero = tf.Assert(exp_rate >= 0, [exp_rate], name='assert_exp_rate_gt_zero') with tf.control_dependencies([assert_exp_rate_gt_zero]): posterior_log_pdf = logistic_log_pdf(tensors["obj_log_odds"], tensors["obj_pre_sigmoid"], self.obj_concrete_temp) posterior_log_pdf = tf.reduce_sum(tf.reshape( posterior_log_pdf, (batch_size, -1)), axis=1) # This is different from the true log prior pdf by a constant factor, # namely the log of the normalization constant for the prior. concrete_sum = tf.reduce_sum(tf.reshape(tensors["obj"], (batch_size, -1)), axis=1) # prior_pdf = exp_rate * tf.exp(-exp_rate * concrete_sum) prior_log_pdf = -exp_rate * concrete_sum return posterior_log_pdf - prior_log_pdf
def __call__(self, tensors): kl = concrete_binary_sample_kl(tensors["obj_pre_sigmoid"], tensors["obj_log_odds"], self.obj_concrete_temp, self.prior_log_odds, self.obj_concrete_temp) batch_size = tf_shape(tensors["obj_pre_sigmoid"])[0] return tf.reduce_sum(tf.reshape(kl, (batch_size, -1)), 1)
def build_representation(self): assert cfg.background_cfg.mode == 'colour' self.build_background() # dummy variable to satisfy dps tf.get_variable("dummy", shape=(1, ), dtype=tf.float32) B, T, *rest = tf_shape(self._tensors["background"]) inp = tf.reshape(self._tensors["inp"], (T * B, *rest)) bg = tf.reshape(self._tensors["background"], (T * B, *rest)) program_tensors = tf_find_connected_components(inp, bg, self.cc_threshold, self.colours, self.cosine_threshold) self._tensors.update({ k: tf.reshape(v, (B, T, *tf_shape(v)[1:])) for k, v in program_tensors.items() if k != 'max_objects' }) if "n_annotations" in self._tensors: count_1norm = tf.to_float( tf.abs( tf.to_int32(self._tensors["n_objects"]) - self._tensors["n_valid_annotations"])) count_1norm_relative = (count_1norm / tf.maximum( tf.cast(self._tensors["n_valid_annotations"], tf.float32), 1e-6)) self.record_tensors( count_1norm_relative=count_1norm_relative, count_1norm=count_1norm, count_error=count_1norm > 0.5, n_objects_per_frame=self._tensors["n_objects"], )
def apply_object_wise(func, signal, output_size, is_training, restore_shape=True, n_trailing_dims=1): """ Treat `signal` as a batch of objects. Apply function `func` separately to each object. The final `n_trailing_dims`-many dimensions are treated as "within-object" dimensions. By default, objects are assumed to be vectors, but this can be changed by increasing `n_trailing_dims`. e.g. n_trailing_dims==2 means each object is a matrix, i.e. the last 2 dimensions of signal are dimensions of the object. """ shape = tf_shape(signal) leading_dim = tf.reduce_prod(shape[:-n_trailing_dims]) signal = tf.reshape(signal, (leading_dim, *shape[-n_trailing_dims:])) output = func(signal, output_size, is_training) if restore_shape: if not isinstance(output_size, tuple): output_size = [output_size] output = tf.reshape(output, (*shape[:-n_trailing_dims], *output_size)) return output
def tile_input_for_iwae(tensor, iw_samples, with_time=False): """Tiles tensor `tensor` in such a way that tiled samples are contiguous in memory; i.e. it tiles along the axis after the batch axis and reshapes to have the same rank as the original tensor :param tensor: tf.Tensor to be tiled :param iw_samples: int, number of importance-weighted samples :param with_time: boolean, if true than an additional axis at the beginning is assumed :return: """ shape = list(tf_shape(tensor)) shape[with_time] *= iw_samples tiles = [1, iw_samples] + [1] * (tensor.shape.ndims - (1 + with_time)) if with_time: tiles = [1] + tiles tensor = tf.expand_dims(tensor, 1 + with_time) tensor = tf.tile(tensor, tiles) tensor = tf.reshape(tensor, shape) return tensor
def null_object_set(self, batch_size): n_prop_objects = self.n_prop_objects new_objects = AttrDict( normalized_box=tf.zeros((batch_size, n_prop_objects, 4)), attr=tf.zeros((batch_size, n_prop_objects, self.A)), z=tf.zeros((batch_size, n_prop_objects, 1)), obj=tf.zeros((batch_size, n_prop_objects, 1)), ) yt, xt, ys, xs = tf.split(new_objects.normalized_box, 4, axis=-1) new_objects.update( abs_posn=new_objects.normalized_box[..., :2] + 0.0, yt=yt, xt=xt, ys=ys, xs=xs, ys_logit=ys + 0.0, xs_logit=xs + 0.0, # d_yt=yt + 0.0, # d_xt=xt + 0.0, # d_ys=ys + 0.0, # d_xs=xs + 0.0, # d_attr=new_objects.attr + 0.0, # d_z=new_objects.z + 0.0, z_logit=new_objects.z + 0.0, ) prop_state = self.cell.initial_state(batch_size * n_prop_objects, tf.float32) trailing_shape = tf_shape(prop_state)[1:] new_objects.prop_state = tf.reshape( prop_state, (batch_size, n_prop_objects, *trailing_shape)) new_objects.prior_prop_state = new_objects.prop_state return new_objects
def build_representation(self): # --- init modules --- if self.backbone is None: self.backbone = self.build_backbone(scope="backbone") if self.cell is None: self.cell = cfg.build_cell(self.n_hidden, name="cell") # self.cell must be a Sonnet RNNCore if self.learn_initial_state: self.initial_hidden_state = snt.trainable_initial_state( 1, self.cell.state_size, tf.float32, name="initial_hidden_state") if self.key_network is None: self.key_network = cfg.build_key_network(scope="key_network") if self.beta_network is None: self.beta_network = cfg.build_beta_network(scope="beta_network") if self.write_network is None: self.write_network = cfg.build_write_network(scope="write_network") if self.erase_network is None: self.erase_network = cfg.build_erase_network(scope="erase_network") if self.output_network is None: self.output_network = cfg.build_output_network( scope="output_network") d_n_frames, n_trackers, batch_size = self.dynamic_n_frames, self.n_trackers, self.batch_size # --- encode --- video = tf.reshape(self.inp, (batch_size * d_n_frames, *self.obs_shape[1:])) zero_to_one_Y = tf.linspace(0., 1., self.image_height) zero_to_one_X = tf.linspace(0., 1., self.image_width) X, Y = tf.meshgrid(zero_to_one_X, zero_to_one_Y) X = tf.tile(X[None, :, :, None], (batch_size * d_n_frames, 1, 1, 1)) Y = tf.tile(Y[None, :, :, None], (batch_size * d_n_frames, 1, 1, 1)) video = tf.concat([video, Y, X], axis=-1) encoded_frames, _, _ = self.backbone(video, self.S, self.is_training) _, H, W, _ = tf_shape(encoded_frames) self.H = H self.W = W encoded_frames = tf.reshape(encoded_frames, (batch_size, d_n_frames, H * W, self.S)) self.encoded_frames = encoded_frames cts = tf.minimum( 1., self.float_is_training + (1 - float(self.discrete_eval))) self.mask_temperature = cts * 1.0 + (1 - cts) * 1e-5 self.layer_temperature = cts * 1.0 + (1 - cts) * 1e-5 f = tf.constant(0, dtype=tf.int32) if self.learn_initial_state: hidden_state = self.initial_hidden_state[None, ...] else: hidden_state = self.cell.zero_state(1, tf.float32)[None, ...] hidden_states = tf.tile(hidden_state, ( self.n_trackers, self.batch_size, ) + (1, ) * (len(hidden_state.shape) - 2)) conf = tf.zeros((n_trackers, batch_size, 1)) structure = dict( hidden_states=hidden_states, tracker_output=tf.zeros( (self.n_trackers, self.batch_size, self.n_hidden)), attention_result=tf.zeros( (self.n_trackers, self.batch_size, self.S)), attention_weights=tf.zeros( (self.n_trackers, self.batch_size, H * W)), memory_activation=tf.zeros( (self.n_trackers, self.batch_size, H * W)), conf=conf, layer=tf.zeros((self.n_trackers, self.batch_size, self.n_layers)), pose=tf.zeros((self.n_trackers, self.batch_size, 4)), mask=tf.zeros((self.n_trackers, self.batch_size, np.prod(self.object_shape))), appearance=tf.zeros((self.n_trackers, self.batch_size, 3 * np.prod(self.object_shape))), order=tf.zeros((self.n_trackers, self.batch_size, 1), dtype=tf.int32), ) tensor_arrays = make_tensor_arrays(structure, self.dynamic_n_frames) loop_vars = [f, conf, hidden_states, *tensor_arrays] result = tf.while_loop(self._loop_cond, self._loop_body, loop_vars) first_ta_idx = min(i for i, ta in enumerate(result) if isinstance(ta, tf.TensorArray)) tensor_arrays = result[first_ta_idx:] def finalize_ta(ta): t = ta.stack() # reshape from (n_frames, n_trackers, batch_size, *other) to (batch_size, n_frames, n_trackers, *other) return tf.transpose(t, (2, 0, 1, *range(3, len(t.shape)))) tensors = map_structure( finalize_ta, tensor_arrays, is_leaf=lambda t: isinstance(t, tf.TensorArray)) tensors = apply_keys(structure, tensors) self._tensors.update(tensors) pprint.pprint(self._tensors) # --- render/decode --- ys, xs, yt, xt = tf.split(tensors["pose"], 4, axis=-1) self.record_tensors( conf=tensors["conf"], ys=ys, xs=xs, yt=yt, xt=xt, ) yt_normed = (yt + 1) / 2 xt_normed = (xt + 1) / 2 ys_normed = 1 + self.eta[0] * ys xs_normed = 1 + self.eta[1] * xs normalized_box = tf.concat( [yt_normed, xt_normed, ys_normed, xs_normed], axis=-1) # expose values for plotting self._tensors.update( normalized_box=normalized_box, mask=tf.reshape( tensors["mask"], (batch_size, d_n_frames, n_trackers, *self.object_shape)), appearance=tf.reshape(tensors["appearance"], (batch_size, d_n_frames, n_trackers, *self.object_shape, self.image_depth)), conf=tensors["conf"], layer=tensors["layer"], order=tensors["order"], ) # --- reshape values --- N = batch_size * d_n_frames * n_trackers ys_normed = tf.reshape(ys_normed, (N, 1)) xs_normed = tf.reshape(xs_normed, (N, 1)) yt_normed = tf.reshape(yt_normed, (N, 1)) xt_normed = tf.reshape(xt_normed, (N, 1)) _yt, _xt, _ys, _xs = tba_coords_to_image_space( yt_normed, xt_normed, ys_normed, xs_normed, (self.image_height, self.image_width), self.anchor_box, top_left=False) mask = tf.reshape(tensors["mask"], (N, *self.object_shape, 1)) appearance = tf.reshape(tensors["appearance"], (N, *self.object_shape, self.image_depth)) transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper((self.image_height, self.image_width), self.object_shape, transform_constraints) inverse_warper = warper.inverse() transforms = tf.concat([_xs, _xt, _ys, _yt], axis=-1) grid_coords = inverse_warper(transforms) transformed_masks = tf.contrib.resampler.resampler(mask, grid_coords) transformed_masks = tf.reshape( transformed_masks, (batch_size, d_n_frames, n_trackers, self.image_height, self.image_width, 1)) transformed_appearances = tf.contrib.resampler.resampler( appearance, grid_coords) transformed_appearances = tf.reshape( transformed_appearances, (batch_size, d_n_frames, n_trackers, self.image_height, self.image_width, self.image_depth)) layer_masks = [] layer_appearances = [] conf = tensors["conf"][:, :, :, :, None, None] # TODO: currently assuming a black background final_frames = tf.zeros((batch_size, d_n_frames, self.image_height, self.image_width, self.image_depth)) # For each layer, create a mask image and an appearance image for layer_idx in range(self.n_layers): layer_weight = tensors["layer"][:, :, :, layer_idx, None, None, None] # (batch_size, n_frames, self.image_height, self.image_width, 1) layer_mask = tf.reduce_sum(conf * layer_weight * transformed_masks, axis=2) layer_mask = tf.minimum(1.0, layer_mask) # (batch_size, n_frames, self.image_height, self.image_width, 3) layer_appearance = tf.reduce_sum(conf * layer_weight * transformed_masks * transformed_appearances, axis=2) if self.clamp_appearance: layer_appearance = tf.minimum(1.0, layer_appearance) final_frames = (1 - layer_mask) * final_frames + layer_appearance layer_masks.append(layer_mask) layer_appearances.append(layer_appearance) self._tensors["output"] = final_frames # --- losses --- self._tensors['per_pixel_reconstruction_loss'] = (self.inp - final_frames)**2 self.losses['reconstruction'] = ( tf.reduce_sum(self._tensors["per_pixel_reconstruction_loss"]) / tf.cast(d_n_frames * self.batch_size, tf.float32)) # self.losses['reconstruction'] = tf.reduce_mean(self._tensors["per_pixel_reconstruction_loss"]) self.losses['area'] = self.lmbda * tf.reduce_mean( ys_normed * xs_normed)
def _body(self, inp, features, objects, is_posterior): """ Summary of how updates are done for the different variables: glimpse': glimpse_params = where + 0.1 * predicted_logit where_y/x: new_where_y/x = where_y/x + where_t_scale * tanh(predicted_logit) where_h/w: new_where_h/w_logit = where_h/w_logit + predicted_logit what: Hard to summarize here, taken from SQAIR. Kind of like an LSTM. depth: new_depth_logit = depth_logit + predicted_logit obj: new_obj = obj * sigmoid(predicted_logit) """ batch_size, n_objects, _ = tf_shape(features) new_objects = AttrDict() is_posterior_tf = tf.ones_like(features[..., 0:2]) if is_posterior: is_posterior_tf = is_posterior_tf * [1, 0] else: is_posterior_tf = is_posterior_tf * [0, 1] base_features = tf.concat([features, is_posterior_tf], axis=-1) cyt, cxt, ys, xs = tf.split(objects.normalized_box, 4, axis=-1) # Do this regardless of is_posterior, otherwise ScopedFunction gets messed up glimpse_dim = self.object_shape[0] * self.object_shape[1] glimpse_prime_params = apply_object_wise(self.glimpse_prime_network, base_features, output_size=4 + 2 * glimpse_dim, is_training=self.is_training) glimpse_prime_params, glimpse_prime_mask_logit, glimpse_mask_logit = \ tf.split(glimpse_prime_params, [4, glimpse_dim, glimpse_dim], axis=-1) if is_posterior: # --- obtain final parameters for glimpse prime by modifying current pose --- _yt, _xt, _ys, _xs = tf.split(glimpse_prime_params, 4, axis=-1) g_yt = cyt + 0.1 * _yt g_xt = cxt + 0.1 * _xt g_ys = ys + 0.1 * _ys g_xs = xs + 0.1 * _xs # --- extract glimpse prime --- _, image_height, image_width, _ = tf_shape(inp) g_yt, g_xt, g_ys, g_xs = coords_to_image_space( g_yt, g_xt, g_ys, g_xs, (image_height, image_width), self.anchor_box, top_left=False) glimpse_prime = extract_affine_glimpse(inp, self.object_shape, g_yt, g_xt, g_ys, g_xs, self.edge_resampler) else: g_yt = tf.zeros_like(cyt) g_xt = tf.zeros_like(cxt) g_ys = tf.zeros_like(ys) g_xs = tf.zeros_like(xs) glimpse_prime = tf.zeros( (batch_size, n_objects, *self.object_shape, self.image_depth)) glimpse_prime_mask = tf.nn.sigmoid(glimpse_prime_mask_logit + 1.) leading_mask_shape = tf_shape(glimpse_prime)[:-1] glimpse_prime_mask = tf.reshape(glimpse_prime_mask, (*leading_mask_shape, 1)) new_objects.update( glimpse_prime_box=tf.concat([g_yt, g_xt, g_ys, g_xs], axis=-1), glimpse_prime=glimpse_prime, glimpse_prime_mask=glimpse_prime_mask, ) glimpse_prime *= glimpse_prime_mask # --- encode glimpse --- encoded_glimpse_prime = apply_object_wise(self.glimpse_prime_encoder, glimpse_prime, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse_prime = tf.zeros((batch_size, n_objects, self.A), dtype=tf.float32) # --- position and scale --- # roughly: # base_features == temporal_state, encoded_glimpse_prime == hidden_output # hidden_output conditions on encoded_glimpse, and that's the only place encoded_glimpse_prime is used. # Here SQAIR conditions on the actual location values from the previous timestep, but we leave that out for now. d_box_inp = tf.concat([base_features, encoded_glimpse_prime], axis=-1) d_box_params = apply_object_wise(self.d_box_network, d_box_inp, output_size=8, is_training=self.is_training) d_box_mean, d_box_log_std = tf.split(d_box_params, 2, axis=-1) d_box_std = self.std_nonlinearity(d_box_log_std) d_box_mean = self.training_wheels * tf.stop_gradient(d_box_mean) + ( 1 - self.training_wheels) * d_box_mean d_box_std = self.training_wheels * tf.stop_gradient(d_box_std) + ( 1 - self.training_wheels) * d_box_std d_yt_mean, d_xt_mean, d_ys, d_xs = tf.split(d_box_mean, 4, axis=-1) d_yt_std, d_xt_std, ys_std, xs_std = tf.split(d_box_std, 4, axis=-1) # --- position --- # We predict position a bit differently from scale. For scale we want to put a prior on the actual value of # the scale, whereas for position we want to put a prior on the difference in position over timesteps. d_yt_logit = Normal(loc=d_yt_mean, scale=d_yt_std).sample() d_xt_logit = Normal(loc=d_xt_mean, scale=d_xt_std).sample() d_yt = self.where_t_scale * tf.nn.tanh(d_yt_logit) d_xt = self.where_t_scale * tf.nn.tanh(d_xt_logit) new_cyt = cyt + d_yt new_cxt = cxt + d_xt new_abs_posn = objects.abs_posn + tf.concat([d_yt, d_xt], axis=-1) # --- scale --- new_ys_mean = objects.ys_logit + d_ys new_xs_mean = objects.xs_logit + d_xs new_ys_logit = Normal(loc=new_ys_mean, scale=ys_std).sample() new_xs_logit = Normal(loc=new_xs_mean, scale=xs_std).sample() new_ys = float(self.max_hw - self.min_hw) * tf.nn.sigmoid( tf.clip_by_value(new_ys_logit, -10, 10)) + self.min_hw new_xs = float(self.max_hw - self.min_hw) * tf.nn.sigmoid( tf.clip_by_value(new_xs_logit, -10, 10)) + self.min_hw if self.use_abs_posn: box_params = tf.concat([ new_abs_posn, d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit ], axis=-1) else: box_params = tf.concat( [d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit], axis=-1) new_objects.update( abs_posn=new_abs_posn, yt=new_cyt, xt=new_cxt, ys=new_ys, xs=new_xs, normalized_box=tf.concat([new_cyt, new_cxt, new_ys, new_xs], axis=-1), d_yt_logit=d_yt_logit, d_xt_logit=d_xt_logit, ys_logit=new_ys_logit, xs_logit=new_xs_logit, d_yt_logit_mean=d_yt_mean, d_xt_logit_mean=d_xt_mean, ys_logit_mean=new_ys_mean, xs_logit_mean=new_xs_mean, d_yt_logit_std=d_yt_std, d_xt_logit_std=d_xt_std, ys_logit_std=ys_std, xs_logit_std=xs_std, ) # --- attributes --- # --- extract a glimpse using new box --- if is_posterior: _, image_height, image_width, _ = tf_shape(inp) _new_cyt, _new_cxt, _new_ys, _new_xs = coords_to_image_space( new_cyt, new_cxt, new_ys, new_xs, (image_height, image_width), self.anchor_box, top_left=False) glimpse = extract_affine_glimpse(inp, self.object_shape, _new_cyt, _new_cxt, _new_ys, _new_xs, self.edge_resampler) else: glimpse = tf.zeros( (batch_size, n_objects, *self.object_shape, self.image_depth)) glimpse_mask = tf.nn.sigmoid(glimpse_mask_logit + 1.) leading_mask_shape = tf_shape(glimpse)[:-1] glimpse_mask = tf.reshape(glimpse_mask, (*leading_mask_shape, 1)) glimpse *= glimpse_mask encoded_glimpse = apply_object_wise(self.glimpse_encoder, glimpse, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse = tf.zeros((batch_size, n_objects, self.A), dtype=tf.float32) # --- predict change in attributes --- # so under sqair we mix between three different values for the attributes: # 1. value from previous timestep # 2. value predicted directly from glimpse # 3. value predicted based on update of temporal cell...this update conditions on hidden_output, # the prediction in #2., and the where values. # How to do this given that we are predicting the change in attr? We could just directly predict # the attr instead, but call it d_attr. After all, it is in this function that we control # whether d_attr is added to attr. # So, make a prediction based on just the input: attr_from_inp = apply_object_wise(self.predict_attr_inp, encoded_glimpse, output_size=2 * self.A, is_training=self.is_training) attr_from_inp_mean, attr_from_inp_log_std = tf.split(attr_from_inp, [self.A, self.A], axis=-1) attr_from_inp_std = self.std_nonlinearity(attr_from_inp_log_std) # And then a prediction which takes the past into account (predicting gate values at the same time): attr_from_temp_inp = tf.concat( [base_features, box_params, encoded_glimpse], axis=-1) attr_from_temp = apply_object_wise(self.predict_attr_temp, attr_from_temp_inp, output_size=5 * self.A, is_training=self.is_training) (attr_from_temp_mean, attr_from_temp_log_std, f_gate_logit, i_gate_logit, t_gate_logit) = tf.split(attr_from_temp, 5, axis=-1) attr_from_temp_std = self.std_nonlinearity(attr_from_temp_log_std) # bias the gates f_gate = tf.nn.sigmoid(f_gate_logit + 1) * .9999 i_gate = tf.nn.sigmoid(i_gate_logit + 1) * .9999 t_gate = tf.nn.sigmoid(t_gate_logit + 1) * .9999 attr_mean = f_gate * objects.attr + ( 1 - i_gate) * attr_from_inp_mean + (1 - t_gate) * attr_from_temp_mean attr_std = (1 - i_gate) * attr_from_inp_std + ( 1 - t_gate) * attr_from_temp_std new_attr = Normal(loc=attr_mean, scale=attr_std).sample() # --- apply change in attributes --- new_objects.update( attr=new_attr, d_attr=new_attr - objects.attr, d_attr_mean=attr_mean - objects.attr, d_attr_std=attr_std, f_gate=f_gate, i_gate=i_gate, t_gate=t_gate, glimpse=glimpse, glimpse_mask=glimpse_mask, ) # --- z --- d_z_inp = tf.concat( [base_features, box_params, new_attr, encoded_glimpse], axis=-1) d_z_params = apply_object_wise(self.d_z_network, d_z_inp, output_size=2, is_training=self.is_training) d_z_mean, d_z_log_std = tf.split(d_z_params, 2, axis=-1) d_z_std = self.std_nonlinearity(d_z_log_std) d_z_mean = self.training_wheels * tf.stop_gradient(d_z_mean) + ( 1 - self.training_wheels) * d_z_mean d_z_std = self.training_wheels * tf.stop_gradient(d_z_std) + ( 1 - self.training_wheels) * d_z_std d_z_logit = Normal(loc=d_z_mean, scale=d_z_std).sample() new_z_logit = objects.z_logit + d_z_logit new_z = self.z_nonlinearity(new_z_logit) new_objects.update( z=new_z, z_logit=new_z_logit, d_z_logit=d_z_logit, d_z_logit_mean=d_z_mean, d_z_logit_std=d_z_std, ) # --- obj --- d_obj_inp = tf.concat( [base_features, box_params, new_attr, new_z, encoded_glimpse], axis=-1) d_obj_logit = apply_object_wise(self.d_obj_network, d_obj_inp, output_size=1, is_training=self.is_training) d_obj_logit = self.training_wheels * tf.stop_gradient(d_obj_logit) + ( 1 - self.training_wheels) * d_obj_logit d_obj_log_odds = tf.clip_by_value(d_obj_logit / self.obj_temp, -10., 10.) d_obj_pre_sigmoid = (self._noisy * concrete_binary_pre_sigmoid_sample( d_obj_log_odds, self.obj_concrete_temp) + (1 - self._noisy) * d_obj_log_odds) d_obj = tf.nn.sigmoid(d_obj_pre_sigmoid) new_obj = objects.obj * d_obj new_objects.update( d_obj_log_odds=d_obj_log_odds, d_obj_prob=tf.nn.sigmoid(d_obj_log_odds), d_obj_pre_sigmoid=d_obj_pre_sigmoid, d_obj=d_obj, obj=new_obj, ) # --- update each object's hidden state -- cell_input = tf.concat([box_params, new_attr, new_z, new_obj], axis=-1) if is_posterior: _, new_objects.prop_state = apply_object_wise( self.cell, cell_input, objects.prop_state) new_objects.prior_prop_state = new_objects.prop_state else: _, new_objects.prior_prop_state = apply_object_wise( self.cell, cell_input, objects.prior_prop_state) new_objects.prop_state = new_objects.prior_prop_state return new_objects
def _call(self, inp, features, objects, is_training, is_posterior): print("\n" + "-" * 10 + " PropagationLayer(is_posterior={}) ".format(is_posterior) + "-" * 10) self._build_networks() if not self.initialized: # Note this limits the re-usability of this module to images # with a fixed shape (the shape of the first image it is used on) self.image_height = int(inp.shape[-3]) self.image_width = int(inp.shape[-2]) self.image_depth = int(inp.shape[-1]) self.batch_size = tf.shape(inp)[0] self.is_training = is_training self.float_is_training = tf.to_float(is_training) if self.do_lateral: # hasn't been updated to make use of abs_posn raise Exception("NotImplemented.") batch_size, n_objects, _ = tf_shape(features) new_objects = [] for i in range(n_objects): # apply lateral to running objects with the feature vector for # the current object _features = features[:, i:i + 1, :] if i > 0: normalized_box = tf.concat( [o.normalized_box for o in new_objects], axis=1) attr = tf.concat([o.attr for o in new_objects], axis=1) z = tf.concat([o.z for o in new_objects], axis=1) obj = tf.concat([o.obj for o in new_objects], axis=1) completed_features = tf.concat( [normalized_box[:, :, 2:], attr, z, obj], axis=2) completed_locs = normalized_box[:, :, :2] current_features = tf.concat([ objects.normalized_box[:, i:i + 1, 2:], objects.attr[:, i:i + 1], objects.z[:, i:i + 1], objects.obj[:, i:i + 1] ], axis=2) current_locs = objects.normalized_box[:, i:i + 1, :2] # if i > max_completed_objects: # # top_k_indices # # squared_distances = tf.reduce_sum((completed_locs - current_locs)**2, axis=2) # # _, top_k_indices = tf.nn.top_k(squared_distances, k=max_completed_objects, sorted=False) _features = self.lateral_network(completed_locs, completed_features, current_locs, current_features, is_training) _objects = AttrDict( normalized_box=objects.normalized_box[:, i:i + 1], attr=objects.attr[:, i:i + 1], z=objects.z[:, i:i + 1], obj=objects.obj[:, i:i + 1], ) _new_objects = self._body(inp, _features, _objects, is_posterior) new_objects.append(_new_objects) _new_objects = AttrDict() for k in new_objects[0]: _new_objects[k] = tf.concat([no[k] for no in new_objects], axis=1) return _new_objects else: return self._body(inp, features, objects, is_posterior)
def _body(self, inp, features, objects, is_posterior): batch_size, n_objects, _ = tf_shape(features) new_objects = AttrDict() is_posterior_tf = tf.ones_like(features[..., 0:2]) if is_posterior: is_posterior_tf = is_posterior_tf * [1, 0] else: is_posterior_tf = is_posterior_tf * [0, 1] base_features = tf.concat([features, is_posterior_tf], axis=-1) cyt, cxt, ys, xs = tf.split(objects.normalized_box, 4, axis=-1) if self.learn_glimpse_prime: # Do this regardless of is_posterior, otherwise ScopedFunction gets messed up glimpse_prime_params = apply_object_wise( self.glimpse_prime_network, base_features, output_size=4, is_training=self.is_training) else: glimpse_prime_params = tf.zeros_like(base_features[..., :4]) if is_posterior: if self.learn_glimpse_prime: # --- obtain final parameters for glimpse prime by modifying current pose --- _yt, _xt, _ys, _xs = tf.split(glimpse_prime_params, 4, axis=-1) # This is how it is done in SQAIR g_yt = cyt + 0.1 * _yt g_xt = cxt + 0.1 * _xt g_ys = ys + 0.1 * _ys g_xs = xs + 0.1 * _xs # g_yt = cyt + self.glimpse_prime_scale * tf.nn.tanh(_yt) # g_xt = cxt + self.glimpse_prime_scale * tf.nn.tanh(_xt) # g_ys = ys + self.glimpse_prime_scale * tf.nn.tanh(_ys) # g_xs = xs + self.glimpse_prime_scale * tf.nn.tanh(_xs) else: g_yt = cyt g_xt = cxt g_ys = self.glimpse_prime_scale * ys g_xs = self.glimpse_prime_scale * xs # --- extract glimpse prime --- _, image_height, image_width, _ = tf_shape(inp) g_yt, g_xt, g_ys, g_xs = coords_to_image_space( g_yt, g_xt, g_ys, g_xs, (image_height, image_width), self.anchor_box, top_left=False) glimpse_prime = extract_affine_glimpse(inp, self.object_shape, g_yt, g_xt, g_ys, g_xs, self.edge_resampler) else: g_yt = tf.zeros_like(cyt) g_xt = tf.zeros_like(cxt) g_ys = tf.zeros_like(ys) g_xs = tf.zeros_like(xs) glimpse_prime = tf.zeros( (batch_size, n_objects, *self.object_shape, self.image_depth)) new_objects.update(glimpse_prime_box=tf.concat( [g_yt, g_xt, g_ys, g_xs], axis=-1), ) # --- encode glimpse --- encoded_glimpse_prime = apply_object_wise(self.glimpse_prime_encoder, glimpse_prime, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse_prime = tf.zeros((batch_size, n_objects, self.A), dtype=tf.float32) # --- position and scale --- d_box_inp = tf.concat([base_features, encoded_glimpse_prime], axis=-1) d_box_params = apply_object_wise(self.d_box_network, d_box_inp, output_size=8, is_training=self.is_training) d_box_mean, d_box_log_std = tf.split(d_box_params, 2, axis=-1) d_box_std = self.std_nonlinearity(d_box_log_std) d_box_mean = self.training_wheels * tf.stop_gradient(d_box_mean) + ( 1 - self.training_wheels) * d_box_mean d_box_std = self.training_wheels * tf.stop_gradient(d_box_std) + ( 1 - self.training_wheels) * d_box_std d_yt_mean, d_xt_mean, d_ys, d_xs = tf.split(d_box_mean, 4, axis=-1) d_yt_std, d_xt_std, ys_std, xs_std = tf.split(d_box_std, 4, axis=-1) # --- position --- # We predict position a bit differently from scale. For scale we want to put a prior on the actual value of # the scale, whereas for position we want to put a prior on the difference in position over timesteps. d_yt_logit = Normal(loc=d_yt_mean, scale=d_yt_std).sample() d_xt_logit = Normal(loc=d_xt_mean, scale=d_xt_std).sample() d_yt = self.where_t_scale * tf.nn.tanh(d_yt_logit) d_xt = self.where_t_scale * tf.nn.tanh(d_xt_logit) new_cyt = cyt + d_yt new_cxt = cxt + d_xt new_abs_posn = objects.abs_posn + tf.concat([d_yt, d_xt], axis=-1) # --- scale --- new_ys_mean = objects.ys_logit + d_ys new_xs_mean = objects.xs_logit + d_xs new_ys_logit = Normal(loc=new_ys_mean, scale=ys_std).sample() new_xs_logit = Normal(loc=new_xs_mean, scale=xs_std).sample() new_ys = float(self.max_hw - self.min_hw) * tf.nn.sigmoid( tf.clip_by_value(new_ys_logit, -10, 10)) + self.min_hw new_xs = float(self.max_hw - self.min_hw) * tf.nn.sigmoid( tf.clip_by_value(new_xs_logit, -10, 10)) + self.min_hw # Used for conditioning if self.use_abs_posn: box_params = tf.concat([ new_abs_posn, d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit ], axis=-1) else: box_params = tf.concat( [d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit], axis=-1) new_objects.update( abs_posn=new_abs_posn, yt=new_cyt, xt=new_cxt, ys=new_ys, xs=new_xs, normalized_box=tf.concat([new_cyt, new_cxt, new_ys, new_xs], axis=-1), d_yt_logit=d_yt_logit, d_xt_logit=d_xt_logit, ys_logit=new_ys_logit, xs_logit=new_xs_logit, d_yt_logit_mean=d_yt_mean, d_xt_logit_mean=d_xt_mean, ys_logit_mean=new_ys_mean, xs_logit_mean=new_xs_mean, d_yt_logit_std=d_yt_std, d_xt_logit_std=d_xt_std, ys_logit_std=ys_std, xs_logit_std=xs_std, glimpse_prime=glimpse_prime, ) # --- attributes --- # --- extract a glimpse using new box --- if is_posterior: _, image_height, image_width, _ = tf_shape(inp) _new_cyt, _new_cxt, _new_ys, _new_xs = coords_to_image_space( new_cyt, new_cxt, new_ys, new_xs, (image_height, image_width), self.anchor_box, top_left=False) glimpse = extract_affine_glimpse(inp, self.object_shape, _new_cyt, _new_cxt, _new_ys, _new_xs, self.edge_resampler) else: glimpse = tf.zeros( (batch_size, n_objects, *self.object_shape, self.image_depth)) encoded_glimpse = apply_object_wise(self.glimpse_encoder, glimpse, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse = tf.zeros((batch_size, n_objects, self.A), dtype=tf.float32) # --- predict change in attributes --- d_attr_inp = tf.concat([base_features, box_params, encoded_glimpse], axis=-1) d_attr_params = apply_object_wise(self.d_attr_network, d_attr_inp, output_size=2 * self.A + 1, is_training=self.is_training) d_attr_mean, d_attr_log_std, gate_logit = tf.split(d_attr_params, [self.A, self.A, 1], axis=-1) d_attr_std = self.std_nonlinearity(d_attr_log_std) gate = tf.nn.sigmoid(gate_logit) if self.gate_d_attr: d_attr_mean *= gate d_attr = Normal(loc=d_attr_mean, scale=d_attr_std).sample() # --- apply change in attributes --- new_attr = objects.attr + d_attr new_objects.update( attr=new_attr, d_attr=d_attr, d_attr_mean=d_attr_mean, d_attr_std=d_attr_std, glimpse=glimpse, d_attr_gate=gate, ) # --- z --- d_z_inp = tf.concat( [base_features, box_params, new_attr, encoded_glimpse], axis=-1) d_z_params = apply_object_wise(self.d_z_network, d_z_inp, output_size=2, is_training=self.is_training) d_z_mean, d_z_log_std = tf.split(d_z_params, 2, axis=-1) d_z_std = self.std_nonlinearity(d_z_log_std) d_z_mean = self.training_wheels * tf.stop_gradient(d_z_mean) + ( 1 - self.training_wheels) * d_z_mean d_z_std = self.training_wheels * tf.stop_gradient(d_z_std) + ( 1 - self.training_wheels) * d_z_std d_z_logit = Normal(loc=d_z_mean, scale=d_z_std).sample() new_z_logit = objects.z_logit + d_z_logit new_z = self.z_nonlinearity(new_z_logit) new_objects.update( z=new_z, z_logit=new_z_logit, d_z_logit=d_z_logit, d_z_logit_mean=d_z_mean, d_z_logit_std=d_z_std, ) # --- obj --- d_obj_inp = tf.concat( [base_features, box_params, new_attr, new_z, encoded_glimpse], axis=-1) d_obj_logit = apply_object_wise(self.d_obj_network, d_obj_inp, output_size=1, is_training=self.is_training) d_obj_logit = self.training_wheels * tf.stop_gradient(d_obj_logit) + ( 1 - self.training_wheels) * d_obj_logit d_obj_log_odds = tf.clip_by_value(d_obj_logit / self.obj_temp, -10., 10.) d_obj_pre_sigmoid = (self._noisy * concrete_binary_pre_sigmoid_sample( d_obj_log_odds, self.obj_concrete_temp) + (1 - self._noisy) * d_obj_log_odds) d_obj = tf.nn.sigmoid(d_obj_pre_sigmoid) new_obj = objects.obj * d_obj new_objects.update( d_obj_log_odds=d_obj_log_odds, d_obj_prob=tf.nn.sigmoid(d_obj_log_odds), d_obj_pre_sigmoid=d_obj_pre_sigmoid, d_obj=d_obj, obj=new_obj, ) # --- update each object's hidden state -- cell_input = tf.concat([box_params, new_attr, new_z, new_obj], axis=-1) if is_posterior: _, new_objects.prop_state = apply_object_wise( self.cell, cell_input, objects.prop_state) new_objects.prior_prop_state = new_objects.prop_state else: _, new_objects.prior_prop_state = apply_object_wise( self.cell, cell_input, objects.prior_prop_state) new_objects.prop_state = new_objects.prior_prop_state return new_objects
def _call(self, inp, inp_features, is_training, is_posterior=True, prop_state=None): print("\n" + "-" * 10 + " ConvGridObjectLayer(is_posterior={}) ".format(is_posterior) + "-" * 10) # --- set up sub networks and attributes --- self.maybe_build_subnet("box_network", builder=cfg.build_conv_lateral, key="box") self.maybe_build_subnet("attr_network", builder=cfg.build_conv_lateral, key="attr") self.maybe_build_subnet("z_network", builder=cfg.build_conv_lateral, key="z") self.maybe_build_subnet("obj_network", builder=cfg.build_conv_lateral, key="obj") self.maybe_build_subnet("object_encoder") _, H, W, _, n_channels = tf_shape(inp_features) if self.B != 1: raise Exception("NotImplemented") if not self.initialized: # Note this limits the re-usability of this module to images # with a fixed shape (the shape of the first image it is used on) self.batch_size, self.image_height, self.image_width, self.image_depth = tf_shape(inp) self.H = H self.W = W self.HWB = H*W self.batch_size = tf.shape(inp)[0] self.is_training = is_training self.float_is_training = tf.to_float(is_training) inp_features = tf.reshape(inp_features, (self.batch_size, H, W, n_channels)) is_posterior_tf = tf.ones_like(inp_features[..., :2]) if is_posterior: is_posterior_tf = is_posterior_tf * [1, 0] else: is_posterior_tf = is_posterior_tf * [0, 1] objects = AttrDict() base_features = tf.concat([inp_features, is_posterior_tf], axis=-1) # --- box --- layer_inp = base_features n_features = self.n_passthrough_features output_size = 8 network_output = self.box_network(layer_inp, output_size + n_features, self.is_training) rep_input, features = tf.split(network_output, (output_size, n_features), axis=-1) _objects = self._build_box(rep_input, self.is_training) objects.update(_objects) # --- attr --- if is_posterior: # --- Get object attributes using object encoder --- yt, xt, ys, xs = tf.split(objects['normalized_box'], 4, axis=-1) yt, xt, ys, xs = coords_to_image_space( yt, xt, ys, xs, (self.image_height, self.image_width), self.anchor_box, top_left=False) transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper( (self.image_height, self.image_width), self.object_shape, transform_constraints) _boxes = tf.concat([xs, 2*xt - 1, ys, 2*yt - 1], axis=-1) _boxes = tf.reshape(_boxes, (self.batch_size*H*W, 4)) grid_coords = warper(_boxes) grid_coords = tf.reshape(grid_coords, (self.batch_size, H, W, *self.object_shape, 2,)) if self.edge_resampler: glimpse = resampler_edge.resampler_edge(inp, grid_coords) else: glimpse = tf.contrib.resampler.resampler(inp, grid_coords) else: glimpse = tf.zeros((self.batch_size, H, W, *self.object_shape, self.image_depth)) # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction encoded_glimpse = apply_object_wise( self.object_encoder, glimpse, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse = tf.zeros_like(encoded_glimpse) layer_inp = tf.concat([base_features, features, encoded_glimpse, objects['local_box']], axis=-1) network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self.is_training) attr_mean, attr_log_std, features = tf.split(network_output, (self.A, self.A, n_features), axis=-1) attr_std = self.std_nonlinearity(attr_log_std) attr = Normal(loc=attr_mean, scale=attr_std).sample() objects.update(attr_mean=attr_mean, attr_std=attr_std, attr=attr, glimpse=glimpse) # --- z --- layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr']], axis=-1) n_features = self.n_passthrough_features network_output = self.z_network(layer_inp, 2 + n_features, self.is_training) z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=-1) z_std = self.std_nonlinearity(z_log_std) z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (1-self.training_wheels) * z_mean z_std = self.training_wheels * tf.stop_gradient(z_std) + (1-self.training_wheels) * z_std z_logit = Normal(loc=z_mean, scale=z_std).sample() z = self.z_nonlinearity(z_logit) objects.update(z_logit_mean=z_mean, z_logit_std=z_std, z_logit=z_logit, z=z) # --- obj --- layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr'], objects['z']], axis=-1) rep_input = self.obj_network(layer_inp, 1, self.is_training) _objects = self._build_obj(rep_input, self.is_training) objects.update(_objects) # --- final --- _objects = AttrDict() for k, v in objects.items(): _, _, _, *trailing_dims = tf_shape(v) _objects[k] = tf.reshape(v, (self.batch_size, self.HWB, *trailing_dims)) objects = _objects if prop_state is not None: objects.prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1)) objects.prior_prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1)) # --- misc --- objects.n_objects = tf.fill((self.batch_size,), self.HWB) objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2)) objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj), axis=(1, 2)) return objects
def _call(self, data, is_training): self.data = data inp = data["image"] self._tensors = AttrDict( inp=inp, is_training=is_training, float_is_training=tf.to_float(is_training), batch_size=tf.shape(inp)[0], ) if "annotations" in data: self._tensors.update( annotations=data["annotations"]["data"], n_annotations=data["annotations"]["shapes"][:, 1], n_valid_annotations=tf.to_int32( tf.reduce_sum( data["annotations"]["data"][:, :, :, 0] * tf.to_float(data["annotations"]["mask"][:, :, :, 0]), axis=2 ) ) ) if "label" in data: self._tensors.update( targets=data["label"], ) if "background" in data: self._tensors.update( ground_truth_background=data["background"], ) if "offset" in data: self._tensors.update( offset=data["offset"], ) max_n_frames = tf_shape(inp)[1] if self.stage_steps is None: self.current_stage = tf.constant(0, tf.int32) dynamic_n_frames = max_n_frames else: self.current_stage = tf.cast(tf.train.get_or_create_global_step(), tf.int32) // self.stage_steps dynamic_n_frames = tf.minimum( self.initial_n_frames + self.n_frames_scale * self.current_stage, max_n_frames) dynamic_n_frames = tf.cast(dynamic_n_frames, tf.float32) dynamic_n_frames = ( self.float_is_training * tf.cast(dynamic_n_frames, tf.float32) + (1-self.float_is_training) * tf.cast(max_n_frames, tf.float32) ) self.dynamic_n_frames = tf.cast(dynamic_n_frames, tf.int32) self._tensors.current_stage = self.current_stage self._tensors.dynamic_n_frames = self.dynamic_n_frames self._tensors.inp = self._tensors.inp[:, :self.dynamic_n_frames] if 'annotations' in self._tensors: self._tensors.annotations = self._tensors.annotations[:, :self.dynamic_n_frames] # self._tensors.n_annotations = self._tensors.n_annotations[:, :self.dynamic_n_frames] self._tensors.n_valid_annotations = self._tensors.n_valid_annotations[:, :self.dynamic_n_frames] self.record_tensors( batch_size=tf.to_float(self.batch_size), float_is_training=self.float_is_training, current_stage=self.current_stage, dynamic_n_frames=self.dynamic_n_frames, ) self.losses = dict() with tf.variable_scope("representation", reuse=self.initialized): if self.needs_background: self.build_background() self.build_representation() return dict( tensors=self._tensors, recorded_tensors=self.recorded_tensors, losses=self.losses, )
def _call(self, inp, inp_features, is_training, is_posterior=True, prop_state=None): print("\n" + "-" * 10 + " GridObjectLayer(is_posterior={}) ".format(is_posterior) + "-" * 10) # --- set up sub networks and attributes --- self.maybe_build_subnet("box_network", builder=cfg.build_lateral, key="box") self.maybe_build_subnet("attr_network", builder=cfg.build_lateral, key="attr") self.maybe_build_subnet("z_network", builder=cfg.build_lateral, key="z") self.maybe_build_subnet("obj_network", builder=cfg.build_lateral, key="obj") self.maybe_build_subnet("object_encoder") _, H, W, _, _ = tf_shape(inp_features) H = int(H) W = int(W) if not self.initialized: # Note this limits the re-usability of this module to images # with a fixed shape (the shape of the first image it is used on) self.batch_size, self.image_height, self.image_width, self.image_depth = tf_shape(inp) self.H = H self.W = W self.HWB = H*W*self.B self.is_training = is_training self.float_is_training = tf.to_float(is_training) # --- set up the edge element --- sizes = [4, self.A, 1, 1] sigmoids = [True, False, False, True] total_sample_size = sum(sizes) if self.edge_weights is None: self.edge_weights = tf.get_variable("edge_weights", shape=total_sample_size, dtype=tf.float32) if "edge" in self.fixed_weights: tf.add_to_collection(FIXED_COLLECTION, self.edge_weights) _edge_weights = tf.split(self.edge_weights, sizes, axis=0) _edge_weights = [ (tf.nn.sigmoid(ew) if sigmoid else ew) for ew, sigmoid in zip(_edge_weights, sigmoids)] edge_element = tf.concat(_edge_weights, axis=0) edge_element = tf.tile(edge_element[None, :], (self.batch_size, 1)) # --- containers for storing built program --- program = np.empty((H, W, self.B), dtype=np.object) # --- build the program --- is_posterior_tf = tf.ones((self.batch_size, 2)) if is_posterior: is_posterior_tf = is_posterior_tf * [1, 0] else: is_posterior_tf = is_posterior_tf * [0, 1] results = [] for h, w, b in itertools.product(range(H), range(W), range(self.B)): built = dict() partial_program, features = None, None context = self._get_sequential_context(program, h, w, b, edge_element) base_features = tf.concat([inp_features[:, h, w, b, :], context, is_posterior_tf], axis=1) # --- box --- layer_inp = base_features n_features = self.n_passthrough_features output_size = 8 network_output = self.box_network(layer_inp, output_size + n_features, self. is_training) rep_input, features = tf.split(network_output, (output_size, n_features), axis=1) _built = self._build_box(rep_input, self.is_training, hw=(h, w)) built.update(_built) partial_program = built['local_box'] # --- attr --- if is_posterior: # --- Get object attributes using object encoder --- yt, xt, ys, xs = tf.split(built['normalized_box'], 4, axis=-1) yt, xt, ys, xs = coords_to_image_space( yt, xt, ys, xs, (self.image_height, self.image_width), self.anchor_box, top_left=False) transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper( (self.image_height, self.image_width), self.object_shape, transform_constraints) _boxes = tf.concat([xs, 2*xt - 1, ys, 2*yt - 1], axis=-1) grid_coords = warper(_boxes) grid_coords = tf.reshape(grid_coords, (self.batch_size, 1, *self.object_shape, 2,)) if self.edge_resampler: glimpse = resampler_edge.resampler_edge(inp, grid_coords) else: glimpse = tf.contrib.resampler.resampler(inp, grid_coords) glimpse = tf.reshape(glimpse, (self.batch_size, *self.object_shape, self.image_depth)) else: glimpse = tf.zeros((self.batch_size, *self.object_shape, self.image_depth)) # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction encoded_glimpse = self.object_encoder(glimpse, (1, 1, self.A), self.is_training) encoded_glimpse = tf.reshape(encoded_glimpse, (self.batch_size, self.A)) if not is_posterior: encoded_glimpse = tf.zeros_like(encoded_glimpse) layer_inp = tf.concat( [base_features, features, encoded_glimpse, partial_program], axis=1) network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self. is_training) attr_mean, attr_log_std, features = tf.split(network_output, (self.A, self.A, n_features), axis=1) attr_std = self.std_nonlinearity(attr_log_std) attr = Normal(loc=attr_mean, scale=attr_std).sample() built.update(attr_mean=attr_mean, attr_std=attr_std, attr=attr, glimpse=glimpse) partial_program = tf.concat([partial_program, built['attr']], axis=1) # --- z --- layer_inp = tf.concat([base_features, features, partial_program], axis=1) n_features = self.n_passthrough_features network_output = self.z_network(layer_inp, 2 + n_features, self.is_training) z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=1) z_std = self.std_nonlinearity(z_log_std) z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (1-self.training_wheels) * z_mean z_std = self.training_wheels * tf.stop_gradient(z_std) + (1-self.training_wheels) * z_std z_logit = Normal(loc=z_mean, scale=z_std).sample() z = self.z_nonlinearity(z_logit) built.update(z_logit_mean=z_mean, z_logit_std=z_std, z_logit=z_logit, z=z) partial_program = tf.concat([partial_program, built['z']], axis=1) # --- obj --- layer_inp = tf.concat([base_features, features, partial_program], axis=1) rep_input = self.obj_network(layer_inp, 1, self.is_training) _built = self._build_obj(rep_input, self.is_training) built.update(_built) partial_program = tf.concat([partial_program, built['obj']], axis=1) # --- final --- results.append(built) program[h, w, b] = partial_program assert program[h, w, b].shape[1] == total_sample_size objects = AttrDict() for k in results[0]: objects[k] = tf.stack([r[k] for r in results], axis=1) if prop_state is not None: objects.prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1)) objects.prior_prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1)) # --- misc --- objects.n_objects = tf.fill((self.batch_size,), self.HWB) objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2)) objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj), axis=(1, 2)) return objects
def _compute_obj_kl(self, tensors, existing_objects=None): # --- compute obj_kl --- obj_pre_sigmoid = tensors["obj_pre_sigmoid"] obj_log_odds = tensors["obj_log_odds"] obj_prob = tensors["obj_prob"] obj = tensors["obj"] batch_size, n_objects, _ = tf_shape(obj) max_n_objects = n_objects if existing_objects is not None: _, n_existing_objects, _ = tf_shape(existing_objects) existing_objects = tf.reshape(existing_objects, (batch_size, n_existing_objects)) max_n_objects += n_existing_objects count_support = tf.range(max_n_objects+1, dtype=tf.float32) if self.count_prior_dist is not None: if self.count_prior_dist is not None: assert len(self.count_prior_dist) == (max_n_objects + 1) count_distribution = tf.constant(self.count_prior_dist, dtype=tf.float32) else: count_prior_prob = tf.nn.sigmoid(self.count_prior_log_odds) count_distribution = (1 - count_prior_prob) * (count_prior_prob ** count_support) normalizer = tf.reduce_sum(count_distribution) count_distribution = count_distribution / tf.maximum(normalizer, 1e-6) count_distribution = tf.tile(count_distribution[None, :], (batch_size, 1)) if existing_objects is not None: count_so_far = tf.reduce_sum(tf.round(existing_objects), axis=1, keepdims=True) count_distribution = ( count_distribution * tf_binomial_coefficient(count_support, count_so_far) * tf_binomial_coefficient(max_n_objects - count_support, n_existing_objects - count_so_far) ) normalizer = tf.reduce_sum(count_distribution, axis=1, keepdims=True) count_distribution = count_distribution / tf.maximum(normalizer, 1e-6) else: count_so_far = tf.zeros((batch_size, 1), dtype=tf.float32) obj_kl = [] for i in range(n_objects): p_z_given_Cz_raw = (count_support[None, :] - count_so_far) / (max_n_objects - i) p_z_given_Cz = tf.clip_by_value(p_z_given_Cz_raw, 0.0, 1.0) # Doing this instead of 1 - p_z_given_Cz seems to be more numerically stable. inv_p_z_given_Cz_raw = (max_n_objects - i - count_support[None, :] + count_so_far) / (max_n_objects - i) inv_p_z_given_Cz = tf.clip_by_value(inv_p_z_given_Cz_raw, 0.0, 1.0) p_z = tf.reduce_sum(count_distribution * p_z_given_Cz, axis=1, keepdims=True) if self.use_concrete_kl: prior_log_odds = tf_safe_log(p_z) - tf_safe_log(1-p_z) _obj_kl = concrete_binary_sample_kl( obj_pre_sigmoid[:, i, :], obj_log_odds[:, i, :], self.obj_concrete_temp, prior_log_odds, self.obj_concrete_temp, ) else: prob = obj_prob[:, i, :] _obj_kl = ( prob * (tf_safe_log(prob) - tf_safe_log(p_z)) + (1-prob) * (tf_safe_log(1-prob) - tf_safe_log(1-p_z)) ) obj_kl.append(_obj_kl) sample = tf.to_float(obj[:, i, :] > 0.5) mult = sample * p_z_given_Cz + (1-sample) * inv_p_z_given_Cz raw_count_distribution = mult * count_distribution normalizer = tf.reduce_sum(raw_count_distribution, axis=1, keepdims=True) normalizer = tf.maximum(normalizer, 1e-6) # invalid = tf.logical_and(p_z_given_Cz_raw > 1, count_distribution > 1e-8) # float_invalid = tf.cast(invalid, tf.float32) # diagnostic = tf.stack( # [float_invalid, p_z_given_Cz, count_distribution, mult, raw_count_distribution], axis=-1) # assert_op = tf.Assert( # tf.reduce_all(tf.logical_not(invalid)), # [invalid, diagnostic, count_so_far, sample, tf.constant(i, dtype=tf.float32)], # summarize=100000) count_distribution = raw_count_distribution / normalizer count_so_far += sample # this avoids buildup of inaccuracies that can cause problems in computing p_z_given_Cz_raw count_so_far = tf.round(count_so_far) obj_kl = tf.reshape(tf.concat(obj_kl, axis=1), (batch_size, n_objects, 1)) return obj_kl
def build_representation(self): processed_image = index.tile_input_for_iwae( tf.transpose(self.inp, (1, 0, 2, 3, 4)), self.k_particles, with_time=True) shape = list(tf_shape(processed_image)) shape[1] = cfg.batch_size * self.k_particles processed_image = tf.reshape(processed_image, shape) self._tensors.update( processed_image=processed_image, mean_img=self.data['mean_img'], ) _, _, *img_size = processed_image.shape.as_list() layers = [self.n_hidden] * self.n_layers def glimpse_encoder(): return AIREncoder(img_size, self.object_shape, self.n_what, Encoder(layers), masked_glimpse=self.masked_glimpse, debug=self.debug) steps_pred_hidden = self.n_hidden / 2 transform_estimator = partial(StochasticTransformParam, layers, self.transform_var_bias) if self.fixed_presence: disc_steps_predictor = partial(FixedStepsPredictor, discovery=True) else: disc_steps_predictor = partial(StepsPredictor, steps_pred_hidden, self.disc_step_bias) if cfg.build_input_encoder is None: input_encoder = partial(Encoder, layers) else: input_encoder = cfg.build_input_encoder _input_encoder = input_encoder() T, B, *rest = tf_shape(processed_image) images = tf.reshape(processed_image, (T*B, *rest)) encoded_input = _input_encoder(images) encoded_input = tf.reshape(encoded_input, (T, B, *tf_shape(encoded_input)[1:])) with tf.variable_scope('discovery'): discover_cell = DiscoveryCore( processed_image, encoded_input, self.object_shape, self.n_what, self.rnn_class(self.n_hidden), glimpse_encoder, transform_estimator, disc_steps_predictor, debug=self.debug) if self.fast_discovery: object_state_predictor = MLP([256, 256, self.n_hidden]) discover = FastDiscover( object_state_predictor, self.n_objects, discover_cell, step_success_prob=self.step_success_prob, where_mean=[*self.scale_prior, 0, 0], disc_prior_type=self.disc_prior_type, rec_where_prior=self.rec_where_prior) else: discover = Discover( self.n_objects, discover_cell, step_success_prob=self.step_success_prob, where_mean=[*self.scale_prior, 0, 0], disc_prior_type=self.disc_prior_type, rec_where_prior=self.rec_where_prior) with tf.variable_scope('propagation'): # Prop cell should have a different rnn cell but should share all other estimators glimpse_encoder = lambda: discover_cell._glimpse_encoder transform_estimator = partial(StochasticTransformParam, layers, self.transform_var_bias) if self.fixed_presence: prop_steps_predictor = partial(FixedStepsPredictor, discovery=False) else: prop_steps_predictor = partial(StepsPredictor, steps_pred_hidden, self.prop_step_bias) prior_rnn = self.prior_rnn_class(self.n_hidden) propagation_prior = make_prior(self.prop_prior_type, self.n_what, prior_rnn, self.prop_prior_step_bias) propagate_rnn_cell = self.rnn_class(self.n_hidden) temporal_rnn_cell = self.time_rnn_class(self.n_hidden) propagation_cell = PropagationCore(processed_image, encoded_input, self.object_shape, self.n_what, propagate_rnn_cell, glimpse_encoder, transform_estimator, prop_steps_predictor, temporal_rnn_cell, debug=self.debug) if self.fast_propagation: propagate = FastPropagate(propagation_cell, propagation_prior) else: propagate = Propagate(propagation_cell, propagation_prior) with tf.variable_scope('decoder'): glimpse_decoder = partial(Decoder, layers, output_scale=self.output_scale) decoder = AIRDecoder(img_size, self.object_shape, glimpse_decoder, batch_dims=2, mean_img=self._tensors.mean_img, output_std=self.output_std, scale_bounds=self.scale_bounds) with tf.variable_scope('sequence'): time_cell = self.time_rnn_class(self.n_hidden) sequence_apdr = SequentialAIR( self.n_objects, self.object_shape, discover, propagate, time_cell, decoder, prior_start_step=self._prior_start_step) outputs = sequence_apdr(processed_image) outputs['where_coords'] = decoder._transformer.to_coords(outputs['where']) self._tensors.update(outputs)
def build_background(self): if cfg.background_cfg.mode == "colour": rgb = np.array(to_rgb(cfg.background_cfg.colour))[None, None, None, :] background = rgb * tf.ones_like(self.inp) elif cfg.background_cfg.mode == "learn_solid": # Learn a solid colour for the background self.solid_background_logits = tf.get_variable("solid_background", initializer=[0.0, 0.0, 0.0]) if "background" in self.fixed_weights: tf.add_to_collection(FIXED_COLLECTION, self.solid_background_logits) solid_background = tf.nn.sigmoid(10 * self.solid_background_logits) background = solid_background[None, None, None, :] * tf.ones_like(self.inp) elif cfg.background_cfg.mode == "scalor": pass elif cfg.background_cfg.mode == "learn": self.maybe_build_subnet("background_encoder") self.maybe_build_subnet("background_decoder") # Here I'm just encoding the first frame... bg_attr = self.background_encoder(self.inp[:, 0], 2 * cfg.background_cfg.A, self.is_training) bg_attr_mean, bg_attr_log_std = tf.split(bg_attr, 2, axis=-1) bg_attr_std = tf.exp(bg_attr_log_std) if not self.noisy: bg_attr_std = tf.zeros_like(bg_attr_std) bg_attr, bg_attr_kl = normal_vae(bg_attr_mean, bg_attr_std, self.attr_prior_mean, self.attr_prior_std) self._tensors.update( bg_attr_mean=bg_attr_mean, bg_attr_std=bg_attr_std, bg_attr_kl=bg_attr_kl, bg_attr=bg_attr) # --- decode --- _, T, H, W, _ = tf_shape(self.inp) background = self.background_decoder(bg_attr, 3, self.is_training) assert len(background.shape) == 2 and background.shape[1] == 3 background = tf.nn.sigmoid(tf.clip_by_value(background, -10, 10)) background = tf.tile(background[:, None, None, None, :], (1, T, H, W, 1)) elif cfg.background_cfg.mode == "learn_and_transform": self.maybe_build_subnet("background_encoder") self.maybe_build_subnet("background_decoder") # --- encode --- n_transform_latents = 4 n_latents = (2 * cfg.background_cfg.A, 2 * n_transform_latents) bg_attr, bg_transform_params = self.background_encoder(self.inp, n_latents, self.is_training) # --- bg attributes --- bg_attr_mean, bg_attr_log_std = tf.split(bg_attr, 2, axis=-1) bg_attr_std = self.std_nonlinearity(bg_attr_log_std) bg_attr, bg_attr_kl = normal_vae(bg_attr_mean, bg_attr_std, self.attr_prior_mean, self.attr_prior_std) # --- bg location --- bg_transform_params = tf.reshape( bg_transform_params, (self.batch_size, self.dynamic_n_frames, 2*n_transform_latents)) mean, log_std = tf.split(bg_transform_params, 2, axis=2) std = self.std_nonlinearity(log_std) logits, kl = normal_vae(mean, std, 0.0, 1.0) # integrate across timesteps logits = tf.cumsum(logits, axis=1) logits = tf.reshape(logits, (self.batch_size*self.dynamic_n_frames, n_transform_latents)) y, x, h, w = tf.split(logits, n_transform_latents, axis=1) h = (0.9 - 0.5) * tf.nn.sigmoid(h) + 0.5 w = (0.9 - 0.5) * tf.nn.sigmoid(w) + 0.5 y = (1 - h) * tf.nn.tanh(y) x = (1 - w) * tf.nn.tanh(x) # --- decode --- background = self.background_decoder(bg_attr, self.image_depth, self.is_training) bg_shape = cfg.background_cfg.bg_shape background = background[:, :bg_shape[0], :bg_shape[1], :] assert background.shape[1:3] == bg_shape background_raw = tf.nn.sigmoid(tf.clip_by_value(background, -10, 10)) transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper( bg_shape, (self.image_height, self.image_width), transform_constraints) transforms = tf.concat([w, x, h, y], axis=-1) grid_coords = warper(transforms) grid_coords = tf.reshape( grid_coords, (self.batch_size, self.dynamic_n_frames, *tf_shape(grid_coords)[1:])) background = tf.contrib.resampler.resampler(background_raw, grid_coords) self._tensors.update( bg_attr_mean=bg_attr_mean, bg_attr_std=bg_attr_std, bg_attr_kl=bg_attr_kl, bg_attr=bg_attr, bg_y=tf.reshape(y, (self.batch_size, self.dynamic_n_frames, 1)), bg_x=tf.reshape(x, (self.batch_size, self.dynamic_n_frames, 1)), bg_h=tf.reshape(h, (self.batch_size, self.dynamic_n_frames, 1)), bg_w=tf.reshape(w, (self.batch_size, self.dynamic_n_frames, 1)), bg_transform_kl=kl, bg_raw=background_raw, ) elif cfg.background_cfg.mode == "data": background = self._tensors["ground_truth_background"] else: raise Exception("Unrecognized background mode: {}.".format(cfg.background_cfg.mode)) self._tensors["background"] = background[:, :self.dynamic_n_frames]
def _find_connected_componenents_body(mask): components = tf.contrib.image.connected_components(mask) total_n_objects = tf.to_int32(tf.reduce_max(components)) indices = tf.range(1, total_n_objects + 1) maxs = tf.reduce_max(components, axis=(1, 2)) # So that we don't pick up zeros. for_mins = tf.where(mask, components, (total_n_objects + 1) * tf.ones_like(components)) mins = tf.reduce_min(for_mins, axis=(1, 2)) n_objects = tf.to_int32(tf.maximum((maxs - mins) + 1, 0)) under = indices[None, :] <= maxs[:, None] over = indices[None, :] >= mins[:, None] both = tf.to_int32(tf.logical_and(under, over)) batch_indices_for_objects = tf.argmax(both, axis=0) assert_valid_batch_indices = tf.Assert(tf.reduce_all( tf.equal(tf.reduce_sum(both, axis=0), 1)), [both], name="assert_valid_batch_indices") with tf.control_dependencies([assert_valid_batch_indices]): batch_indices_for_objects = tf.identity(batch_indices_for_objects) _, image_height, image_width, *_ = tf_shape(mask) cell = BboxCell(components, batch_indices_for_objects, image_height, image_width) # For each object, get its bounding box by using `indices` to figure out which element of # `components` the object appears in, and then check that element object_bboxes, _ = dynamic_rnn(cell, indices[:, None, None], initial_state=cell.zero_state( 1, tf.float32), parallel_iterations=10, swap_memory=False, time_major=True) # Couldn't I have just iterated through all object indices and used tf.where on `components` to simultaneously # get both the bounding box and the batch index? Yes, but I think I thought that would be expensive # (have to look through the entirety of `components` once for each object). # Get rid of dummy batch dim created for dynamic_rnn object_bboxes = object_bboxes[:, 0, :] obj = tf.sequence_mask(n_objects) routing = tf.reshape(tf.to_int32(obj), (-1, )) routing = tf.cumsum(routing, exclusive=True) routing = tf.reshape(routing, tf.shape(obj)) obj = tf.to_float(obj[:, :, None]) return dict( normalized_box=tf.gather(object_bboxes, routing, axis=0), obj=obj, n_objects=n_objects, max_objects=tf.reduce_max(n_objects), )
def _call(self, objects, background, is_training, appearance_only=False): if not self.initialized: self.image_depth = tf_shape(background)[-1] self.maybe_build_subnet("object_decoder") # --- compute sprite appearance from attr using object decoder --- appearance_logits = apply_object_wise( self.object_decoder, objects.attr, self.object_shape + (self.image_depth + 1, ), is_training) appearance_logits = appearance_logits * ([self.color_logit_scale] * 3 + [self.alpha_logit_scale]) appearance_logits = appearance_logits + ([0.] * 3 + [self.alpha_logit_bias]) appearance = tf.nn.sigmoid( tf.clip_by_value(appearance_logits, -10., 10.)) if appearance_only: return dict(appearance=appearance) appearance_for_output = appearance batch_size, *obj_leading_shape, _, _, _ = tf_shape(appearance) n_objects = np.prod(obj_leading_shape) appearance = tf.reshape( appearance, (batch_size, n_objects, *self.object_shape, self.image_depth + 1)) obj_colors, obj_alpha = tf.split(appearance, [self.image_depth, 1], axis=-1) if "alpha" in self.no_gradient: obj_alpha = tf.stop_gradient(obj_alpha) if "alpha" in self.fixed_values: obj_alpha = float( self.fixed_values["alpha"]) * tf.ones_like(obj_alpha) obj_alpha *= tf.reshape(objects.obj, (batch_size, n_objects, 1, 1, 1)) z = tf.reshape(objects.z, (batch_size, n_objects, 1, 1, 1)) obj_importance = tf.maximum(obj_alpha * z, 0.01) object_maps = tf.concat([obj_colors, obj_alpha, obj_importance], axis=-1) ys, xs, yt, xt = objects.ys, objects.xs, objects.yt, objects.xt scales = tf.concat([ys, xs], axis=-1) scales = tf.reshape(scales, (batch_size, n_objects, 2)) offsets = tf.concat([yt, xt], axis=-1) offsets = tf.reshape(offsets, (batch_size, n_objects, 2)) # --- Compose images --- n_objects_per_image = tf.fill((batch_size, ), int(n_objects)) output = render_sprites.render_sprites(object_maps, n_objects_per_image, scales, offsets, background) return dict(appearance=appearance_for_output, output=output)
def build_representation(self): # --- init modules --- if self.encoder is None: self.encoder = self.build_encoder(scope="encoder") if "encoder" in self.fixed_weights: self.encoder.fix_variables() if self.cell is None and self.build_cell is not None: self.cell = cfg.build_cell(scope="cell") if "cell" in self.fixed_weights: self.cell.fix_variables() if self.decoder is None: self.decoder = cfg.build_decoder(scope="decoder") if "decoder" in self.fixed_weights: self.decoder.fix_variables() # --- encode --- inp_trailing_shape = tf_shape(self.inp)[2:] video = tf.reshape(self.inp, (self.batch_size * self.dynamic_n_frames, *inp_trailing_shape)) encoder_output = self.encoder(video, 2 * self.A, self.is_training) eo_trailing_shape = tf_shape(encoder_output)[1:] encoder_output = tf.reshape( encoder_output, (self.batch_size, self.dynamic_n_frames, *eo_trailing_shape)) if self.cell is None: attr = encoder_output else: if self.flat_latent: n_trailing_dims = int(np.prod(eo_trailing_shape)) encoder_output = tf.reshape( encoder_output, (self.batch_size, self.dynamic_n_frames, n_trailing_dims)) else: raise Exception("NotImplemented") n_objects = int(np.prod(eo_trailing_shape[:-1])) D = eo_trailing_shape[-1] encoder_output = tf.reshape( encoder_output, (self.batch_size, self.dynamic_n_frames, n_objects, D)) encoder_output = tf.layers.flatten(encoder_output) attr, final_state = dynamic_rnn( self.cell, encoder_output, initial_state=self.cell.zero_state(self.batch_size, tf.float32), parallel_iterations=1, swap_memory=False, time_major=False) attr_mean, attr_log_std = tf.split(attr, 2, axis=-1) attr_std = tf.math.softplus(attr_log_std) if not self.noisy: attr_std = tf.zeros_like(attr_std) attr, attr_kl = normal_vae(attr_mean, attr_std, self.attr_prior_mean, self.attr_prior_std) self._tensors.update(attr_mean=attr_mean, attr_std=attr_std, attr_kl=attr_kl, attr=attr) # --- decode --- decoder_input = tf.reshape(attr, (self.batch_size*self.dynamic_n_frames, *tf_shape(attr)[2:])) reconstruction = self.decoder(decoder_input, tf_shape(self.inp)[2:], self.is_training) reconstruction = reconstruction[:, :self.obs_shape[1], :self.obs_shape[2], :] reconstruction = tf.reshape(reconstruction, (self.batch_size, self.dynamic_n_frames, *self.obs_shape[1:])) reconstruction = tf.nn.sigmoid(tf.clip_by_value(reconstruction, -10, 10)) self._tensors["output"] = reconstruction # --- losses --- if self.train_kl: self.losses['attr_kl'] = tf_mean_sum(self._tensors["attr_kl"]) if self.train_reconstruction: self._tensors['per_pixel_reconstruction_loss'] = xent_loss(pred=reconstruction, label=self.inp) self.losses['reconstruction'] = tf_mean_sum(self._tensors['per_pixel_reconstruction_loss'])
def _call(self, objects, background, is_training, appearance_only=False, mask_only=False): """ If mask_only==True, then we ignore the provided background, using a black blackground instead, and also ignore the computed appearance, using all-white appearances instead. """ if not self.initialized: self.image_depth = tf_shape(background)[-1] single = False if isinstance(objects, dict): single = True objects = [objects] _object_maps = [] _scales = [] _offsets = [] _appearance = [] for i, obj in enumerate(objects): anchor_box = self.anchor_boxes[i] object_shape = self.object_shapes[i] object_decoder = self.maybe_build_subnet( "object_decoder_for_flight_{}".format(i), builder_name='build_object_decoder') # --- compute sprite appearance from attr using object decoder --- appearance_logit = apply_object_wise( object_decoder, obj.attr, output_size=object_shape + (self.image_depth+1,), is_training=is_training) appearance_logit = appearance_logit * ([self.color_logit_scale] * self.image_depth + [self.alpha_logit_scale]) appearance_logit = appearance_logit + ([0.] * self.image_depth + [self.alpha_logit_bias]) appearance = tf.nn.sigmoid(tf.clip_by_value(appearance_logit, -10., 10.)) _appearance.append(appearance) if appearance_only: continue batch_size, *obj_leading_shape, _, _, _ = tf_shape(appearance) n_objects = np.prod(obj_leading_shape) appearance = tf.reshape( appearance, (batch_size, n_objects, *object_shape, self.image_depth+1)) obj_colors, obj_alpha = tf.split(appearance, [self.image_depth, 1], axis=-1) if mask_only: obj_colors = tf.ones_like(obj_colors) obj_alpha *= tf.reshape(obj.obj, (batch_size, n_objects, 1, 1, 1)) z = tf.reshape(obj.z, (batch_size, n_objects, 1, 1, 1)) obj_importance = tf.maximum(obj_alpha * z / self.importance_temp, 0.01) object_maps = tf.concat([obj_colors, obj_alpha, obj_importance], axis=-1) *_, image_height, image_width, _ = tf_shape(background) yt, xt, ys, xs = coords_to_image_space( obj.yt, obj.xt, obj.ys, obj.xs, (image_height, image_width), anchor_box, top_left=True) scales = tf.concat([ys, xs], axis=-1) scales = tf.reshape(scales, (batch_size, n_objects, 2)) offsets = tf.concat([yt, xt], axis=-1) offsets = tf.reshape(offsets, (batch_size, n_objects, 2)) _object_maps.append(object_maps) _scales.append(scales) _offsets.append(offsets) if single: _appearance = _appearance[0] if appearance_only: return dict(appearance=_appearance) if mask_only: background = tf.zeros_like(background) # --- Compose images --- output = render_sprites.render_sprites( _object_maps, _scales, _offsets, background ) return dict( appearance=_appearance, output=output)
def _compute_obj_kl(self, tensors, existing_objects=None): # --- compute obj_kl --- obj_pre_sigmoid = tensors["obj_pre_sigmoid"] obj_log_odds = tensors["obj_log_odds"] obj_prob = tensors["obj_prob"] obj = tensors["obj"] batch_size, n_objects, _ = tf_shape(obj) max_n_objects = n_objects if existing_objects is not None: _, n_existing_objects, _ = tf_shape(existing_objects) existing_objects = tf.reshape(existing_objects, (batch_size, n_existing_objects)) max_n_objects += n_existing_objects count_support = tf.range(max_n_objects + 1, dtype=tf.float32) if self.count_prior_dist is not None: if self.count_prior_dist is not None: assert len(self.count_prior_dist) == (max_n_objects + 1) count_distribution = tf.constant(self.count_prior_dist, dtype=tf.float32) else: count_prior_prob = tf.nn.sigmoid(self.count_prior_log_odds) count_distribution = (1 - count_prior_prob) * (count_prior_prob** count_support) normalizer = tf.reduce_sum(count_distribution) count_distribution = count_distribution / tf.maximum(normalizer, 1e-6) count_distribution = tf.tile(count_distribution[None, :], (batch_size, 1)) if existing_objects is not None: count_so_far = tf.reduce_sum(tf.round(existing_objects), axis=1, keepdims=True) count_distribution = ( count_distribution * tf_binomial_coefficient(count_support, count_so_far) * tf_binomial_coefficient(max_n_objects - count_support, n_existing_objects - count_so_far)) normalizer = tf.reduce_sum(count_distribution, axis=1, keepdims=True) count_distribution = count_distribution / tf.maximum( normalizer, 1e-6) else: count_so_far = tf.zeros((batch_size, 1), dtype=tf.float32) obj_kl = [] for i in range(n_objects): p_z_given_Cz = tf.maximum(count_support[None, :] - count_so_far, 0) / (max_n_objects - i) # Reshape for batch matmul _count_distribution = count_distribution[:, None, :] _p_z_given_Cz = p_z_given_Cz[:, :, None] p_z = tf.matmul(_count_distribution, _p_z_given_Cz)[:, :, 0] if self.use_concrete_kl: prior_log_odds = tf_safe_log(p_z) - tf_safe_log(1 - p_z) _obj_kl = concrete_binary_sample_kl( obj_pre_sigmoid[:, i, :], obj_log_odds[:, i, :], self.obj_concrete_temp, prior_log_odds, self.obj_concrete_temp, ) else: prob = obj_prob[:, i, :] _obj_kl = (prob * (tf_safe_log(prob) - tf_safe_log(p_z)) + (1 - prob) * (tf_safe_log(1 - prob) - tf_safe_log(1 - p_z))) obj_kl.append(_obj_kl) sample = tf.to_float(obj[:, i, :] > 0.5) mult = sample * p_z_given_Cz + (1 - sample) * (1 - p_z_given_Cz) count_distribution = mult * count_distribution normalizer = tf.reduce_sum(count_distribution, axis=1, keepdims=True) normalizer = tf.maximum(normalizer, 1e-6) count_distribution = count_distribution / normalizer count_so_far += sample obj_kl = tf.reshape(tf.concat(obj_kl, axis=1), (batch_size, n_objects, 1)) return obj_kl
def _call(self, input_locs, input_features, reference_locs, reference_features, is_training): """ input_features: (B, n_inp, n_hidden) input_locs: (B, n_inp, loc_dim) reference_locs: (B, n_ref, loc_dim) """ assert (reference_features is not None) == self.do_object_wise if not self.is_built: self.relation_func = self.build_mlp(scope="relation_func") if self.do_object_wise: self.object_wise_func = self.build_mlp( scope="object_wise_func") self.is_built = True loc_dim = tf_shape(input_locs)[-1] n_ref = tf_shape(reference_locs)[-2] batch_size, n_inp, _ = tf_shape(input_features) input_locs = tf.broadcast_to(input_locs, (batch_size, n_inp, loc_dim)) reference_locs = tf.broadcast_to(reference_locs, (batch_size, n_ref, loc_dim)) adjusted_locs = input_locs[:, None, :, :] - reference_locs[:, :, None, :] # (B, n_ref, n_inp, loc_dim) adjusted_features = tf.tile( input_features[:, None], (1, n_ref, 1, 1)) # (B, n_ref, n_inp, features_dim) relation_input = tf.concat([adjusted_features, adjusted_locs], axis=-1) if self.do_object_wise: object_wise = apply_object_wise( self.object_wise_func, reference_features, output_size=self.n_hidden, is_training=is_training) # (B, n_ref, n_hidden) _object_wise = tf.tile(object_wise[:, :, None], (1, 1, n_inp, 1)) relation_input = tf.concat([relation_input, _object_wise], axis=-1) else: object_wise = None V = apply_object_wise( self.relation_func, relation_input, output_size=self.n_hidden, is_training=is_training) # (B, n_ref, n_inp, n_hidden) attention_weights = tf.exp(-0.5 * tf.reduce_sum( (adjusted_locs / self.kernel_std)**2, axis=3)) attention_weights = (attention_weights / (2 * np.pi)**(loc_dim / 2) / self.kernel_std**loc_dim) # (B, n_ref, n_inp) result = tf.reduce_sum(V * attention_weights[..., None], axis=2) # (B, n_ref, n_hidden) if self.do_object_wise: result += object_wise # result = tf.contrib.layers.layer_norm(result) return result
def _call(self, objects, background, is_training, appearance_only=False): if not self.initialized: self.image_depth = tf_shape(background)[-1] self.maybe_build_subnet("object_decoder") # --- compute sprite appearance from attr using object decoder --- appearance_logit = apply_object_wise( self.object_decoder, objects.attr, output_size=self.object_shape + (self.image_depth+1,), is_training=is_training) appearance_logit = appearance_logit * ([self.color_logit_scale] * 3 + [self.alpha_logit_scale]) appearance_logit = appearance_logit + ([0.] * 3 + [self.alpha_logit_bias]) appearance = tf.nn.sigmoid(tf.clip_by_value(appearance_logit, -10., 10.)) if appearance_only: return dict(appearance=appearance) appearance_for_output = appearance batch_size, *obj_leading_shape, _, _, _ = tf_shape(appearance) n_objects = np.prod(obj_leading_shape) appearance = tf.reshape( appearance, (batch_size, n_objects, *self.object_shape, self.image_depth+1)) obj_colors, obj_alpha = tf.split(appearance, [self.image_depth, 1], axis=-1) obj_alpha *= tf.reshape(objects.render_obj, (batch_size, n_objects, 1, 1, 1)) z = tf.reshape(objects.z, (batch_size, n_objects, 1, 1, 1)) obj_importance = tf.maximum(obj_alpha * z / self.importance_temp, 0.01) object_maps = tf.concat([obj_colors, obj_alpha, obj_importance], axis=-1) *_, image_height, image_width, _ = tf_shape(background) yt, xt, ys, xs = coords_to_image_space( objects.yt, objects.xt, objects.ys, objects.xs, (image_height, image_width), self.anchor_box, top_left=True) scales = tf.concat([ys, xs], axis=-1) scales = tf.reshape(scales, (batch_size, n_objects, 2)) offsets = tf.concat([yt, xt], axis=-1) offsets = tf.reshape(offsets, (batch_size, n_objects, 2)) # --- Compose images --- n_objects_per_image = tf.fill((batch_size,), int(n_objects)) output = render_sprites.render_sprites( object_maps, n_objects_per_image, scales, offsets, background ) return dict( appearance=appearance_for_output, output=output)
def build_representation(self): # --- init modules --- self.B = len(self.anchor_boxes) if self.backbone is None: self.backbone = self.build_backbone(scope="backbone") if "backbone" in self.fixed_weights: self.backbone.fix_variables() if self.feature_fuser is None: self.feature_fuser = self.build_feature_fuser(scope="feature_fuser") if "feature_fuser" in self.fixed_weights: self.feature_fuser.fix_variables() if self.obj_feature_extractor is None and self.build_obj_feature_extractor is not None: self.obj_feature_extractor = self.build_obj_feature_extractor(scope="obj_feature_extractor") if "obj_feature_extractor" in self.fixed_weights: self.obj_feature_extractor.fix_variables() backbone_output, n_grid_cells, grid_cell_size = self.backbone( self.inp, self.B*self.n_backbone_features, self.is_training) self.H, self.W = [int(i) for i in n_grid_cells] self.HWB = self.H * self.W * self.B self.pixels_per_cell = tuple(int(i) for i in grid_cell_size) H, W, B = self.H, self.W, self.B if self.object_layer is None: self.object_layer = ObjectLayer(self.pixels_per_cell, scope="objects") self.object_rep_tensors = [] object_rep_tensors = None _tensors = defaultdict(list) for f in range(self.n_frames): print("Bulding network for frame {}".format(f)) early_frame_features = backbone_output[:, f] if f > 0 and self.obj_feature_extractor is not None: object_features = object_rep_tensors["all"] object_features = tf.reshape( object_features, (self.batch_size, H, W, B*tf_shape(object_features)[-1])) early_frame_features += self.obj_feature_extractor( object_features, B*self.n_backbone_features, self.is_training) frame_features = self.feature_fuser( early_frame_features, B*self.n_backbone_features, self.is_training) frame_features = tf.reshape( frame_features, (self.batch_size, H, W, B, self.n_backbone_features)) object_rep_tensors = self.object_layer( self.inp[:, f], frame_features, self._tensors["background"][:, f], self.is_training) self.object_rep_tensors.append(object_rep_tensors) for k, v in object_rep_tensors.items(): _tensors[k].append(v) self._tensors.update(**{k: tf.stack(v, axis=1) for k, v in _tensors.items()}) # --- specify values to record --- obj = self._tensors["obj"] pred_n_objects = self._tensors["pred_n_objects"] self.record_tensors( batch_size=self.batch_size, float_is_training=self.float_is_training, cell_y=self._tensors["cell_y"], cell_x=self._tensors["cell_x"], h=self._tensors["h"], w=self._tensors["w"], z=self._tensors["z"], area=self._tensors["area"], cell_y_std=self._tensors["cell_y_std"], cell_x_std=self._tensors["cell_x_std"], h_std=self._tensors["h_std"], w_std=self._tensors["w_std"], z_std=self._tensors["z_std"], n_objects=pred_n_objects, obj=obj, latent_area=self._tensors["latent_area"], latent_hw=self._tensors["latent_hw"], attr=self._tensors["attr"], ) # --- losses --- if self.train_reconstruction: output = self._tensors['output'] inp = self._tensors['inp'] self._tensors['per_pixel_reconstruction_loss'] = xent_loss(pred=output, label=inp) self.losses['reconstruction'] = ( self.reconstruction_weight * tf_mean_sum(self._tensors['per_pixel_reconstruction_loss']) ) if self.train_kl: self.losses.update( obj_kl=self.kl_weight * tf_mean_sum(self._tensors["obj_kl"]), cell_y_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["cell_y_kl"]), cell_x_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["cell_x_kl"]), h_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["h_kl"]), w_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["w_kl"]), z_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["z_kl"]), attr_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["attr_kl"]), ) if cfg.background_cfg.mode == "learn_and_transform": self.losses.update( bg_attr_kl=self.kl_weight * tf_mean_sum(self._tensors["bg_attr_kl"]), bg_transform_kl=self.kl_weight * tf_mean_sum(self._tensors["bg_transform_kl"]), ) # --- other evaluation metrics --- if "n_annotations" in self._tensors: count_1norm = tf.to_float( tf.abs(tf.to_int32(self._tensors["pred_n_objects_hard"]) - self._tensors["n_valid_annotations"])) self.record_tensors( count_1norm=count_1norm, count_error=count_1norm > 0.5, )
def __call__(self, inp): output = self.wrapped(inp, None, True)[0] batch_size = tf_shape(output)[0] n_trailing = np.prod(tf_shape(output)[1:]) return tf.reshape(output, (batch_size, n_trailing))