Beispiel #1
0
def extract_affine_glimpse(image, object_shape, cyt, cxt, ys, xs,
                           edge_resampler):
    """ (cyt, cxt) are rectangle center. (ys, xs) are rectangle height/width """
    _, *image_shape, image_depth = tf_shape(image)
    transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
    warper = snt.AffineGridWarper(image_shape, object_shape,
                                  transform_constraints)

    # change coordinate system
    cyt = 2 * cyt - 1
    cxt = 2 * cxt - 1

    leading_shape = tf_shape(cyt)[:-1]

    _boxes = tf.concat([xs, cxt, ys, cyt], axis=-1)
    _boxes = tf.reshape(_boxes, (-1, 4))

    grid_coords = warper(_boxes)

    grid_coords = tf.reshape(grid_coords, (*leading_shape, *object_shape, 2))

    if edge_resampler:
        glimpses = resampler_edge.resampler_edge(image, grid_coords)
    else:
        glimpses = tf.contrib.resampler.resampler(image, grid_coords)

    glimpses = tf.reshape(glimpses,
                          (*leading_shape, *object_shape, image_depth))

    return glimpses
Beispiel #2
0
    def _call(self, inp, inp_features, is_training, is_posterior=True, prop_state=None):
        print("\n" + "-" * 10 + " GridObjectLayer(is_posterior={}) ".format(is_posterior) + "-" * 10)

        # --- set up sub networks and attributes ---

        self.maybe_build_subnet("box_network", builder=cfg.build_lateral, key="box")
        self.maybe_build_subnet("attr_network", builder=cfg.build_lateral, key="attr")
        self.maybe_build_subnet("z_network", builder=cfg.build_lateral, key="z")
        self.maybe_build_subnet("obj_network", builder=cfg.build_lateral, key="obj")

        self.maybe_build_subnet("object_encoder")

        _, H, W, _, _ = tf_shape(inp_features)
        H = int(H)
        W = int(W)

        if not self.initialized:
            # Note this limits the re-usability of this module to images
            # with a fixed shape (the shape of the first image it is used on)
            self.batch_size, self.image_height, self.image_width, self.image_depth = tf_shape(inp)
            self.H = H
            self.W = W
            self.HWB = H*W*self.B
            self.is_training = is_training
            self.float_is_training = tf.to_float(is_training)

        # --- set up the edge element ---

        sizes = [4, self.A, 1, 1]
        sigmoids = [True, False, False, True]
        total_sample_size = sum(sizes)

        if self.edge_weights is None:
            self.edge_weights = tf.get_variable("edge_weights", shape=total_sample_size, dtype=tf.float32)
            if "edge" in self.fixed_weights:
                tf.add_to_collection(FIXED_COLLECTION, self.edge_weights)

        _edge_weights = tf.split(self.edge_weights, sizes, axis=0)
        _edge_weights = [
            (tf.nn.sigmoid(ew) if sigmoid else ew)
            for ew, sigmoid in zip(_edge_weights, sigmoids)]
        edge_element = tf.concat(_edge_weights, axis=0)
        edge_element = tf.tile(edge_element[None, :], (self.batch_size, 1))

        # --- containers for storing built program ---

        program = np.empty((H, W, self.B), dtype=np.object)

        # --- build the program ---

        is_posterior_tf = tf.ones((self.batch_size, 2))
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        results = []
        for h, w, b in itertools.product(range(H), range(W), range(self.B)):
            built = dict()

            partial_program, features = None, None
            context = self._get_sequential_context(program, h, w, b, edge_element)
            base_features = tf.concat([inp_features[:, h, w, b, :], context, is_posterior_tf], axis=1)

            # --- box ---

            layer_inp = base_features
            n_features = self.n_passthrough_features
            output_size = 8

            network_output = self.box_network(layer_inp, output_size + n_features, self. is_training)
            rep_input, features = tf.split(network_output, (output_size, n_features), axis=1)

            _built = self._build_box(rep_input, self.is_training, hw=(h, w))
            built.update(_built)
            partial_program = built['local_box']

            # --- attr ---

            if is_posterior:
                # --- Get object attributes using object encoder ---

                yt, xt, ys, xs = tf.split(built['normalized_box'], 4, axis=-1)

                yt, xt, ys, xs = coords_to_image_space(
                    yt, xt, ys, xs, (self.image_height, self.image_width), self.anchor_box, top_left=False)

                transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
                warper = snt.AffineGridWarper(
                    (self.image_height, self.image_width), self.object_shape, transform_constraints)

                _boxes = tf.concat([xs, 2*xt - 1, ys, 2*yt - 1], axis=-1)

                grid_coords = warper(_boxes)
                grid_coords = tf.reshape(grid_coords, (self.batch_size, 1, *self.object_shape, 2,))
                if self.edge_resampler:
                    glimpse = resampler_edge.resampler_edge(inp, grid_coords)
                else:
                    glimpse = tf.contrib.resampler.resampler(inp, grid_coords)
                glimpse = tf.reshape(glimpse, (self.batch_size, *self.object_shape, self.image_depth))
            else:
                glimpse = tf.zeros((self.batch_size, *self.object_shape, self.image_depth))

            # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction
            encoded_glimpse = self.object_encoder(glimpse, (1, 1, self.A), self.is_training)
            encoded_glimpse = tf.reshape(encoded_glimpse, (self.batch_size, self.A))

            if not is_posterior:
                encoded_glimpse = tf.zeros_like(encoded_glimpse)

            layer_inp = tf.concat(
                [base_features, features, encoded_glimpse, partial_program], axis=1)
            network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self. is_training)
            attr_mean, attr_log_std, features = tf.split(network_output, (self.A, self.A, n_features), axis=1)

            attr_std = self.std_nonlinearity(attr_log_std)

            attr = Normal(loc=attr_mean, scale=attr_std).sample()

            built.update(attr_mean=attr_mean, attr_std=attr_std, attr=attr, glimpse=glimpse)
            partial_program = tf.concat([partial_program, built['attr']], axis=1)

            # --- z ---

            layer_inp = tf.concat([base_features, features, partial_program], axis=1)
            n_features = self.n_passthrough_features

            network_output = self.z_network(layer_inp, 2 + n_features, self.is_training)
            z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=1)
            z_std = self.std_nonlinearity(z_log_std)

            z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (1-self.training_wheels) * z_mean
            z_std = self.training_wheels * tf.stop_gradient(z_std) + (1-self.training_wheels) * z_std
            z_logit = Normal(loc=z_mean, scale=z_std).sample()
            z = self.z_nonlinearity(z_logit)

            built.update(z_logit_mean=z_mean, z_logit_std=z_std, z_logit=z_logit, z=z)
            partial_program = tf.concat([partial_program, built['z']], axis=1)

            # --- obj ---

            layer_inp = tf.concat([base_features, features, partial_program], axis=1)
            rep_input = self.obj_network(layer_inp, 1, self.is_training)

            _built = self._build_obj(rep_input, self.is_training)
            built.update(_built)

            partial_program = tf.concat([partial_program, built['obj']], axis=1)

            # --- final ---

            results.append(built)

            program[h, w, b] = partial_program
            assert program[h, w, b].shape[1] == total_sample_size

        objects = AttrDict()
        for k in results[0]:
            objects[k] = tf.stack([r[k] for r in results], axis=1)

        if prop_state is not None:
            objects.prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))
            objects.prior_prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))

        # --- misc ---

        objects.n_objects = tf.fill((self.batch_size,), self.HWB)
        objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2))
        objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj), axis=(1, 2))

        return objects
