def _build(self, zp_old, inputs): sg = shapeguard.ShapeGuard() sg.guard(zp_old, "B, K, Zp") sg.guard(inputs, "B, K, H") update = snt.Linear(sg.Zp) update_gate = snt.Linear(sg.Zp) predict = snt.nets.MLP( output_sizes=list(self._hidden_sizes) + [sg.Zp * 2], activation=self._activation, ) flat_zp = sg.reshape(zp_old, "B*K, Zp") flat_inputs = sg.reshape(inputs, "B*K, H") g = tf.nn.sigmoid(update_gate(flat_inputs) + self._corrector_gate_bias) u = update(flat_inputs) # a slightly more efficient way of computing the gated update # (1-g) * flat_zp + g * u zp_corrected = flat_zp + g * (u - flat_zp) predicted = predict(flat_zp) pred_up = predicted[:, :sg.Zp] pred_gate = tf.nn.sigmoid(predicted[:, sg.Zp:] + self._pred_gate_bias) zp = zp_corrected + pred_gate * (pred_up - zp_corrected) return sg.reshape(zp, "B, K, Zp")
def _build(self, image): """Connect model to TensorFlow graph.""" assert self._mlp_opt["output_sizes"][ -1] is not None, "set output_shapes" sg = shapeguard.ShapeGuard() flat_image, unflatten = flatten_all_but_last(image, n_dims=3) sg.guard(flat_image, "B, H, W, C") cnn = snt.nets.ConvNet2D(activate_final=True, paddings=("SAME", ), normalize_final=False, **self._cnn_opt) mlp = snt.nets.MLP(**self._mlp_opt) # run CNN net = cnn(flat_image) if self._mode == "flatten": # flatten net_shape = net.get_shape().as_list() flat_shape = net_shape[:-3] + [np.prod(net_shape[-3:])] net = tf.reshape(net, flat_shape) elif self._mode == "avg_pool": net = tf.reduce_mean(net, axis=[1, 2]) else: raise KeyError('Unknown mode "{}"'.format(self._mode)) # run MLP output = sg.guard(mlp(net), "B, Y") return FlatParameters(unflatten(output))
def _build(self, z): """Connect model to TensorFlow graph.""" assert self._target_out_shape is not None, "Call set_output_shape" # reshape components into batch dimension before processing them sg = shapeguard.ShapeGuard() flat_z, unflatten = flatten_all_but_last(z) sg.guard(flat_z, "B, Z") sg.guard(self._target_out_shape, "H, W, C") if self._mlp_opt is None: mlp = tf.identity else: mlp = snt.nets.MLP(activate_final=True, **self._mlp_opt) mlp_output = sg.guard(mlp(flat_z), "B, hidden") # tile MLP output spatially and append coordinate channels broadcast_mlp_output = tf.tile( mlp_output[:, tf.newaxis, tf.newaxis], multiples=tf.constant(sg["1, H, W, 1"]), ) # B, H, W, Z dec_cnn_inputs = self.append_coordinate_channels(broadcast_mlp_output) cnn = snt.nets.ConvNet2D(paddings=("SAME", ), normalize_final=False, **self._cnn_opt) cnn_outputs = cnn(dec_cnn_inputs) sg.guard(cnn_outputs, "B, H, W, C") return FlatParameters(unflatten(cnn_outputs))
def decode(self, z): sg = shapeguard.ShapeGuard() sg.guard(z, "B, K, Z") # legacy z = tf.concat([z, 5.0 * tf.ones(sg["B, K, 1"], dtype=tf.float32)], axis=2) params = self.decoder(z) out_dist = self.output_dist(*params) return params, out_dist
def __init__(self, encoder_net, recurrent_net, refinement_head, name="refinement"): super().__init__(name=name) self._encoder_net = encoder_net self._recurrent_net = recurrent_net self._refinement_head = refinement_head self._sg = shapeguard.ShapeGuard()
def __init__( self, decoder, refinement_core, latent_dist, output_dist, n_z, num_components, num_iters, sequential=False, factor_evaluator=None, stop_gradients=DEFAULT_STOP_GRADIENT, iter_loss_weight="linspace", inputs=DEFAULT_INPUTS, preprocess=None, coord_type="linear", coord_freqs=3, name="iodine", ): super().__init__(name=name) self._sg = shapeguard.ShapeGuard(dims={"K": num_components}) self.decoder = decoder self.refinement_core = refinement_core self.latent_dist = latent_dist self.output_dist = output_dist self.n_z = n_z self.num_components = num_components self.num_iters = num_iters self.sequential = sequential self.iter_loss_weights = self._parse_iter_loss_weights( iter_loss_weight) self.factor_evaluator = factor_evaluator self.stop_gradients = stop_gradients self.inputs = inputs self.preprocess = DEFAULT_PREPROCESSING if preprocess is None else preprocess self.coord_type = coord_type self.coord_freqs = coord_freqs with self._enter_variable_scope(): self.latent_dist.set_output_shape([self.n_z]) logging.info("VAE: z shape: %s", [self.n_z]) with tf.name_scope("prior"): self.prior = self.latent_dist.get_default_prior( (self.num_components, )) self._sg.guard(self.prior, "K, Z") with tf.variable_scope("preprocess"): self._layernorms = { name: snt.LayerNorm(name="layer_norm_" + name) for name in self.preprocess }
def _build(self, zp_old, inputs): sg = shapeguard.ShapeGuard() sg.guard(zp_old, "B, K, Zp") sg.guard(inputs, "B, K, H") update = snt.Linear(sg.Zp) flat_zp = sg.reshape(zp_old, "B*K, Zp") flat_inputs = sg.reshape(inputs, "B*K, H") zp = flat_zp + update(flat_inputs) return sg.reshape(zp, "B, K, Zp")
def _build(self, data): """Connect model to TensorFlow graph.""" assert self._mlp_opt["output_sizes"][ -1] is not None, "set output_shapes" sg = shapeguard.ShapeGuard() flat_data, unflatten = flatten_all_but_last(data) sg.guard(flat_data, "B, N") mlp = snt.nets.MLP(**self._mlp_opt) # run MLP output = sg.guard(mlp(flat_data), "B, Y") return FlatParameters(unflatten(output))
def predict(self, z): sg = shapeguard.ShapeGuard() z = sg.guard(z, "B, Z") all_preds = sg.guard(self.predictor(z), "B, M") idx = 0 predictions = {} for m in self._mapping: with tf.name_scope(m.name): pred = all_preds[:, idx:idx + m.size] predictions[m.name] = sg.guard(pred, "B, {}".format(m.size)) idx += m.size return predictions
def _build(self, z): """Connect model to TensorFlow graph.""" sg = shapeguard.ShapeGuard() flat_z, unflatten = flatten_all_but_last(z) sg.guard(flat_z, "B, Z") sg.guard(self._target_out_shape, "H, W, C") mlp = snt.nets.MLP(**self._mlp_opt) cnn = snt.nets.ConvNet2DTranspose(paddings=("SAME", ), normalize_final=False, **self._cnn_opt) net = mlp(flat_z) output = sg.guard(cnn(net), "B, H, W, C") return FlatParameters(unflatten(output))
def construct_iterations_image(images, recons, masks, border_width=2, nr_seqs=2, clip=True): """Construct a single image containing image, and recons. Args: images: (B, T, 1, H, W, C) recons: (B, T, 1, H, W, C) masks: (B, T, K, H, W, 1) border_width: int. width of the border in pixels. (default=2) nr_seqs: int. Number of sequences to include. (default=2) clip: bool. Whether to clip the final image to range [0, 1]. Returns: rec_images: (nr, H+border_width*2, (W+border_width*2) * 2, 3) """ sg = shapeguard.ShapeGuard() sg.guard(recons, "B, T, 1, H, W, C") if images.get_shape().as_list()[1] == 1: images = tf.tile(images, sg["1, T, 1, 1, 1, 1"]) sg.guard(images, "B, T, 1, H, W, C") sg.guard(masks, " B, T, K, H, W, 1") if sg.C == 1: # deal with grayscale images = tf.tile(images, [1, 1, 1, 1, 1, 3]) recons = tf.tile(recons, [1, 1, 1, 1, 1, 3]) sg.S = min(nr_seqs, sg.B) with tf.name_scope("diagnostic_image"): # convert masks to rgb masks_trans = tf.transpose(masks[:nr_seqs], [0, 1, 5, 3, 4, 2]) recolored_masks = color_transform(masks_trans) # Pad everything no_pad, pad = (0, 0), (border_width, border_width) paddings = tf.constant([no_pad, no_pad, no_pad, pad, pad, no_pad]) pad_images = tf.pad(images[:nr_seqs], paddings, constant_values=0.5) pad_recons = tf.pad(recons[:nr_seqs], paddings, constant_values=0.5) pad_masks = tf.pad(recolored_masks, paddings, constant_values=0.5) # concatenate all parts along width triples = tf.concat([pad_images, pad_recons, pad_masks], axis=3) triples = sg.guard(triples[:, :, 0], "S, T, 3*Hp, Wp, 3") # concatenate iterations along width and sequences along height final = tf.reshape(tf.transpose(triples, [0, 2, 1, 3, 4]), sg["1, S*3*Hp, Wp*T, 3"]) if clip: final = tf.clip_by_value(final, 0.0, 1.0) return final
def _build(self, data, prev_states): assert not self._hidden_sizes or self._hidden_sizes[-1] is not None assert len(prev_states) == len(self._hidden_sizes) sg = shapeguard.ShapeGuard() sg.guard(data, "B, K, H") data = sg.reshape(data, "B*K, H") out = data new_states = [] for lstm, pstate in zip(self._lstm_layers, prev_states): out, nstate = lstm(out, pstate) new_states.append(nstate) sg.guard(out, "B*K, Y") out = sg.reshape(out, "B, K, Y") return out, new_states
def _build(self, pixel, mask): sg = shapeguard.ShapeGuard() # MASKING sg.guard(mask, "B, K, H, W, 1") mask = tf.transpose(mask, perm=[0, 2, 3, 4, 1]) mask = sg.reshape(mask, "B, H, W, K") mask = self._mask_activation(mask) mask = mask[:, tf.newaxis] # add K=1 axis since K is removed by mixture mix_dist = tfd.Categorical(logits=mask) # COMPONENTS sg.guard(pixel, "B, K, H, W, Cp") params = tf.transpose(pixel, perm=[0, 2, 3, 1, 4]) params = params[:, tf.newaxis] # add K=1 axis since K is removed by mixture dist = self._dist(params) return tfd.MixtureSameFamily( mixture_distribution=mix_dist, components_distribution=dist)
def append_coordinate_channels(self, output): sg = shapeguard.ShapeGuard() sg.guard(output, "B, H, W, C") if self._coord_type is None: return output if self._coord_type == "linear": w_coords = tf.linspace(-1.0, 1.0, sg.W)[None, None, :, None] h_coords = tf.linspace(-1.0, 1.0, sg.H)[None, :, None, None] w_coords = tf.tile(w_coords, sg["B, H, 1, 1"]) h_coords = tf.tile(h_coords, sg["B, 1, W, 1"]) return tf.concat([output, h_coords, w_coords], axis=-1) elif self._coord_type == "cos": freqs = sg.guard(tf.range(0.0, self._coord_freqs), "F") valx = tf.linspace(0.0, np.pi, sg.W)[None, None, :, None, None] valy = tf.linspace(0.0, np.pi, sg.H)[None, :, None, None, None] x_basis = tf.cos(valx * freqs[None, None, None, :, None]) y_basis = tf.cos(valy * freqs[None, None, None, None, :]) xy_basis = tf.reshape(x_basis * y_basis, sg["1, H, W, F*F"]) coords = tf.tile(xy_basis, sg["B, 1, 1, 1"])[Ellipsis, 1:] return tf.concat([output, coords], axis=-1) else: raise KeyError('Unknown coord_type: "{}"'.format(self._coord_type))
def set_output_shapes(self, shape): # assert self._mlp_opt['output_sizes'][-1] is None, self._mlp_opt sg = shapeguard.ShapeGuard() sg.guard(shape, "1, Y") self._mlp_opt["output_sizes"][-1] = sg.Y
def eval(self, data): total_loss, scalars, iterations = self._build(data) sg = shapeguard.ShapeGuard() def get_components(dist): return tf.transpose( dist.components_distribution.mean()[:, 0, :, :, :, :], [0, 3, 1, 2, 4]) def get_mask(dist): return tf.transpose(dist.mixture_distribution.probs[:, :, :, :, :], [0, 4, 2, 3, 1]) def get_mask_logits(dist): return tf.transpose( dist.mixture_distribution.logits[:, :, :, :, :], [0, 4, 2, 3, 1]) def stack_iters(list_of_variables, pad_zero=False): if pad_zero: list_of_variables.insert(0, tf.zeros_like(list_of_variables[0])) return tf.stack(list_of_variables, axis=1) # data image = sg.guard(data["image"], "B, 1, H, W, C") true_mask = sg.guard(data["mask"], "B, 1, L, H, W, 1") visibility = sg.guard(data["visibility"], "B, L") factors = data["factors"] # inputs inputs_flat = { k: stack_iters([inp["flat"][k] for inp in iterations["inputs"]], pad_zero=True) for k in iterations["inputs"][0]["flat"].keys() } inputs_spatial = { k: stack_iters([inp["spatial"][k] for inp in iterations["inputs"]], pad_zero=True) for k in iterations["inputs"][0]["spatial"].keys() } # latent z = sg.guard(stack_iters(iterations["z"]), "B, T, K, Z") z_mean = stack_iters([zd.mean() for zd in iterations["z_dist"]]) z_std = stack_iters([zd.stddev() for zd in iterations["z_dist"]]) # outputs recons = stack_iters([xd.mean() for xd in iterations["x_dist"]]) pred_mask = stack_iters([get_mask(xd) for xd in iterations["x_dist"]]) pred_mask_logits = stack_iters( [get_mask_logits(xd) for xd in iterations["x_dist"]]) components = stack_iters( [get_components(xd) for xd in iterations["x_dist"]]) # metrics tm = tf.transpose(true_mask[Ellipsis, 0], [0, 1, 3, 4, 2]) tm = tf.reshape(tf.tile(tm, sg["1, T, 1, 1, 1"]), sg["B * T, H * W, L"]) pm = tf.transpose(pred_mask[Ellipsis, 0], [0, 1, 3, 4, 2]) pm = tf.reshape(pm, sg["B * T, H * W, K"]) ari = tf.reshape(adjusted_rand_index(tm, pm), sg["B, T"]) ari_nobg = tf.reshape(adjusted_rand_index(tm[Ellipsis, 1:], pm), sg["B, T"]) mse = tf.reduce_mean(tf.square(recons - image[:, None]), axis=[2, 3, 4, 5]) # losses loss_recons = stack_iters(iterations["re"]) kl = stack_iters(iterations["kl"]) info = { "data": { "image": sg.guard(image, "B, 1, H, W, C"), "true_mask": sg.guard(true_mask, "B, 1, L, H, W, 1"), "visibility": sg.guard(visibility, "B, L"), "factors": factors, }, "inputs": { "flat": inputs_flat, "spatial": inputs_spatial }, "latent": { "z": sg.guard(z, "B, T, K, Z"), "z_mean": sg.guard(z_mean, "B, T, K, Z"), "z_std": sg.guard(z_std, "B, T, K, Z"), }, "outputs": { "recons": sg.guard(recons, "B, T, 1, H, W, C"), "pred_mask": sg.guard(pred_mask, "B, T, K, H, W, 1"), "pred_mask_logits": sg.guard(pred_mask_logits, "B, T, K, H, W, 1"), "components": sg.guard(components, "B, T, K, H, W, C"), }, "losses": { "total": total_loss, "recons": sg.guard(loss_recons, "B, T"), "kl": sg.guard(kl, "B, T, K"), }, "metrics": { "ari": ari, "ari_nobg": ari_nobg, "mse": mse }, } if self.factor_evaluator: # factor evaluation information factor_info = { "loss": [], "metrics": collections.defaultdict(list), "predictions": collections.defaultdict(list), "assignment": [], } for t in range(z.get_shape().as_list()[1]): floss, fscalars, _, fpred, fass = self.factor_evaluator( z[:, t], factors, visibility, pred_mask[:, t], true_mask[:, 0]) factor_info["loss"].append(floss) factor_info["assignment"].append(fass) for k in fpred: factor_info["predictions"][k].append( tf.reduce_sum(fpred[k] * fass[Ellipsis, None], axis=2)) factor_info["metrics"][k].append(fscalars[k]) info["losses"]["factor"] = sg.guard(tf.stack(factor_info["loss"]), "T") info["factor_regressor"] = { "assignment": sg.guard(stack_iters(factor_info["assignment"]), "B, T, L, K"), "metrics": { k: tf.stack(factor_info["metrics"][k], axis=0) for k in factor_info["metrics"] }, "predictions": { k: stack_iters(factor_info["predictions"][k]) for k in factor_info["predictions"] }, } return info
def set_output_shapes(self, shape): sg = shapeguard.ShapeGuard() sg.guard(shape, "1, Y") self._mlp_opt["output_sizes"][-1] = sg.Y
def _build(self, z, latent, visibility, pred_mask, true_mask): sg = shapeguard.ShapeGuard() z = sg.guard(z, "B, K, Z") pred_mask = sg.guard(pred_mask, "B, K, H, W, 1") true_mask = sg.guard(true_mask, "B, L, H, W, 1") visibility = sg.guard(visibility, "B, L") num_visible_obj = tf.reduce_sum(visibility) # Map z to predictions for all latents sg.M = sum([m.size for m in self._mapping]) self.predictor = snt.Linear(sg.M, name="predict_latents") z_flat = sg.reshape(z, "B*K, Z") all_preds = sg.guard(self.predictor(z_flat), "B*K, M") all_preds = sg.reshape(all_preds, "B, 1, K, M") all_preds = tf.tile(all_preds, sg["1, L, 1, 1"]) # prepare latents latents = {} mean_var_tot = {} for m in self._mapping: with tf.name_scope(m.name): # preprocess, reshape, and tile lat_preprocess = self.get_preprocessing(m) lat = sg.guard(lat_preprocess(latent[m.name]), "B, L, {}".format(m.size)) # compute mean over latent by training a variable using mse if m.type in {"scalar", "angle"}: mvt = utils.OnlineMeanVarEstimator( axis=[0, 1], ddof=1, name="{}_mean_var".format(m.name)) mean_var_tot[m.name] = mvt(lat, visibility[:, :, tf.newaxis]) lat = tf.reshape(lat, sg["B, L, 1"] + [-1]) lat = tf.tile(lat, sg["1, 1, K, 1"]) latents[m.name] = lat # prepare predictions idx = 0 predictions = {} for m in self._mapping: with tf.name_scope(m.name): assert m.name in latent, "{} not in {}".format( m.name, latent.keys()) pred = all_preds[Ellipsis, idx:idx + m.size] predictions[m.name] = sg.guard(pred, "B, L, K, {}".format(m.size)) idx += m.size # compute error total_pairwise_errors = None for m in self._mapping: with tf.name_scope(m.name): error_fn = self.get_error_func(m) sg.guard(latents[m.name], "B, L, K, {}".format(m.size)) sg.guard(predictions[m.name], "B, L, K, {}".format(m.size)) err = error_fn(latents[m.name], predictions[m.name]) sg.guard(err, "B, L, K") if total_pairwise_errors is None: total_pairwise_errors = err else: total_pairwise_errors += err # determine best assignment by comparing masks obj_mask = true_mask[:, :, tf.newaxis] pred_mask = pred_mask[:, tf.newaxis] pairwise_overlap = tf.reduce_sum(obj_mask * pred_mask, axis=[3, 4, 5]) best_match = sg.guard(tf.argmax(pairwise_overlap, axis=2), "B, L") assignment = tf.one_hot(best_match, sg.K) assignment *= visibility[:, :, tf.newaxis] # Mask non-visible objects # total error total_error = (tf.reduce_sum(assignment * total_pairwise_errors) / num_visible_obj) # compute scalars monitored_scalars = {} for m in self._mapping: with tf.name_scope(m.name): metric = self.get_metric(m) scalar = metric( latents[m.name], predictions[m.name], assignment[:, :, :, tf.newaxis], mean_var_tot.get(m.name), num_visible_obj, ) monitored_scalars[m.name] = scalar return total_error, monitored_scalars, mean_var_tot, predictions, assignment
def __init__(self, pixel_decoder, name="component_decoder"): super().__init__(name=name) self._pixel_decoder = pixel_decoder self._sg = shapeguard.ShapeGuard()