示例#1
0
    def _build(self, zp_old, inputs):
        sg = shapeguard.ShapeGuard()
        sg.guard(zp_old, "B, K, Zp")
        sg.guard(inputs, "B, K, H")
        update = snt.Linear(sg.Zp)
        update_gate = snt.Linear(sg.Zp)
        predict = snt.nets.MLP(
            output_sizes=list(self._hidden_sizes) + [sg.Zp * 2],
            activation=self._activation,
        )

        flat_zp = sg.reshape(zp_old, "B*K, Zp")
        flat_inputs = sg.reshape(inputs, "B*K, H")

        g = tf.nn.sigmoid(update_gate(flat_inputs) + self._corrector_gate_bias)
        u = update(flat_inputs)

        # a slightly more efficient way of computing the gated update
        # (1-g) * flat_zp + g * u
        zp_corrected = flat_zp + g * (u - flat_zp)

        predicted = predict(flat_zp)
        pred_up = predicted[:, :sg.Zp]
        pred_gate = tf.nn.sigmoid(predicted[:, sg.Zp:] + self._pred_gate_bias)

        zp = zp_corrected + pred_gate * (pred_up - zp_corrected)

        return sg.reshape(zp, "B, K, Zp")
    def _build(self, image):
        """Connect model to TensorFlow graph."""
        assert self._mlp_opt["output_sizes"][
            -1] is not None, "set output_shapes"
        sg = shapeguard.ShapeGuard()
        flat_image, unflatten = flatten_all_but_last(image, n_dims=3)
        sg.guard(flat_image, "B, H, W, C")

        cnn = snt.nets.ConvNet2D(activate_final=True,
                                 paddings=("SAME", ),
                                 normalize_final=False,
                                 **self._cnn_opt)
        mlp = snt.nets.MLP(**self._mlp_opt)

        # run CNN
        net = cnn(flat_image)

        if self._mode == "flatten":
            # flatten
            net_shape = net.get_shape().as_list()
            flat_shape = net_shape[:-3] + [np.prod(net_shape[-3:])]
            net = tf.reshape(net, flat_shape)
        elif self._mode == "avg_pool":
            net = tf.reduce_mean(net, axis=[1, 2])
        else:
            raise KeyError('Unknown mode "{}"'.format(self._mode))
        # run MLP
        output = sg.guard(mlp(net), "B, Y")
        return FlatParameters(unflatten(output))
    def _build(self, z):
        """Connect model to TensorFlow graph."""
        assert self._target_out_shape is not None, "Call set_output_shape"
        # reshape components into batch dimension before processing them
        sg = shapeguard.ShapeGuard()
        flat_z, unflatten = flatten_all_but_last(z)
        sg.guard(flat_z, "B, Z")
        sg.guard(self._target_out_shape, "H, W, C")

        if self._mlp_opt is None:
            mlp = tf.identity
        else:
            mlp = snt.nets.MLP(activate_final=True, **self._mlp_opt)
        mlp_output = sg.guard(mlp(flat_z), "B, hidden")

        # tile MLP output spatially and append coordinate channels
        broadcast_mlp_output = tf.tile(
            mlp_output[:, tf.newaxis, tf.newaxis],
            multiples=tf.constant(sg["1, H, W, 1"]),
        )  # B, H, W, Z

        dec_cnn_inputs = self.append_coordinate_channels(broadcast_mlp_output)

        cnn = snt.nets.ConvNet2D(paddings=("SAME", ),
                                 normalize_final=False,
                                 **self._cnn_opt)
        cnn_outputs = cnn(dec_cnn_inputs)
        sg.guard(cnn_outputs, "B, H, W, C")

        return FlatParameters(unflatten(cnn_outputs))
示例#4
0
 def decode(self, z):
     sg = shapeguard.ShapeGuard()
     sg.guard(z, "B, K, Z")
     # legacy
     z = tf.concat([z, 5.0 * tf.ones(sg["B, K, 1"], dtype=tf.float32)],
                   axis=2)
     params = self.decoder(z)
     out_dist = self.output_dist(*params)
     return params, out_dist