Beispiel #3
0
    def _call(self, inp, inp_features, is_training, is_posterior=True, prop_state=None):
        print("\n" + "-" * 10 + " ConvGridObjectLayer(is_posterior={}) ".format(is_posterior) + "-" * 10)

        # --- set up sub networks and attributes ---

        self.maybe_build_subnet("box_network", builder=cfg.build_conv_lateral, key="box")
        self.maybe_build_subnet("attr_network", builder=cfg.build_conv_lateral, key="attr")
        self.maybe_build_subnet("z_network", builder=cfg.build_conv_lateral, key="z")
        self.maybe_build_subnet("obj_network", builder=cfg.build_conv_lateral, key="obj")

        self.maybe_build_subnet("object_encoder")

        _, H, W, _, n_channels = tf_shape(inp_features)

        if self.B != 1:
            raise Exception("NotImplemented")

        if not self.initialized:
            # Note this limits the re-usability of this module to images
            # with a fixed shape (the shape of the first image it is used on)
            self.batch_size, self.image_height, self.image_width, self.image_depth = tf_shape(inp)
            self.H = H
            self.W = W
            self.HWB = H*W
            self.batch_size = tf.shape(inp)[0]
            self.is_training = is_training
            self.float_is_training = tf.to_float(is_training)

        inp_features = tf.reshape(inp_features, (self.batch_size, H, W, n_channels))
        is_posterior_tf = tf.ones_like(inp_features[..., :2])
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        objects = AttrDict()

        base_features = tf.concat([inp_features, is_posterior_tf], axis=-1)

        # --- box ---

        layer_inp = base_features
        n_features = self.n_passthrough_features
        output_size = 8

        network_output = self.box_network(layer_inp, output_size + n_features, self.is_training)
        rep_input, features = tf.split(network_output, (output_size, n_features), axis=-1)

        _objects = self._build_box(rep_input, self.is_training)
        objects.update(_objects)

        # --- attr ---

        if is_posterior:
            # --- Get object attributes using object encoder ---

            yt, xt, ys, xs = tf.split(objects['normalized_box'], 4, axis=-1)

            yt, xt, ys, xs = coords_to_image_space(
                yt, xt, ys, xs, (self.image_height, self.image_width), self.anchor_box, top_left=False)

            transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
            warper = snt.AffineGridWarper(
                (self.image_height, self.image_width), self.object_shape, transform_constraints)

            _boxes = tf.concat([xs, 2*xt - 1, ys, 2*yt - 1], axis=-1)
            _boxes = tf.reshape(_boxes, (self.batch_size*H*W, 4))
            grid_coords = warper(_boxes)
            grid_coords = tf.reshape(grid_coords, (self.batch_size, H, W, *self.object_shape, 2,))

            if self.edge_resampler:
                glimpse = resampler_edge.resampler_edge(inp, grid_coords)
            else:
                glimpse = tf.contrib.resampler.resampler(inp, grid_coords)
        else:
            glimpse = tf.zeros((self.batch_size, H, W, *self.object_shape, self.image_depth))

        # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction
        encoded_glimpse = apply_object_wise(
            self.object_encoder, glimpse, n_trailing_dims=3, output_size=self.A, is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse = tf.zeros_like(encoded_glimpse)

        layer_inp = tf.concat([base_features, features, encoded_glimpse, objects['local_box']], axis=-1)
        network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self.is_training)
        attr_mean, attr_log_std, features = tf.split(network_output, (self.A, self.A, n_features), axis=-1)

        attr_std = self.std_nonlinearity(attr_log_std)

        attr = Normal(loc=attr_mean, scale=attr_std).sample()

        objects.update(attr_mean=attr_mean, attr_std=attr_std, attr=attr, glimpse=glimpse)

        # --- z ---

        layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr']], axis=-1)
        n_features = self.n_passthrough_features

        network_output = self.z_network(layer_inp, 2 + n_features, self.is_training)
        z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=-1)
        z_std = self.std_nonlinearity(z_log_std)

        z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (1-self.training_wheels) * z_mean
        z_std = self.training_wheels * tf.stop_gradient(z_std) + (1-self.training_wheels) * z_std
        z_logit = Normal(loc=z_mean, scale=z_std).sample()
        z = self.z_nonlinearity(z_logit)

        objects.update(z_logit_mean=z_mean, z_logit_std=z_std, z_logit=z_logit, z=z)

        # --- obj ---

        layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr'], objects['z']], axis=-1)
        rep_input = self.obj_network(layer_inp, 1, self.is_training)

        _objects = self._build_obj(rep_input, self.is_training)
        objects.update(_objects)

        # --- final ---

        _objects = AttrDict()
        for k, v in objects.items():
            _, _, _, *trailing_dims = tf_shape(v)
            _objects[k] = tf.reshape(v, (self.batch_size, self.HWB, *trailing_dims))
        objects = _objects

        if prop_state is not None:
            objects.prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))
            objects.prior_prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))

        # --- misc ---

        objects.n_objects = tf.fill((self.batch_size,), self.HWB)
        objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2))
        objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj), axis=(1, 2))

        return objects
