def __model(self, x, x_bar, is_training, dropout_p, masks, input_masks, network_size="medium"): print("****** Building Graph ******") self.x = x self.x_bar = x_bar self.is_training = is_training self.dropout_p = dropout_p self.masks = masks self.input_masks = input_masks if int_shape(x)[1] == 64: conv_encoder = conv_encoder_64_medium conv_decoder = conv_decoder_64_medium elif int_shape(x)[1] == 32: if network_size == 'medium': conv_encoder = conv_encoder_32_medium conv_decoder = conv_decoder_32_medium elif network_size == 'large': conv_encoder = conv_encoder_32_large conv_decoder = conv_decoder_32_large elif network_size == 'large1': conv_encoder = conv_encoder_32_large1 conv_decoder = conv_decoder_32_large1 else: raise Exception("unknown network type") with arg_scope([conv_encoder, conv_decoder, cond_pixel_cnn], nonlinearity=self.nonlinearity, bn=self.bn, kernel_initializer=self.kernel_initializer, kernel_regularizer=self.kernel_regularizer, is_training=self.is_training, counters=self.counters): inputs = self.x self.z_mu, self.z_log_sigma_sq = conv_encoder(inputs, self.z_dim) sigma = tf.exp(self.z_log_sigma_sq / 2.) if self.use_mode == 'train': self.z = gaussian_sampler(self.z_mu, sigma) elif self.use_mode == 'test': self.z = tf.placeholder(tf.float32, shape=int_shape(self.z_mu)) print("use mode:{0}".format(self.use_mode)) self.decoded_features = conv_decoder(self.z, output_features=True) sh = self.decoded_features self.mix_logistic_params = cond_pixel_cnn( self.x_bar, sh=sh, bn=False, dropout_p=self.dropout_p, nr_resnet=self.nr_resnet, nr_filters=self.nr_filters, nr_logistic_mix=self.nr_logistic_mix) self.x_hat = mix_logistic_sampler( self.mix_logistic_params, nr_logistic_mix=self.nr_logistic_mix, sample_range=self.sample_range, counters=self.counters)
def _model(self): default_args = { "nonlinearity": self.nonlinearity, "bn": self.bn, "kernel_initializer": self.kernel_initializer, "kernel_regularizer": self.kernel_regularizer, "is_training": self.is_training, "counters": self.counters, } with arg_scope([self.conditional_decoder], **default_args): default_args.update({"bn": False}) with arg_scope([self.sample_encoder, self.aggregator], **default_args): num_c = tf.shape(self.X_c)[0] X_ct = tf.concat([self.X_c, self.X_t], axis=0) y_ct = tf.concat([self.y_c, self.y_t], axis=0) r_ct = self.sample_encoder(X_ct, y_ct, self.r_dim) self.r_ct = r_ct #r_c, r_t = r_ct[:, :num_c], r_ct[:, num_c:] #self.z_mu_pr, self.z_log_sigma_sq_pr = aggregator(r_c, self.z_dim) self.z_mu_pr, self.z_log_sigma_sq_pr, self.z_mu_pos, self.z_log_sigma_sq_pos = self.aggregator( r_ct, num_c, self.z_dim) # z = gaussian_sampler(self.z_mu_pos, self.z_log_sigma_sq_pos) if self.user_mode == 'train': z = gaussian_sampler(self.z_mu_pos, tf.exp(0.5 * self.z_log_sigma_sq_pos)) elif self.user_mode == 'eval': z = self.z_mu_pos else: raise Exception("unknown user_mode") z = (1 - self.use_z_ph) * z + self.use_z_ph * self.z_ph y_hat = self.conditional_decoder(self.X_t, z) return y_hat
def __model(self, x, is_training): print("****** Building Graph ******") self.x = x self.is_training = is_training if int_shape(x)[1] == 64: encoder = conv_encoder_64_medium decoder = conv_decoder_64_medium elif int_shape(x)[1] == 32: encoder = conv_encoder_32_medium decoder = conv_decoder_32_medium with arg_scope([encoder, decoder], nonlinearity=self.nonlinearity, bn=self.bn, kernel_initializer=self.kernel_initializer, kernel_regularizer=self.kernel_regularizer, is_training=self.is_training, counters=self.counters): self.z_mu, self.z_log_sigma_sq = encoder(x, self.z_dim) sigma = tf.exp(self.z_log_sigma_sq / 2.) if self.use_mode == 'train': self.z = gaussian_sampler(self.z_mu, sigma) elif self.use_mode == 'test': self.z = tf.placeholder(tf.float32, shape=int_shape(self.z_mu)) print("use mode:{0}".format(self.use_mode)) self.x_hat = decoder(self.z)
def _model(self): default_args = { "nonlinearity": self.nonlinearity, "bn": self.bn, "kernel_initializer": self.kernel_initializer, "kernel_regularizer": self.kernel_regularizer, "is_training": self.is_training, "counters": self.counters, } with arg_scope([self.conditional_decoder], **default_args): default_args.update({"bn": False}) with arg_scope([self.sample_encoder, self.aggregator], **default_args): num_c = tf.shape(self.X_c)[0] X_ct = tf.concat([self.X_c, self.X_t], axis=0) y_ct = tf.concat([self.y_c, self.y_t], axis=0) r_ct = self.sample_encoder(X_ct, y_ct, self.r_dim) self.z_mu_pr, self.z_log_sigma_sq_pr, self.z_mu_pos, self.z_log_sigma_sq_pos = self.aggregator( r_ct, num_c, self.z_dim) if self.user_mode == 'train': z = gaussian_sampler(self.z_mu_pos, tf.exp(0.5 * self.z_log_sigma_sq_pos)) elif self.user_mode == 'eval': z = self.z_mu_pos else: raise Exception("unknown user_mode") z = (1 - self.use_z_ph) * z + self.use_z_ph * self.z_ph # add maml ops y_hat = self.conditional_decoder(self.X_c, z) vars = get_trainable_variables(['conditional_decoder']) inner_iters = 1 eval_iters = 10 y_hat_test_arr = [ self.conditional_decoder(self.X_t, z, params=vars.copy()) ] for k in range(1, max(inner_iters, eval_iters) + 1): loss = sum_squared_error(labels=self.y_c, predictions=y_hat) grads = tf.gradients(loss, vars, colocate_gradients_with_ops=True) vars = [v - self.alpha * g for v, g in zip(vars, grads)] y_hat = self.conditional_decoder(self.X_c, z, params=vars.copy()) y_hat_test = self.conditional_decoder(self.X_t, z, params=vars.copy()) y_hat_test_arr.append(y_hat_test) self.eval_ops = y_hat_test_arr return y_hat_test_arr[inner_iters]
def __model(self, network_type="large"): print("****** Building Graph ******") # placeholders if self.network_type == 'binary': num_channels = 1 else: num_channels = 3 self.x = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, num_channels)) self.x_bar = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, num_channels)) self.is_training = tf.placeholder(tf.bool, shape=()) self.dropout_p = tf.placeholder(tf.float32, shape=()) self.masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) self.input_masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) self.use_prior = tf.placeholder_with_default(False, shape=()) # choose network size if self.img_size == 32: if self.network_type == 'large': encoder = conv_encoder_32_large decoder = conv_decoder_32_large else: encoder = conv_encoder_32 decoder = conv_decoder_32 forward_pixelcnn = forward_pixel_cnn_32_small reverse_pixelcnn = reverse_pixel_cnn_32_small kwargs = { "nonlinearity": self.nonlinearity, "bn": self.bn, "kernel_initializer": self.kernel_initializer, "kernel_regularizer": self.kernel_regularizer, "is_training": self.is_training, "counters": self.counters, } with arg_scope([forward_pixelcnn, reverse_pixelcnn, encoder, decoder], **kwargs): kwargs_pixelcnn = { "nr_resnet": self.nr_resnet, "nr_filters": self.nr_filters, "nr_logistic_mix": self.nr_logistic_mix, "dropout_p": self.dropout_p, "bn": False, } with arg_scope([forward_pixelcnn, reverse_pixelcnn], **kwargs_pixelcnn): inputs = self.x if self.input_masks is not None: inputs = inputs * broadcast_masks_tf(self.input_masks, num_channels=3) inputs += tf.random_uniform(int_shape(inputs), -1, 1) * ( 1 - broadcast_masks_tf(self.input_masks, num_channels=3)) inputs = tf.concat([ inputs, broadcast_masks_tf(self.input_masks, num_channels=1) ], axis=-1) self.z_mu, self.z_log_sigma_sq = encoder(inputs, self.z_dim) sigma = tf.exp(self.z_log_sigma_sq / 2.) self.z = gaussian_sampler(self.z_mu, sigma) self.z_pr = gaussian_sampler(tf.zeros_like(self.z_mu), tf.ones_like(sigma)) self.z_ph = tf.placeholder_with_default( tf.zeros_like(self.z_mu), shape=int_shape(self.z_mu)) self.use_z_ph = tf.placeholder_with_default(False, shape=()) use_prior = tf.cast(tf.cast(self.use_prior, tf.int32), tf.float32) use_z_ph = tf.cast(tf.cast(self.use_z_ph, tf.int32), tf.float32) z = (use_prior * self.z_pr + (1 - use_prior) * self.z) * ( 1 - use_z_ph) + use_z_ph * self.z_ph decoded_features = decoder(z, output_features=True) r_outputs = reverse_pixelcnn(self.x_bar, self.masks, bn=False) cond_features = tf.concat([r_outputs, decoded_features], axis=-1) self.mix_logistic_params = forward_pixelcnn(self.x_bar, cond_features, bn=False) self.x_hat = mix_logistic_sampler( self.mix_logistic_params, nr_logistic_mix=self.nr_logistic_mix, sample_range=self.sample_range, counters=self.counters)
def __model(self): print("****** Building Graph ******") # placeholders if self.network_type == 'binary': self.num_channels = 1 else: self.num_channels = 3 self.x = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.x_bar = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size, self.num_channels)) self.is_training = tf.placeholder(tf.bool, shape=()) self.dropout_p = tf.placeholder(tf.float32, shape=()) self.masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) self.input_masks = tf.placeholder(tf.float32, shape=(self.batch_size, self.img_size, self.img_size)) # choose network size if self.img_size == 32: if self.network_type == 'large': encoder = conv_encoder_32_large_bn decoder = conv_decoder_32_large_mixture_logistic else: raise Exception("unknown network type") elif self.img_size == 28: if self.network_type == 'binary': encoder = conv_encoder_28_binary decoder = conv_decoder_28_binary else: raise Exception("unknown network type") kwargs = { "nonlinearity": self.nonlinearity, "bn": self.bn, "kernel_initializer": self.kernel_initializer, "kernel_regularizer": self.kernel_regularizer, "is_training": self.is_training, "counters": self.counters, } with arg_scope([encoder, decoder], **kwargs): inputs = self.x inputs = inputs * broadcast_masks_tf( self.masks, num_channels=self.num_channels) inputs = tf.concat( [inputs, broadcast_masks_tf(self.masks, num_channels=1)], axis=-1) self.z_mu, self.z_log_sigma_sq = encoder(inputs, self.z_dim) sigma = tf.exp(self.z_log_sigma_sq / 2.) self.z = gaussian_sampler(self.z_mu, sigma) self.z_ph = tf.placeholder_with_default(tf.zeros_like(self.z_mu), shape=int_shape(self.z_mu)) self.use_z_ph = tf.placeholder_with_default(False, shape=()) use_z_ph = tf.cast(tf.cast(self.use_z_ph, tf.int32), tf.float32) z = self.z * (1 - use_z_ph) + use_z_ph * self.z_ph if self.network_type == 'binary': self.pixel_params = decoder(z) self.x_hat = bernoulli_sampler(self.pixel_params, counters=self.counters) else: self.pixel_params = decoder( z, nr_logistic_mix=self.nr_logistic_mix) self.x_hat = mix_logistic_sampler( self.pixel_params, nr_logistic_mix=self.nr_logistic_mix, sample_range=self.sample_range, counters=self.counters)