示例#5
0
 def __init__(self,
              encoder_net,
              recurrent_net,
              refinement_head,
              name="refinement"):
     super().__init__(name=name)
     self._encoder_net = encoder_net
     self._recurrent_net = recurrent_net
     self._refinement_head = refinement_head
     self._sg = shapeguard.ShapeGuard()
示例#6
0
    def __init__(
        self,
        decoder,
        refinement_core,
        latent_dist,
        output_dist,
        n_z,
        num_components,
        num_iters,
        sequential=False,
        factor_evaluator=None,
        stop_gradients=DEFAULT_STOP_GRADIENT,
        iter_loss_weight="linspace",
        inputs=DEFAULT_INPUTS,
        preprocess=None,
        coord_type="linear",
        coord_freqs=3,
        name="iodine",
    ):
        super().__init__(name=name)
        self._sg = shapeguard.ShapeGuard(dims={"K": num_components})
        self.decoder = decoder
        self.refinement_core = refinement_core

        self.latent_dist = latent_dist
        self.output_dist = output_dist

        self.n_z = n_z
        self.num_components = num_components
        self.num_iters = num_iters
        self.sequential = sequential
        self.iter_loss_weights = self._parse_iter_loss_weights(
            iter_loss_weight)

        self.factor_evaluator = factor_evaluator

        self.stop_gradients = stop_gradients
        self.inputs = inputs
        self.preprocess = DEFAULT_PREPROCESSING if preprocess is None else preprocess
        self.coord_type = coord_type
        self.coord_freqs = coord_freqs

        with self._enter_variable_scope():
            self.latent_dist.set_output_shape([self.n_z])
            logging.info("VAE: z shape: %s", [self.n_z])
            with tf.name_scope("prior"):
                self.prior = self.latent_dist.get_default_prior(
                    (self.num_components, ))
            self._sg.guard(self.prior, "K, Z")
            with tf.variable_scope("preprocess"):
                self._layernorms = {
                    name: snt.LayerNorm(name="layer_norm_" + name)
                    for name in self.preprocess
                }
示例#7
0
    def _build(self, zp_old, inputs):
        sg = shapeguard.ShapeGuard()
        sg.guard(zp_old, "B, K, Zp")
        sg.guard(inputs, "B, K, H")
        update = snt.Linear(sg.Zp)

        flat_zp = sg.reshape(zp_old, "B*K, Zp")
        flat_inputs = sg.reshape(inputs, "B*K, H")

        zp = flat_zp + update(flat_inputs)

        return sg.reshape(zp, "B, K, Zp")
    def _build(self, data):
        """Connect model to TensorFlow graph."""
        assert self._mlp_opt["output_sizes"][
            -1] is not None, "set output_shapes"
        sg = shapeguard.ShapeGuard()
        flat_data, unflatten = flatten_all_but_last(data)
        sg.guard(flat_data, "B, N")

        mlp = snt.nets.MLP(**self._mlp_opt)
        # run MLP
        output = sg.guard(mlp(flat_data), "B, Y")
        return FlatParameters(unflatten(output))
示例#9
0
    def predict(self, z):
        sg = shapeguard.ShapeGuard()
        z = sg.guard(z, "B, Z")
        all_preds = sg.guard(self.predictor(z), "B, M")

        idx = 0
        predictions = {}
        for m in self._mapping:
            with tf.name_scope(m.name):
                pred = all_preds[:, idx:idx + m.size]
                predictions[m.name] = sg.guard(pred, "B, {}".format(m.size))
                idx += m.size
        return predictions
示例#10
0
 def _build(self, z):
     """Connect model to TensorFlow graph."""
     sg = shapeguard.ShapeGuard()
     flat_z, unflatten = flatten_all_but_last(z)
     sg.guard(flat_z, "B, Z")
     sg.guard(self._target_out_shape, "H, W, C")
     mlp = snt.nets.MLP(**self._mlp_opt)
     cnn = snt.nets.ConvNet2DTranspose(paddings=("SAME", ),
                                       normalize_final=False,
                                       **self._cnn_opt)
     net = mlp(flat_z)
     output = sg.guard(cnn(net), "B, H, W, C")
     return FlatParameters(unflatten(output))
