Esempio n. 1
0
 def __model(self,
             x,
             x_bar,
             is_training,
             dropout_p,
             masks,
             input_masks,
             network_size="medium"):
     print("******   Building Graph   ******")
     self.x = x
     self.x_bar = x_bar
     self.is_training = is_training
     self.dropout_p = dropout_p
     self.masks = masks
     self.input_masks = input_masks
     if int_shape(x)[1] == 64:
         conv_encoder = conv_encoder_64_medium
         conv_decoder = conv_decoder_64_medium
     elif int_shape(x)[1] == 32:
         if network_size == 'medium':
             conv_encoder = conv_encoder_32_medium
             conv_decoder = conv_decoder_32_medium
         elif network_size == 'large':
             conv_encoder = conv_encoder_32_large
             conv_decoder = conv_decoder_32_large
         elif network_size == 'large1':
             conv_encoder = conv_encoder_32_large1
             conv_decoder = conv_decoder_32_large1
         else:
             raise Exception("unknown network type")
     with arg_scope([conv_encoder, conv_decoder, cond_pixel_cnn],
                    nonlinearity=self.nonlinearity,
                    bn=self.bn,
                    kernel_initializer=self.kernel_initializer,
                    kernel_regularizer=self.kernel_regularizer,
                    is_training=self.is_training,
                    counters=self.counters):
         inputs = self.x
         self.z_mu, self.z_log_sigma_sq = conv_encoder(inputs, self.z_dim)
         sigma = tf.exp(self.z_log_sigma_sq / 2.)
         if self.use_mode == 'train':
             self.z = gaussian_sampler(self.z_mu, sigma)
         elif self.use_mode == 'test':
             self.z = tf.placeholder(tf.float32, shape=int_shape(self.z_mu))
         print("use mode:{0}".format(self.use_mode))
         self.decoded_features = conv_decoder(self.z, output_features=True)
         sh = self.decoded_features
         self.mix_logistic_params = cond_pixel_cnn(
             self.x_bar,
             sh=sh,
             bn=False,
             dropout_p=self.dropout_p,
             nr_resnet=self.nr_resnet,
             nr_filters=self.nr_filters,
             nr_logistic_mix=self.nr_logistic_mix)
         self.x_hat = mix_logistic_sampler(
             self.mix_logistic_params,
             nr_logistic_mix=self.nr_logistic_mix,
             sample_range=self.sample_range,
             counters=self.counters)
    def _model(self):
        default_args = {
            "nonlinearity": self.nonlinearity,
            "bn": self.bn,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "is_training": self.is_training,
            "counters": self.counters,
        }
        with arg_scope([self.conditional_decoder], **default_args):
            default_args.update({"bn": False})
            with arg_scope([self.sample_encoder, self.aggregator],
                           **default_args):
                num_c = tf.shape(self.X_c)[0]
                X_ct = tf.concat([self.X_c, self.X_t], axis=0)
                y_ct = tf.concat([self.y_c, self.y_t], axis=0)
                r_ct = self.sample_encoder(X_ct, y_ct, self.r_dim)
                self.r_ct = r_ct
                #r_c, r_t = r_ct[:, :num_c], r_ct[:, num_c:]

                #self.z_mu_pr, self.z_log_sigma_sq_pr = aggregator(r_c, self.z_dim)
                self.z_mu_pr, self.z_log_sigma_sq_pr, self.z_mu_pos, self.z_log_sigma_sq_pos = self.aggregator(
                    r_ct, num_c, self.z_dim)
                # z = gaussian_sampler(self.z_mu_pos, self.z_log_sigma_sq_pos)
                if self.user_mode == 'train':
                    z = gaussian_sampler(self.z_mu_pos,
                                         tf.exp(0.5 * self.z_log_sigma_sq_pos))
                elif self.user_mode == 'eval':
                    z = self.z_mu_pos
                else:
                    raise Exception("unknown user_mode")
                z = (1 - self.use_z_ph) * z + self.use_z_ph * self.z_ph
                y_hat = self.conditional_decoder(self.X_t, z)
                return y_hat
 def __model(self, x, is_training):
     print("******   Building Graph   ******")
     self.x = x
     self.is_training = is_training
     if int_shape(x)[1] == 64:
         encoder = conv_encoder_64_medium
         decoder = conv_decoder_64_medium
     elif int_shape(x)[1] == 32:
         encoder = conv_encoder_32_medium
         decoder = conv_decoder_32_medium
     with arg_scope([encoder, decoder],
                    nonlinearity=self.nonlinearity,
                    bn=self.bn,
                    kernel_initializer=self.kernel_initializer,
                    kernel_regularizer=self.kernel_regularizer,
                    is_training=self.is_training,
                    counters=self.counters):
         self.z_mu, self.z_log_sigma_sq = encoder(x, self.z_dim)
         sigma = tf.exp(self.z_log_sigma_sq / 2.)
         if self.use_mode == 'train':
             self.z = gaussian_sampler(self.z_mu, sigma)
         elif self.use_mode == 'test':
             self.z = tf.placeholder(tf.float32, shape=int_shape(self.z_mu))
         print("use mode:{0}".format(self.use_mode))
         self.x_hat = decoder(self.z)
