Exemplo n.º 1
0
    def __init__(self, data):

        self.graph = tf.Graph()
        self.sess = tf.Session()
        gpu_options = tf.GPUOptions(allow_growth=True)
        self.sess = tf.Session(graph=self.graph,
                               config=tf.ConfigProto(gpu_options=gpu_options))

        self.generator = generator
        self.discriminator = discriminator
        self.data = data

        # data
        self.cat_dim = self.data.cat_dim
        self.code_con_dim = self.data.code_con_dim
        self.total_con_dim = self.data.total_con_dim
        self.channel = self.data.channel
        self.dataset_path = self.data.path
        self.dataset_name = self.data.name
        self.split_name = self.data.split_name
        self.batch_size = self.data.batch_size
        with self.graph.as_default():
            with slim.queues.QueueRunners(self.sess):
                self.dataset, self.real_data, self.labels = load_batch(
                    self.dataset_path, self.dataset_name, self.split_name,
                    self.batch_size)
                tf.train.start_queue_runners(self.sess)
                self.gen_input_noise, self.gen_input_code = get_infogan_noise(
                    self.batch_size, self.cat_dim, self.code_con_dim,
                    self.total_con_dim)

                with variable_scope.variable_scope(
                        'generator') as self.gen_scope:
                    self.gen_data = self.generator(
                        self.gen_input_noise,
                        self.gen_input_code)  #real/fake loss

                with variable_scope.variable_scope(
                        'discriminator') as self.dis_scope:
                    self.dis_gen_data, self.Q_net = self.discriminator(
                        self.gen_data, self.cat_dim, self.code_con_dim
                    )  #real/fake loss + I(c' ; X_{data}) loss
                with variable_scope.variable_scope(self.dis_scope.name,
                                                   reuse=True):
                    self.real_data = ops.convert_to_tensor(self.real_data)
                    self.dis_real_data, _ = self.discriminator(
                        self.real_data, self.cat_dim,
                        self.code_con_dim)  #real/fake loss

                #loss
                self.dis_var = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES,
                    scope=self.dis_scope.name)
                self.gen_var = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES,
                    scope=self.gen_scope.name)

                self.D_loss = losses_fn.wasserstein_discriminator_loss(
                    self.dis_real_data, self.dis_gen_data)
                self.G_loss = losses_fn.wasserstein_generator_loss(
                    self.dis_gen_data)
                self.wasserstein_gradient_penalty_loss = losses_fn.wasserstein_gradient_penalty_infogan(
                    self, self.real_data, self.gen_data)
                self.mutual_information_loss = losses_fn.mutual_information_penalty(
                    self.gen_input_code, self.Q_net)

                tf.summary.scalar(
                    'D_loss',
                    self.D_loss + self.wasserstein_gradient_penalty_loss)
                tf.summary.scalar('G_loss', self.G_loss)
                tf.summary.scalar('Mutual_information_loss',
                                  self.mutual_information_loss)
                # tf.summary.scalar('log_prob_cat', self.log_prob_cat)
                # tf.summary.scalar('log_prob_con', self.log_prob_con)
                self.merged = tf.summary.merge_all()

                self.global_step = tf.Variable(0,
                                               name='global_step',
                                               trainable=False)

                #solver
                self.D_solver = tf.train.AdamOptimizer(
                    0.001, beta1=0.5).minimize(
                        self.D_loss + self.wasserstein_gradient_penalty_loss,
                        var_list=self.dis_var,
                        global_step=self.global_step)
                self.G_solver = tf.train.AdamOptimizer(
                    0.0001, beta1=0.5).minimize(self.G_loss,
                                                var_list=self.gen_var)
                self.mutual_information_solver = tf.train.AdamOptimizer(
                    0.0001,
                    beta1=0.5).minimize(self.mutual_information_loss,
                                        var_list=self.gen_var + self.dis_var)
                self.saver = tf.train.Saver()
                self.initializer = tf.global_variables_initializer()