示例#11
0
def construct_iterations_image(images,
                               recons,
                               masks,
                               border_width=2,
                               nr_seqs=2,
                               clip=True):
    """Construct a single image containing image, and recons.

    Args:
      images: (B, T, 1, H, W, C)
      recons: (B, T, 1, H, W, C)
      masks:  (B, T, K, H, W, 1)
      border_width: int. width of the border in pixels. (default=2)
      nr_seqs: int. Number of sequences to include. (default=2)
      clip: bool. Whether to clip the final image to range [0, 1].

    Returns:
      rec_images: (nr, H+border_width*2, (W+border_width*2) * 2, 3)
    """
    sg = shapeguard.ShapeGuard()
    sg.guard(recons, "B, T, 1, H, W, C")
    if images.get_shape().as_list()[1] == 1:
        images = tf.tile(images, sg["1, T, 1, 1, 1, 1"])
    sg.guard(images, "B, T, 1, H, W, C")
    sg.guard(masks, " B, T, K, H, W, 1")
    if sg.C == 1:  # deal with grayscale
        images = tf.tile(images, [1, 1, 1, 1, 1, 3])
        recons = tf.tile(recons, [1, 1, 1, 1, 1, 3])
    sg.S = min(nr_seqs, sg.B)
    with tf.name_scope("diagnostic_image"):
        # convert masks to rgb
        masks_trans = tf.transpose(masks[:nr_seqs], [0, 1, 5, 3, 4, 2])
        recolored_masks = color_transform(masks_trans)
        # Pad everything
        no_pad, pad = (0, 0), (border_width, border_width)
        paddings = tf.constant([no_pad, no_pad, no_pad, pad, pad, no_pad])
        pad_images = tf.pad(images[:nr_seqs], paddings, constant_values=0.5)
        pad_recons = tf.pad(recons[:nr_seqs], paddings, constant_values=0.5)
        pad_masks = tf.pad(recolored_masks, paddings, constant_values=0.5)
        # concatenate all parts along width
        triples = tf.concat([pad_images, pad_recons, pad_masks], axis=3)
        triples = sg.guard(triples[:, :, 0], "S, T, 3*Hp, Wp, 3")
        # concatenate iterations along width and sequences along height
        final = tf.reshape(tf.transpose(triples, [0, 2, 1, 3, 4]),
                           sg["1, S*3*Hp, Wp*T, 3"])
        if clip:
            final = tf.clip_by_value(final, 0.0, 1.0)
        return final
示例#12
0
    def _build(self, data, prev_states):
        assert not self._hidden_sizes or self._hidden_sizes[-1] is not None
        assert len(prev_states) == len(self._hidden_sizes)
        sg = shapeguard.ShapeGuard()
        sg.guard(data, "B, K, H")
        data = sg.reshape(data, "B*K, H")

        out = data
        new_states = []
        for lstm, pstate in zip(self._lstm_layers, prev_states):
            out, nstate = lstm(out, pstate)
            new_states.append(nstate)

        sg.guard(out, "B*K, Y")
        out = sg.reshape(out, "B, K, Y")
        return out, new_states
示例#13
0
  def _build(self, pixel, mask):
    sg = shapeguard.ShapeGuard()
    # MASKING
    sg.guard(mask, "B, K, H, W, 1")
    mask = tf.transpose(mask, perm=[0, 2, 3, 4, 1])
    mask = sg.reshape(mask, "B, H, W, K")
    mask = self._mask_activation(mask)
    mask = mask[:, tf.newaxis]  # add K=1 axis since K is removed by mixture
    mix_dist = tfd.Categorical(logits=mask)

    # COMPONENTS
    sg.guard(pixel, "B, K, H, W, Cp")
    params = tf.transpose(pixel, perm=[0, 2, 3, 1, 4])
    params = params[:, tf.newaxis]  # add K=1 axis since K is removed by mixture
    dist = self._dist(params)
    return tfd.MixtureSameFamily(
        mixture_distribution=mix_dist, components_distribution=dist)