Beispiel #4
0
    def _call(self, inp, inp_features, is_training, is_posterior=True):
        print("\n" + "-" * 10 +
              " GridObjectLayer(is_posterior={}) ".format(is_posterior) +
              "-" * 10)

        # --- set up sub networks and attributes ---

        self.maybe_build_subnet("box_network",
                                builder=cfg.build_lateral,
                                key="box")
        self.maybe_build_subnet("attr_network",
                                builder=cfg.build_lateral,
                                key="attr")
        self.maybe_build_subnet("z_network",
                                builder=cfg.build_lateral,
                                key="z")
        self.maybe_build_subnet("obj_network",
                                builder=cfg.build_lateral,
                                key="obj")

        self.maybe_build_subnet("object_encoder")

        _, H, W, _, _ = inp_features.shape
        H = int(H)
        W = int(W)

        if not self.initialized:
            # Note this limits the re-usability of this module to images
            # with a fixed shape (the shape of the first image it is used on)
            self.image_height = int(inp.shape[-3])
            self.image_width = int(inp.shape[-2])
            self.image_depth = int(inp.shape[-1])
            self.H = H
            self.W = W
            self.HWB = H * W * self.B
            self.batch_size = tf.shape(inp)[0]
            self.is_training = is_training
            self.float_is_training = tf.to_float(is_training)

        # --- set up the edge element ---

        sizes = [4, self.A, 1, 1]
        sigmoids = [True, False, False, True]
        total_sample_size = sum(sizes)

        if self.edge_weights is None:
            self.edge_weights = tf.get_variable("edge_weights",
                                                shape=total_sample_size,
                                                dtype=tf.float32)
            if "edge" in self.fixed_weights:
                tf.add_to_collection(FIXED_COLLECTION, self.edge_weights)

        _edge_weights = tf.split(self.edge_weights, sizes, axis=0)
        _edge_weights = [(tf.nn.sigmoid(ew) if sigmoid else ew)
                         for ew, sigmoid in zip(_edge_weights, sigmoids)]
        edge_element = tf.concat(_edge_weights, axis=0)
        edge_element = tf.tile(edge_element[None, :], (self.batch_size, 1))

        # --- containers for storing built program ---

        _tensors = collections.defaultdict(self._make_empty)
        program = np.empty((H, W, self.B), dtype=np.object)

        # --- build the program ---

        is_posterior_tf = tf.ones((self.batch_size, 2))
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        for h, w, b in itertools.product(range(H), range(W), range(self.B)):
            partial_program, features = None, None
            context = self._get_sequential_context(program, h, w, b,
                                                   edge_element)
            base_features = tf.concat(
                [inp_features[:, h, w, b, :], context, is_posterior_tf],
                axis=1)

            # --- box ---

            layer_inp = base_features
            n_features = self.n_passthrough_features
            output_size = 8

            network_output = self.box_network(layer_inp,
                                              output_size + n_features,
                                              self.is_training)
            rep_input, features = tf.split(network_output,
                                           (output_size, n_features),
                                           axis=1)

            built = self._build_box(rep_input, h, w, b, self.is_training)

            for key, value in built.items():
                _tensors[key][h, w, b] = value
            partial_program = built['box']

            # --- attr ---

            if is_posterior:
                # --- Get object attributes using object encoder ---

                yt, xt, ys, xs = tf.split(built['normalized_box'], 4, axis=-1)

                # yt/xt give top/left but here we need center
                yt += ys / 2
                xt += xs / 2

                transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
                warper = snt.AffineGridWarper(
                    (self.image_height, self.image_width), self.object_shape,
                    transform_constraints)

                _boxes = tf.concat([xs, 2 * xt - 1, ys, 2 * yt - 1], axis=-1)

                grid_coords = warper(_boxes)
                grid_coords = tf.reshape(grid_coords, (
                    self.batch_size,
                    1,
                    *self.object_shape,
                    2,
                ))
                glimpse = resampler_edge.resampler_edge(inp, grid_coords)
                glimpse = tf.reshape(
                    glimpse,
                    (self.batch_size, *self.object_shape, self.image_depth))
            else:
                glimpse = tf.zeros(
                    (self.batch_size, *self.object_shape, self.image_depth))

            # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction
            encoded_glimpse = self.object_encoder(glimpse, (1, 1, self.A),
                                                  self.is_training)
            encoded_glimpse = tf.reshape(encoded_glimpse,
                                         (self.batch_size, self.A))

            if not is_posterior:
                encoded_glimpse = tf.zeros_like(encoded_glimpse)

            layer_inp = tf.concat(
                [base_features, features, encoded_glimpse, partial_program],
                axis=1)
            network_output = self.attr_network(layer_inp,
                                               2 * self.A + n_features,
                                               self.is_training)
            attr_mean, attr_log_std, features = tf.split(
                network_output, (self.A, self.A, n_features), axis=1)

            attr_std = self.std_nonlinearity(attr_log_std)

            attr_dist = Normal(loc=attr_mean, scale=attr_std)
            attr = attr_dist.sample()

            if "attr" in self.no_gradient:
                attr = tf.stop_gradient(attr)

            built = dict(attr_dist=attr_dist, attr=attr, glimpse=glimpse)

            for key, value in built.items():
                _tensors[key][h, w, b] = value
            partial_program = tf.concat([partial_program, built['attr']],
                                        axis=1)

            # --- z ---

            layer_inp = tf.concat([base_features, features, partial_program],
                                  axis=1)
            n_features = self.n_passthrough_features

            network_output = self.z_network(layer_inp, 2 + n_features,
                                            self.is_training)
            z_mean, z_log_std, features = tf.split(network_output,
                                                   (1, 1, n_features),
                                                   axis=1)
            z_std = self.std_nonlinearity(z_log_std)

            z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (
                1 - self.training_wheels) * z_mean
            z_std = self.training_wheels * tf.stop_gradient(z_std) + (
                1 - self.training_wheels) * z_std
            z_logit_dist = Normal(loc=z_mean, scale=z_std)
            z_logits = z_logit_dist.sample()
            z = self.z_nonlinearity(z_logits)

            if "z" in self.no_gradient:
                z = tf.stop_gradient(z)

            if "z" in self.fixed_values:
                z = self.fixed_values['z'] * tf.ones_like(z)

            built = dict(z_logit_dist=z_logit_dist, z=z)

            for key, value in built.items():
                _tensors[key][h, w, b] = value
            partial_program = tf.concat([partial_program, built['z']], axis=1)

            # --- obj ---

            layer_inp = tf.concat([base_features, features, partial_program],
                                  axis=1)
            rep_input = self.obj_network(layer_inp, 1, self.is_training)

            built = self._build_obj(rep_input, self.is_training)

            for key, value in built.items():
                _tensors[key][h, w, b] = value
            partial_program = tf.concat([partial_program, built['obj']],
                                        axis=1)

            # --- final ---

            program[h, w, b] = partial_program
            assert program[h, w, b].shape[1] == total_sample_size

        # --- merge tensors from different grid cells ---

        objects = AttrDict()
        for k, v in _tensors.items():
            if k.endswith('_dist'):
                dist = v[0, 0, 0]
                dist_class = type(dist)
                params = dist.parameters.copy()
                tensor_keys = sorted(key for key, t in params.items()
                                     if isinstance(t, tf.Tensor))
                tensor_params = {}

                for key in tensor_keys:
                    t1 = []
                    for h in range(H):
                        t2 = []
                        for w in range(W):
                            t2.append(
                                tf.stack([
                                    v[h, w, b].parameters[key]
                                    for b in range(self.B)
                                ],
                                         axis=1))
                        t1.append(tf.stack(t2, axis=1))
                    tensor_params[key] = tf.stack(t1, axis=1)

                params.update(tensor_params)
                objects[k] = dist_class(**params)
            else:
                t1 = []
                for h in range(H):
                    t2 = []
                    for w in range(W):
                        t2.append(
                            tf.stack([v[h, w, b] for b in range(self.B)],
                                     axis=1))
                    t1.append(tf.stack(t2, axis=1))
                objects[k] = tf.stack(t1, axis=1)

        objects.all = tf.concat(
            [objects.box, objects.attr, objects.z, objects.obj], axis=-1)

        # --- misc ---

        objects.n_objects = tf.fill((self.batch_size, ), self.HWB)
        objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2, 3, 4))
        objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj),
                                                    axis=(1, 2, 3, 4))

        return objects