Esempio n. 4
0
    def _model(self):
        default_args = {
            "nonlinearity": self.nonlinearity,
            "bn": self.bn,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "is_training": self.is_training,
            "counters": self.counters,
        }
        with arg_scope([self.conditional_decoder], **default_args):
            default_args.update({"bn": False})
            with arg_scope([self.sample_encoder, self.aggregator],
                           **default_args):
                num_c = tf.shape(self.X_c)[0]
                X_ct = tf.concat([self.X_c, self.X_t], axis=0)
                y_ct = tf.concat([self.y_c, self.y_t], axis=0)
                r_ct = self.sample_encoder(X_ct, y_ct, self.r_dim)

                self.z_mu_pr, self.z_log_sigma_sq_pr, self.z_mu_pos, self.z_log_sigma_sq_pos = self.aggregator(
                    r_ct, num_c, self.z_dim)
                if self.user_mode == 'train':
                    z = gaussian_sampler(self.z_mu_pos,
                                         tf.exp(0.5 * self.z_log_sigma_sq_pos))
                elif self.user_mode == 'eval':
                    z = self.z_mu_pos
                else:
                    raise Exception("unknown user_mode")
                z = (1 - self.use_z_ph) * z + self.use_z_ph * self.z_ph

                # add maml ops
                y_hat = self.conditional_decoder(self.X_c, z)
                vars = get_trainable_variables(['conditional_decoder'])
                inner_iters = 1
                eval_iters = 10
                y_hat_test_arr = [
                    self.conditional_decoder(self.X_t, z, params=vars.copy())
                ]
                for k in range(1, max(inner_iters, eval_iters) + 1):
                    loss = sum_squared_error(labels=self.y_c,
                                             predictions=y_hat)
                    grads = tf.gradients(loss,
                                         vars,
                                         colocate_gradients_with_ops=True)
                    vars = [v - self.alpha * g for v, g in zip(vars, grads)]
                    y_hat = self.conditional_decoder(self.X_c,
                                                     z,
                                                     params=vars.copy())
                    y_hat_test = self.conditional_decoder(self.X_t,
                                                          z,
                                                          params=vars.copy())
                    y_hat_test_arr.append(y_hat_test)
                self.eval_ops = y_hat_test_arr
                return y_hat_test_arr[inner_iters]
    def __model(self, network_type="large"):
        print("******   Building Graph   ******")
        # placeholders
        if self.network_type == 'binary':
            num_channels = 1
        else:
            num_channels = 3
        self.x = tf.placeholder(tf.float32,
                                shape=(self.batch_size, self.img_size,
                                       self.img_size, num_channels))
        self.x_bar = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size, num_channels))
        self.is_training = tf.placeholder(tf.bool, shape=())
        self.dropout_p = tf.placeholder(tf.float32, shape=())
        self.masks = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size))
        self.input_masks = tf.placeholder(tf.float32,
                                          shape=(self.batch_size,
                                                 self.img_size, self.img_size))
        self.use_prior = tf.placeholder_with_default(False, shape=())
        # choose network size
        if self.img_size == 32:
            if self.network_type == 'large':
                encoder = conv_encoder_32_large
                decoder = conv_decoder_32_large
            else:
                encoder = conv_encoder_32
                decoder = conv_decoder_32
            forward_pixelcnn = forward_pixel_cnn_32_small
            reverse_pixelcnn = reverse_pixel_cnn_32_small
        kwargs = {
            "nonlinearity": self.nonlinearity,
            "bn": self.bn,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "is_training": self.is_training,
            "counters": self.counters,
        }
        with arg_scope([forward_pixelcnn, reverse_pixelcnn, encoder, decoder],
                       **kwargs):
            kwargs_pixelcnn = {
                "nr_resnet": self.nr_resnet,
                "nr_filters": self.nr_filters,
                "nr_logistic_mix": self.nr_logistic_mix,
                "dropout_p": self.dropout_p,
                "bn": False,
            }
            with arg_scope([forward_pixelcnn, reverse_pixelcnn],
                           **kwargs_pixelcnn):
                inputs = self.x
                if self.input_masks is not None:
                    inputs = inputs * broadcast_masks_tf(self.input_masks,
                                                         num_channels=3)
                    inputs += tf.random_uniform(int_shape(inputs), -1, 1) * (
                        1 -
                        broadcast_masks_tf(self.input_masks, num_channels=3))
                    inputs = tf.concat([
                        inputs,
                        broadcast_masks_tf(self.input_masks, num_channels=1)
                    ],
                                       axis=-1)

                self.z_mu, self.z_log_sigma_sq = encoder(inputs, self.z_dim)
                sigma = tf.exp(self.z_log_sigma_sq / 2.)
                self.z = gaussian_sampler(self.z_mu, sigma)
                self.z_pr = gaussian_sampler(tf.zeros_like(self.z_mu),
                                             tf.ones_like(sigma))

                self.z_ph = tf.placeholder_with_default(
                    tf.zeros_like(self.z_mu), shape=int_shape(self.z_mu))
                self.use_z_ph = tf.placeholder_with_default(False, shape=())

                use_prior = tf.cast(tf.cast(self.use_prior, tf.int32),
                                    tf.float32)
                use_z_ph = tf.cast(tf.cast(self.use_z_ph, tf.int32),
                                   tf.float32)
                z = (use_prior * self.z_pr + (1 - use_prior) * self.z) * (
                    1 - use_z_ph) + use_z_ph * self.z_ph

                decoded_features = decoder(z, output_features=True)
                r_outputs = reverse_pixelcnn(self.x_bar, self.masks, bn=False)
                cond_features = tf.concat([r_outputs, decoded_features],
                                          axis=-1)
                self.mix_logistic_params = forward_pixelcnn(self.x_bar,
                                                            cond_features,
                                                            bn=False)
                self.x_hat = mix_logistic_sampler(
                    self.mix_logistic_params,
                    nr_logistic_mix=self.nr_logistic_mix,
                    sample_range=self.sample_range,
                    counters=self.counters)
    def __model(self):
        print("******   Building Graph   ******")
        # placeholders
        if self.network_type == 'binary':
            self.num_channels = 1
        else:
            self.num_channels = 3
        self.x = tf.placeholder(tf.float32,
                                shape=(self.batch_size, self.img_size,
                                       self.img_size, self.num_channels))
        self.x_bar = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size, self.num_channels))
        self.is_training = tf.placeholder(tf.bool, shape=())
        self.dropout_p = tf.placeholder(tf.float32, shape=())
        self.masks = tf.placeholder(tf.float32,
                                    shape=(self.batch_size, self.img_size,
                                           self.img_size))
        self.input_masks = tf.placeholder(tf.float32,
                                          shape=(self.batch_size,
                                                 self.img_size, self.img_size))
        # choose network size
        if self.img_size == 32:
            if self.network_type == 'large':
                encoder = conv_encoder_32_large_bn
                decoder = conv_decoder_32_large_mixture_logistic
            else:
                raise Exception("unknown network type")
        elif self.img_size == 28:
            if self.network_type == 'binary':
                encoder = conv_encoder_28_binary
                decoder = conv_decoder_28_binary
            else:
                raise Exception("unknown network type")
        kwargs = {
            "nonlinearity": self.nonlinearity,
            "bn": self.bn,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "is_training": self.is_training,
            "counters": self.counters,
        }
        with arg_scope([encoder, decoder], **kwargs):
            inputs = self.x
            inputs = inputs * broadcast_masks_tf(
                self.masks, num_channels=self.num_channels)
            inputs = tf.concat(
                [inputs,
                 broadcast_masks_tf(self.masks, num_channels=1)],
                axis=-1)
            self.z_mu, self.z_log_sigma_sq = encoder(inputs, self.z_dim)
            sigma = tf.exp(self.z_log_sigma_sq / 2.)
            self.z = gaussian_sampler(self.z_mu, sigma)

            self.z_ph = tf.placeholder_with_default(tf.zeros_like(self.z_mu),
                                                    shape=int_shape(self.z_mu))
            self.use_z_ph = tf.placeholder_with_default(False, shape=())

            use_z_ph = tf.cast(tf.cast(self.use_z_ph, tf.int32), tf.float32)
            z = self.z * (1 - use_z_ph) + use_z_ph * self.z_ph

            if self.network_type == 'binary':
                self.pixel_params = decoder(z)
                self.x_hat = bernoulli_sampler(self.pixel_params,
                                               counters=self.counters)
            else:
                self.pixel_params = decoder(
                    z, nr_logistic_mix=self.nr_logistic_mix)
                self.x_hat = mix_logistic_sampler(
                    self.pixel_params,
                    nr_logistic_mix=self.nr_logistic_mix,
                    sample_range=self.sample_range,
                    counters=self.counters)