示例#14
0
 def append_coordinate_channels(self, output):
     sg = shapeguard.ShapeGuard()
     sg.guard(output, "B, H, W, C")
     if self._coord_type is None:
         return output
     if self._coord_type == "linear":
         w_coords = tf.linspace(-1.0, 1.0, sg.W)[None, None, :, None]
         h_coords = tf.linspace(-1.0, 1.0, sg.H)[None, :, None, None]
         w_coords = tf.tile(w_coords, sg["B, H, 1, 1"])
         h_coords = tf.tile(h_coords, sg["B, 1, W, 1"])
         return tf.concat([output, h_coords, w_coords], axis=-1)
     elif self._coord_type == "cos":
         freqs = sg.guard(tf.range(0.0, self._coord_freqs), "F")
         valx = tf.linspace(0.0, np.pi, sg.W)[None, None, :, None, None]
         valy = tf.linspace(0.0, np.pi, sg.H)[None, :, None, None, None]
         x_basis = tf.cos(valx * freqs[None, None, None, :, None])
         y_basis = tf.cos(valy * freqs[None, None, None, None, :])
         xy_basis = tf.reshape(x_basis * y_basis, sg["1, H, W, F*F"])
         coords = tf.tile(xy_basis, sg["B,  1, 1, 1"])[Ellipsis, 1:]
         return tf.concat([output, coords], axis=-1)
     else:
         raise KeyError('Unknown coord_type: "{}"'.format(self._coord_type))
示例#15
0
 def set_output_shapes(self, shape):
     # assert self._mlp_opt['output_sizes'][-1] is None, self._mlp_opt
     sg = shapeguard.ShapeGuard()
     sg.guard(shape, "1, Y")
     self._mlp_opt["output_sizes"][-1] = sg.Y
