def reconstruct_images(self, save_file, sess, x, block_shape, show_original_images=True, batch_size=20, dec_output_2_img_func=None, **kwargs): if batch_size < 0: x1 = self.reconstruct(sess, x, **kwargs) else: x1 = [] for batch_ids in iterate_data(len(x), batch_size, shuffle=False): x1.append(self.reconstruct(sess, x[batch_ids], **kwargs)) x1 = np.concatenate(x1, axis=0) if dec_output_2_img_func is not None: x1 = dec_output_2_img_func(x1) x = dec_output_2_img_func(x) x1 = np.reshape(x1, to_list(block_shape) + self.x_shape) x = np.reshape(x, to_list(block_shape) + self.x_shape) if show_original_images: save_img_blocks_col_by_col(save_file, [x, x1]) else: save_img_block(save_file, x1)
def __init__(self, x_shape, z_shape): super(BaseLatentModel, self).__init__() self.x_shape = to_list(x_shape) self.z_shape = to_list(z_shape) self.x_ph = tf.placeholder(tf.float32, [None] + self.x_shape, name="x") self.z_ph = tf.placeholder(tf.float32, [None] + self.z_shape, name="z")
def __init__(self, z_shape, stochastic=False, activation=tf.nn.relu): self.z_shape = to_list(z_shape) self.z_dim = prod(self.z_shape) self.stochastic = stochastic self.activation = activation print("[{}] activation: {}".format(self.__class__.__name__, self.activation))
def __init__(self, x_shape, y_shape, input_perturber, main_classifier, cons_mode='mse', cons_4_unlabeled_only=False, same_perturbed_inputs=False, weight_decay=-1.0): LiteBaseModel.__init__(self) self.x_shape = to_list(x_shape) self.y_shape = to_list(y_shape) assert len(self.y_shape) == 1, "'y_shape' must be a scalar or an array of length 1!" self.num_classes = self.y_shape[0] self.x_ph = tf.placeholder(tf.float32, [None] + self.x_shape, name="x") self.y_ph = tf.placeholder(tf.int32, [None], name="y") self.label_flag_ph = tf.placeholder(tf.bool, [None], name="label_flag") # Main classifier self.main_classifier_name = "main_classifier" self.main_classifier_fn = tf.make_template( self.main_classifier_name, main_classifier, is_train=self.is_train) # Input perturber self.input_perturber_name = "input_perturber" self.input_perturber_fn = tf.make_template( self.input_perturber_name, input_perturber, is_train=self.is_train) possible_cons_modes = ['mse', 'kld', 'rev_kld', '2rand'] assert cons_mode in possible_cons_modes, \ "'cons_mode' must be in {}. Found {}!".format(possible_cons_modes, cons_mode) self.cons_mode = cons_mode self.cons_4_unlabeled_only = cons_4_unlabeled_only self.same_perturbed_inputs = same_perturbed_inputs self.weight_decay = weight_decay print("In class [{}]:".format(self.__class__.__name__)) print("cons_mode: {}".format(self.cons_mode)) print("cons_4_unlabeled_only: {}".format(self.cons_4_unlabeled_only)) print("same_perturbed_inputs: {}".format(self.same_perturbed_inputs)) print("weight_decay: {}".format(self.weight_decay))
def generate_images(self, save_file, sess, z, block_shape, batch_size=20, dec_output_2_img_func=None, **kwargs): if batch_size < 0: x1_gen = self.decode(sess, z, **kwargs) else: x1_gen = [] for batch_ids in iterate_data(len(z), batch_size, shuffle=False): x1_gen.append(self.decode(sess, z[batch_ids], **kwargs)) x1_gen = np.concatenate(x1_gen, axis=0) if dec_output_2_img_func is not None: x1_gen = dec_output_2_img_func(x1_gen) x1_gen = np.reshape(x1_gen, to_list(block_shape) + self.x_shape) save_img_block(save_file, x1_gen)