Ejemplo n.º 1
0
    def __init__(self, nr_resnet, nr_filters, resnet_nonlinearity):
        super(PixelCNNLayer_down, self).__init__()
        self.nr_resnet = nr_resnet
        # stream from pixels above
        self.u_stream = nn.ModuleList([gated_resnet(nr_filters, down_shifted_conv2d,
                                                    resnet_nonlinearity, skip_connection=1)
                                       for _ in range(nr_resnet)])

        # stream from pixels above and to thes left
        self.ul_stream = nn.ModuleList([gated_resnet(nr_filters, down_right_shifted_conv2d,
                                                     resnet_nonlinearity, skip_connection=2)
                                        for _ in range(nr_resnet)])
Ejemplo n.º 2
0
def cond_pixel_cnn(x, gh=None, sh=None, nonlinearity=tf.nn.elu, nr_resnet=5, nr_filters=160, nr_logistic_mix=10, bn=True, dropout_p=0.0, kernel_initializer=None, kernel_regularizer=None, is_training=False, counters={}):
    name = nn.get_name("conv_pixel_cnn", counters)
    print("construct", name, "...")
    print("    * nr_resnet: ", nr_resnet)
    print("    * nr_filters: ", nr_filters)
    print("    * nr_logistic_mix: ", nr_logistic_mix)
    with tf.variable_scope(name):
        # do not use batch normalization for auto-reggressive model, force bn to be False
        with arg_scope([gated_resnet], gh=gh, sh=sh, nonlinearity=nonlinearity, dropout_p=dropout_p, counters=counters):
            with arg_scope([gated_resnet, down_shifted_conv2d, down_right_shifted_conv2d, down_shifted_deconv2d, down_right_shifted_deconv2d], bn=False, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, is_training=is_training):
                xs = nn.int_shape(x)
                x_pad = tf.concat([x,tf.ones(xs[:-1]+[1])],3) # add channel of ones to distinguish image from padding later on

                u_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3]))] # stream for pixels above
                ul_list = [down_shift(down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \
                        right_shift(down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left
                receptive_field = (2, 3)
                for rep in range(nr_resnet):
                    u_list.append(gated_resnet(u_list[-1], conv=down_shifted_conv2d))
                    ul_list.append(gated_resnet(ul_list[-1], u_list[-1], conv=down_right_shifted_conv2d))
                    receptive_field = (receptive_field[0]+1, receptive_field[1]+2)
                x_out = nin(tf.nn.elu(ul_list[-1]), 10*nr_logistic_mix)
                print("    * receptive_field", receptive_field)
                return x_out
Ejemplo n.º 3
0
    def _inference_bu(self, x, ema):

        kwargs = {
            "training": self.ph_is_training,
            "ema": ema,
            "init": self.init
        }

        stochastic_inference_bottom_up = []
        deterministic_path_top_down = []
        deterministic_path_bottom_up = []

        skip = None
        d_bu = x
        d_td = x

        dim_length = len(self.deterministic_layers[0])
        for i, dims in enumerate(self.deterministic_layers):
            assert len(dims) == dim_length
            # Build deterministic block
            for j, dim in enumerate(dims):
                scope = "deterministic_bottom_up_%i_%i" % (i, j)
                residual = False if i == 0 and j == 0 else True

                if i > 0:
                    skip = tf.concat([
                        deterministic_path_bottom_up[i - 1],
                        deterministic_path_top_down[i - 1]
                    ],
                                     axis=-1)

                d_bu = gated_resnet(d_bu, skip, dim, self.activation, scope,
                                    residual, self.dropout_inference, **kwargs)

            deterministic_path_bottom_up += [d_bu]

            for j, dim in enumerate(dims):
                scope = "deterministic_top_down_%i_%i" % (i, j)
                residual = False if i == 0 and j == 0 else True
                d_td = gated_resnet(d_td, deterministic_path_bottom_up[i], dim,
                                    self.activation, scope, residual,
                                    self.dropout_inference, **kwargs)

            deterministic_path_top_down += [d_td]

            if i == len(self.deterministic_layers) - 1:
                break  # Do not add the top z layer before building the top-down inference model.

            # Build stochastic layer
            dim = self.stochastic_layers[i]
            scope = "qz_bottom_up_%i" % (i + 1)
            q_z_bottom_up, q_mean_bottom_up, q_var_bottom_up = self._stochastic(
                d_bu, dim, scope, ema)

            stochastic_inference_bottom_up += [
                (q_z_bottom_up, q_mean_bottom_up, q_var_bottom_up)
            ]

            if len(q_z_bottom_up.get_shape()) == 2:
                flatten_shape = [int(dim) for dim in d_bu.get_shape()[1:]]
                scope = "dense2conv_bottom_up_%i" % (i + 1)
                d_bu = dense(q_z_bottom_up, np.prod(flatten_shape), scope,
                             **kwargs)
                d_bu = tf.reshape(d_bu, [-1] + flatten_shape)
            else:
                d_bu = q_z_bottom_up

        # Build top stochastic layer q(z_L | x) from the input of bottom-up and top-down inference.
        dim = self.deterministic_layers[-1][-1]
        d = tf.concat([d_td, d_bu], axis=-1)
        scope = "deterministic_top"
        d = gated_resnet(d, None, [dim[0], dim[1], [1, 1]], self.activation,
                         scope, True, self.dropout_inference, **kwargs)
        flatten_shapes = [[int(dim) for dim in d.get_shape()[1:]]]

        scope = "qz_top_%i" % len(self.stochastic_layers)
        q_z_top, q_mean_top, q_var_top = self._stochastic(
            d, self.stochastic_layers[-1], scope, ema)

        stochastic_inference_bottom_up += [(q_z_top, q_mean_top, q_var_top)]

        return stochastic_inference_bottom_up, deterministic_path_top_down, flatten_shapes
Ejemplo n.º 4
0
    def _inference_td(self, stochastic_inference_bottom_up,
                      deterministic_path_top_down, flatten_shapes, ema):
        kwargs = {
            "training": self.ph_is_training,
            "ema": ema,
            "init": self.init
        }

        stochastic_inference_top_down = []

        z_top = stochastic_inference_bottom_up[-1][0]

        stochastic_inference_bottom_up = stochastic_inference_bottom_up[:-1]
        stochastic_layers_reordered = self.stochastic_layers[::-1][1:]
        deterministic_path_top_down_reordered = deterministic_path_top_down[::-1][
            1:]
        generative_skip_connections = []

        for i, dims in enumerate(self.deterministic_layers[::-1][:-1]):

            z_top_reshaped = z_top
            scope_index = len(self.stochastic_layers) - (i + 1)

            if len(z_top_reshaped.get_shape()) == 2:
                scope = "dense2conv_%i" % scope_index
                z_top_reshaped = dense(z_top_reshaped,
                                       np.prod(flatten_shapes[i]), scope,
                                       **kwargs)
                z_top_reshaped = tf.reshape(z_top_reshaped,
                                            [-1] + flatten_shapes[i])

            # Build deterministic block of generative model.
            d_p = z_top_reshaped
            for j, dim in enumerate(dims[::-1]):
                skip = None
                if i > 0:
                    skip = generative_skip_connections.pop()
                residual = False if j == 0 else True
                scope = "deterministic_generative_%i_%i" % (i, j)
                d_p = transposed_gated_resnet(d_p, skip, dim, self.activation,
                                              scope, residual,
                                              self.dropout_generative,
                                              **kwargs)
                generative_skip_connections = [d_p
                                               ] + generative_skip_connections

            q_z_bottom_up = stochastic_inference_bottom_up[::-1][i][0]

            # Build top-down stochastic layer q(z_(L-(i+1)) | z_(L-i)) and p(z_(L-(i+1)) | z_(L-i))
            d_td = deterministic_path_top_down_reordered[i]
            scope = "qz_top_down_pz_merge_%i" % scope_index
            dim = dims[::-1][-1]
            dim = [dim[0], dim[1], [1, 1]]
            d_td = gated_resnet(d_td, d_p, dim, self.activation, scope, True,
                                self.dropout_generative, **kwargs)

            flatten_shapes += [[int(d) for d in d_td.get_shape()[1:]]]
            scope = "qz_top_down_%i" % scope_index
            q_z_top_down, q_mean_top_down, q_var_top_down = self._stochastic(
                d_td, stochastic_layers_reordered[i], scope, ema)

            stochastic_inference_top_down += [(q_z_top_down, q_mean_top_down,
                                               q_var_top_down)]

            z_top = tf.concat([q_z_top_down, q_z_bottom_up], axis=-1)

        return stochastic_inference_top_down[::-1]