示例#16
0
    def eval(self, data):
        total_loss, scalars, iterations = self._build(data)
        sg = shapeguard.ShapeGuard()

        def get_components(dist):
            return tf.transpose(
                dist.components_distribution.mean()[:, 0, :, :, :, :],
                [0, 3, 1, 2, 4])

        def get_mask(dist):
            return tf.transpose(dist.mixture_distribution.probs[:, :, :, :, :],
                                [0, 4, 2, 3, 1])

        def get_mask_logits(dist):
            return tf.transpose(
                dist.mixture_distribution.logits[:, :, :, :, :],
                [0, 4, 2, 3, 1])

        def stack_iters(list_of_variables, pad_zero=False):
            if pad_zero:
                list_of_variables.insert(0,
                                         tf.zeros_like(list_of_variables[0]))
            return tf.stack(list_of_variables, axis=1)

        # data
        image = sg.guard(data["image"], "B, 1, H, W, C")
        true_mask = sg.guard(data["mask"], "B, 1, L, H, W, 1")
        visibility = sg.guard(data["visibility"], "B, L")
        factors = data["factors"]

        # inputs
        inputs_flat = {
            k: stack_iters([inp["flat"][k] for inp in iterations["inputs"]],
                           pad_zero=True)
            for k in iterations["inputs"][0]["flat"].keys()
        }
        inputs_spatial = {
            k: stack_iters([inp["spatial"][k] for inp in iterations["inputs"]],
                           pad_zero=True)
            for k in iterations["inputs"][0]["spatial"].keys()
        }
        # latent
        z = sg.guard(stack_iters(iterations["z"]), "B, T, K, Z")
        z_mean = stack_iters([zd.mean() for zd in iterations["z_dist"]])
        z_std = stack_iters([zd.stddev() for zd in iterations["z_dist"]])
        # outputs
        recons = stack_iters([xd.mean() for xd in iterations["x_dist"]])
        pred_mask = stack_iters([get_mask(xd) for xd in iterations["x_dist"]])
        pred_mask_logits = stack_iters(
            [get_mask_logits(xd) for xd in iterations["x_dist"]])
        components = stack_iters(
            [get_components(xd) for xd in iterations["x_dist"]])

        # metrics
        tm = tf.transpose(true_mask[Ellipsis, 0], [0, 1, 3, 4, 2])
        tm = tf.reshape(tf.tile(tm, sg["1, T, 1, 1, 1"]),
                        sg["B * T, H * W, L"])
        pm = tf.transpose(pred_mask[Ellipsis, 0], [0, 1, 3, 4, 2])
        pm = tf.reshape(pm, sg["B * T, H * W, K"])
        ari = tf.reshape(adjusted_rand_index(tm, pm), sg["B, T"])
        ari_nobg = tf.reshape(adjusted_rand_index(tm[Ellipsis, 1:], pm),
                              sg["B, T"])

        mse = tf.reduce_mean(tf.square(recons - image[:, None]),
                             axis=[2, 3, 4, 5])

        # losses
        loss_recons = stack_iters(iterations["re"])
        kl = stack_iters(iterations["kl"])

        info = {
            "data": {
                "image": sg.guard(image, "B, 1, H, W, C"),
                "true_mask": sg.guard(true_mask, "B, 1, L, H, W, 1"),
                "visibility": sg.guard(visibility, "B, L"),
                "factors": factors,
            },
            "inputs": {
                "flat": inputs_flat,
                "spatial": inputs_spatial
            },
            "latent": {
                "z": sg.guard(z, "B, T, K, Z"),
                "z_mean": sg.guard(z_mean, "B, T, K, Z"),
                "z_std": sg.guard(z_std, "B, T, K, Z"),
            },
            "outputs": {
                "recons": sg.guard(recons, "B, T, 1, H, W, C"),
                "pred_mask": sg.guard(pred_mask, "B, T, K, H, W, 1"),
                "pred_mask_logits": sg.guard(pred_mask_logits,
                                             "B, T, K, H, W, 1"),
                "components": sg.guard(components, "B, T, K, H, W, C"),
            },
            "losses": {
                "total": total_loss,
                "recons": sg.guard(loss_recons, "B, T"),
                "kl": sg.guard(kl, "B, T, K"),
            },
            "metrics": {
                "ari": ari,
                "ari_nobg": ari_nobg,
                "mse": mse
            },
        }

        if self.factor_evaluator:
            # factor evaluation information
            factor_info = {
                "loss": [],
                "metrics": collections.defaultdict(list),
                "predictions": collections.defaultdict(list),
                "assignment": [],
            }
            for t in range(z.get_shape().as_list()[1]):
                floss, fscalars, _, fpred, fass = self.factor_evaluator(
                    z[:, t], factors, visibility, pred_mask[:, t],
                    true_mask[:, 0])
                factor_info["loss"].append(floss)
                factor_info["assignment"].append(fass)
                for k in fpred:
                    factor_info["predictions"][k].append(
                        tf.reduce_sum(fpred[k] * fass[Ellipsis, None], axis=2))
                    factor_info["metrics"][k].append(fscalars[k])

            info["losses"]["factor"] = sg.guard(tf.stack(factor_info["loss"]),
                                                "T")
            info["factor_regressor"] = {
                "assignment":
                sg.guard(stack_iters(factor_info["assignment"]), "B, T, L, K"),
                "metrics": {
                    k: tf.stack(factor_info["metrics"][k], axis=0)
                    for k in factor_info["metrics"]
                },
                "predictions": {
                    k: stack_iters(factor_info["predictions"][k])
                    for k in factor_info["predictions"]
                },
            }

        return info
示例#17
0
 def set_output_shapes(self, shape):
     sg = shapeguard.ShapeGuard()
     sg.guard(shape, "1, Y")
     self._mlp_opt["output_sizes"][-1] = sg.Y