Exemplo n.º 2
0
    def __init__(self, data):

        self.graph = tf.Graph()
        self.sess = tf.Session()
        gpu_options = tf.GPUOptions(allow_growth=True)
        self.sess = tf.Session(graph = self.graph, config=tf.ConfigProto(gpu_options=gpu_options))

        self.generator = generator
        self.discriminator = discriminator
        self.data = data

        # data
        self.cat_dim = self.data.cat_dim
        self.code_con_dim = self.data.code_con_dim
        self.total_con_dim = self.data.total_con_dim
        self.channel = self.data.channel
        self.dataset_path = self.data.path
        self.dataset_name = self.data.name
        self.split_name = self.data.split_name
        self.batch_size = self.data.batch_size
        self.visual_prior_path = self.data.visual_prior_path
        with self.graph.as_default():
            with slim.queues.QueueRunners(self.sess):
                self.dataset, self.real_data, self.labels = load_batch(self.dataset_path, self.dataset_name, self.split_name, self.batch_size)

                visual_prior = {'category' : list(range(10)), 'rotation' : ['min', 'max'], 'width' : ['min', 'max']}
                self.visual_prior_images = {}
                for key in visual_prior.keys():
                    self.visual_prior_images[key] = {}
                    for attribute in visual_prior[key]:
                        self.visual_prior_images[key][attribute] = []
                        path = os.path.join(self.visual_prior_path, key, str(attribute))
                        for img_file in os.listdir(path):
                            sample = cv2.imread(os.path.join(path, img_file))
                            sample = cv2.cvtColor(sample, cv2.COLOR_BGR2GRAY)
                            sample = (tf.to_float(sample) - 128.0) / 128.0
                            sample = tf.reshape(sample, (28, 28, 1))
                            self.visual_prior_images[key][attribute].append(sample)
                        self.visual_prior_images[key][attribute] = ops.convert_to_tensor(self.visual_prior_images[key][attribute])

                self.variation_key = [key_name for key_name in self.visual_prior_images.keys() if key_name!='category']
                self.variation_key.sort()
                print(self.variation_key)


                tf.train.start_queue_runners(self.sess)
                self.gen_input_noise, self.gen_input_code = get_infogan_noise(self.batch_size, self.cat_dim, self.code_con_dim, self.total_con_dim)

                with variable_scope.variable_scope('generator') as self.gen_scope:
                    self.gen_data = self.generator(self.gen_input_noise, self.gen_input_code) #real/fake loss
                
                with variable_scope.variable_scope('discriminator') as self.dis_scope:
                    self.dis_gen_data, self.Q_net = self.discriminator(self.gen_data, self.cat_dim, self.code_con_dim) #real/fake loss + I(c' ; X_{data}) loss
                with variable_scope.variable_scope(self.dis_scope.name, reuse = True):
                    self.real_data = ops.convert_to_tensor(self.real_data)
                    self.dis_real_data, _ = self.discriminator(self.real_data, self.cat_dim, self.code_con_dim) #real/fake loss 

                #loss
                self.dis_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.dis_scope.name)
                self.gen_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.gen_scope.name)

                self.D_loss = losses_fn.wasserstein_discriminator_loss(self.dis_real_data, self.dis_gen_data)
                self.G_loss = losses_fn.wasserstein_generator_loss(self.dis_gen_data)
                self.wasserstein_gradient_penalty_loss = losses_fn.wasserstein_gradient_penalty_infogan(self, self.real_data, self.gen_data)
                self.mutual_information_loss = losses_fn.mutual_information_penalty(self.gen_input_code, self.Q_net)
                self.visual_prior_penalty = losses_fn.visual_prior_penalty(self, self.visual_prior_images)

                tf.summary.scalar('D_loss', self.D_loss + self.wasserstein_gradient_penalty_loss)
                tf.summary.scalar('G_loss', self.G_loss)
                tf.summary.scalar('Mutual_information_loss', self.mutual_information_loss)
                tf.summary.scalar('visual_prior_loss', self.visual_prior_penalty)
                # tf.summary.scalar('log_prob_cat', self.log_prob_cat)
                # tf.summary.scalar('log_prob_con', self.log_prob_con)
                self.merged = tf.summary.merge_all()

                self.global_step = tf.Variable(0, name='global_step', trainable=False)
                
                #solver
                self.D_solver = tf.train.AdamOptimizer(0.001, beta1=0.5).minimize(self.D_loss+self.wasserstein_gradient_penalty_loss, var_list=self.dis_var, global_step=self.global_step)
                self.G_solver = tf.train.AdamOptimizer(0.0001, beta1=0.5).minimize(self.G_loss, var_list=self.gen_var)
                self.mutual_information_solver = tf.train.AdamOptimizer(0.0001, beta1=0.5).minimize(self.mutual_information_loss + self.visual_prior_penalty, var_list=self.gen_var + self.dis_var)
                self.saver = tf.train.Saver()
                self.initializer = tf.global_variables_initializer()