示例#18
0
    def _build(self, z, latent, visibility, pred_mask, true_mask):
        sg = shapeguard.ShapeGuard()
        z = sg.guard(z, "B, K, Z")
        pred_mask = sg.guard(pred_mask, "B, K, H, W, 1")
        true_mask = sg.guard(true_mask, "B, L, H, W, 1")

        visibility = sg.guard(visibility, "B, L")
        num_visible_obj = tf.reduce_sum(visibility)

        # Map z to predictions for all latents
        sg.M = sum([m.size for m in self._mapping])
        self.predictor = snt.Linear(sg.M, name="predict_latents")
        z_flat = sg.reshape(z, "B*K, Z")
        all_preds = sg.guard(self.predictor(z_flat), "B*K, M")
        all_preds = sg.reshape(all_preds, "B, 1, K, M")
        all_preds = tf.tile(all_preds, sg["1, L, 1, 1"])

        # prepare latents
        latents = {}
        mean_var_tot = {}
        for m in self._mapping:
            with tf.name_scope(m.name):
                # preprocess, reshape, and tile
                lat_preprocess = self.get_preprocessing(m)
                lat = sg.guard(lat_preprocess(latent[m.name]),
                               "B, L, {}".format(m.size))
                # compute mean over latent by training a variable using mse
                if m.type in {"scalar", "angle"}:
                    mvt = utils.OnlineMeanVarEstimator(
                        axis=[0, 1], ddof=1, name="{}_mean_var".format(m.name))
                    mean_var_tot[m.name] = mvt(lat, visibility[:, :,
                                                               tf.newaxis])

                lat = tf.reshape(lat, sg["B, L, 1"] + [-1])
                lat = tf.tile(lat, sg["1, 1, K, 1"])
                latents[m.name] = lat

        # prepare predictions
        idx = 0
        predictions = {}
        for m in self._mapping:
            with tf.name_scope(m.name):
                assert m.name in latent, "{} not in {}".format(
                    m.name, latent.keys())
                pred = all_preds[Ellipsis, idx:idx + m.size]
                predictions[m.name] = sg.guard(pred,
                                               "B, L, K, {}".format(m.size))
                idx += m.size

        # compute error
        total_pairwise_errors = None
        for m in self._mapping:
            with tf.name_scope(m.name):
                error_fn = self.get_error_func(m)
                sg.guard(latents[m.name], "B, L, K, {}".format(m.size))
                sg.guard(predictions[m.name], "B, L, K, {}".format(m.size))
                err = error_fn(latents[m.name], predictions[m.name])
                sg.guard(err, "B, L, K")
                if total_pairwise_errors is None:
                    total_pairwise_errors = err
                else:
                    total_pairwise_errors += err

        # determine best assignment by comparing masks
        obj_mask = true_mask[:, :, tf.newaxis]
        pred_mask = pred_mask[:, tf.newaxis]
        pairwise_overlap = tf.reduce_sum(obj_mask * pred_mask, axis=[3, 4, 5])
        best_match = sg.guard(tf.argmax(pairwise_overlap, axis=2), "B, L")
        assignment = tf.one_hot(best_match, sg.K)
        assignment *= visibility[:, :, tf.newaxis]  # Mask non-visible objects

        # total error
        total_error = (tf.reduce_sum(assignment * total_pairwise_errors) /
                       num_visible_obj)

        # compute scalars
        monitored_scalars = {}
        for m in self._mapping:
            with tf.name_scope(m.name):
                metric = self.get_metric(m)
                scalar = metric(
                    latents[m.name],
                    predictions[m.name],
                    assignment[:, :, :, tf.newaxis],
                    mean_var_tot.get(m.name),
                    num_visible_obj,
                )
                monitored_scalars[m.name] = scalar
        return total_error, monitored_scalars, mean_var_tot, predictions, assignment
示例#19
0
 def __init__(self, pixel_decoder, name="component_decoder"):
     super().__init__(name=name)
     self._pixel_decoder = pixel_decoder
     self._sg = shapeguard.ShapeGuard()