Ejemplo n.º 1
0
    def testCycleGAN(self):

        params = edict({
            'width': 128,
            'height': 128,
            'load_size': 142,
            'epochs': 100,
            'batch_size': 1,
            'lr': 2e-4,
            'beta1': 0.5,
            'pool_size': 50,
            'input_nc': 3,
            'output_nc': 3,
            'use_lsgan': True,
            'train': True,
            'seed': 42,
            'cuda': False,
            'num_gpu': 1,
            'save_interval': 1000,
            'vis_interval': 500,
            'log_interval': 50,
            'num_workers': 2
        })

        def set_random_seed(seed, cuda):
            random.seed(seed)
            torch.manual_seed(seed)
            if cuda:
                torch.cuda.manual_seed_all(seed)

        set_random_seed(params.seed, params.cuda)

        image = torch.ones([1, params.input_nc, params.height, params.width])
        model = CycleGAN(params=params)
        model.real_A.data.resize_(image.size()).copy_(image)
        model.real_B.data.resize_(image.size()).copy_(image)
        start = time.time()
        model.optimize_parameters()
        end = time.time()
        print(end - start)
Ejemplo n.º 2
0
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode = 'build'  # 'build' #

IMAGE_SIZE = 128

data_loader = DataLoader(dataset_name=DATA_NAME,
                         img_res=(IMAGE_SIZE, IMAGE_SIZE))

gan = CycleGAN(input_dim=(IMAGE_SIZE, IMAGE_SIZE, 3),
               learning_rate=0.0002,
               buffer_max_length=50,
               lambda_validation=1,
               lambda_reconstr=10,
               lambda_id=2,
               generator_type='unet',
               gen_n_filters=32,
               disc_n_filters=32)

if mode == 'build':
    gan.save(RUN_FOLDER)
else:
    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))

BATCH_SIZE = 1
EPOCHS = 200
PRINT_EVERY_N_BATCHES = 10

TEST_A_FILE = 'n07740461_14740.jpg'
Ejemplo n.º 3
0
 def __init__(self, 
              X_source_1, X_target_1, source_name_1, target_name_1, 
              X_source_2, X_target_2, source_name_2, target_name_2,
              X_source_3, X_target_3, source_name_3, target_name_3,
              X_source_4, X_target_4, source_name_4, target_name_4,
              session_id, dim_g, dim_d, dim_c,
              save_path, GPU_CONFIG, 
              lr, keep_probs,batch_size, train_epochs,
              lambda_cyc, lambda_cls, params=None):
     self.X_source_1 = X_source_1
     self.X_target_1 = X_target_1
     self.source_name_1 = source_name_1
     self.target_name_1 = target_name_1
     self.X_source_2 = X_source_2
     self.X_target_2 = X_target_2
     self.source_name_2 = source_name_2
     self.target_name_2 = target_name_2
     self.X_source_3 = X_source_3
     self.X_target_3 = X_target_3
     self.source_name_3 = source_name_3
     self.target_name_3 = target_name_3
     self.X_source_4 = X_source_4
     self.X_target_4 = X_target_4
     self.source_name_4 = source_name_4
     self.target_name_4 = target_name_4
     self.session_id = session_id
     self.dim_g = dim_g
     self.dim_d = dim_d
     self.dim_c = dim_c
     self.save_path = save_path
     self.GPU_CONFIG = GPU_CONFIG
     self.lr = lr
     self.keep_probs = keep_probs
     self.batch_size = batch_size
     self.train_epochs = train_epochs
     self.lambda_cyc = lambda_cyc
     self.lambda_cls = lambda_cls
     self.params = params
     self.cgan_1 = CycleGAN(self.X_source_1, self.X_target_1, 
                            self.source_name_1, self.target_name_1, 
                            self.session_id, self.dim_g, self.dim_d, 
                            save_path=self.save_path+'/'+ \
                                        self.target_name_1+'/', 
                            GPU_CONFIG=self.GPU_CONFIG, 
                            batch_size=self.batch_size,
                            train_epochs=self.train_epochs,
                            lambda_cyc=self.lambda_cyc,
                            params=self.params[
                                self.source_name_1+'2'+self.target_name_1
                            ])
     self.cgan_2 = CycleGAN(self.X_source_2, self.X_target_2, 
                            self.source_name_2, self.target_name_2, 
                            self.session_id, self.dim_g, self.dim_d, 
                            save_path=self.save_path+'/'+ \
                                        self.target_name_2+'/', 
                            GPU_CONFIG=self.GPU_CONFIG, 
                            batch_size=self.batch_size,
                            train_epochs=self.train_epochs,
                            lambda_cyc=self.lambda_cyc,
                            params=self.params[
                                self.source_name_2+'2'+self.target_name_2
                            ])
     self.cgan_3 = CycleGAN(self.X_source_3, self.X_target_3, 
                            self.source_name_3, self.target_name_3, 
                            self.session_id, self.dim_g, self.dim_d, 
                            save_path=self.save_path+'/'+ \
                                        self.target_name_3+'/', 
                            GPU_CONFIG=self.GPU_CONFIG, 
                            batch_size=self.batch_size,
                            train_epochs=self.train_epochs,
                            lambda_cyc=self.lambda_cyc,
                            params=self.params[
                                self.source_name_3+'2'+self.target_name_3
                            ])
     self.cgan_4 = CycleGAN(self.X_source_4, self.X_target_4, 
                            self.source_name_4, self.target_name_4, 
                            self.session_id, self.dim_g, self.dim_d, 
                            save_path=self.save_path+'/'+ \
                                        self.target_name_4+'/', 
                            GPU_CONFIG=self.GPU_CONFIG, 
                            batch_size=self.batch_size,
                            train_epochs=self.train_epochs, 
                            lambda_cyc=self.lambda_cyc,
                            params=self.params[
                                self.source_name_4+'2'+self.target_name_4
                            ])
Ejemplo n.º 4
0
class QuadCycleGAN():
    
    def __init__(self, 
                 X_source_1, X_target_1, source_name_1, target_name_1, 
                 X_source_2, X_target_2, source_name_2, target_name_2,
                 X_source_3, X_target_3, source_name_3, target_name_3,
                 X_source_4, X_target_4, source_name_4, target_name_4,
                 session_id, dim_g, dim_d, dim_c,
                 save_path, GPU_CONFIG, 
                 lr, keep_probs,batch_size, train_epochs,
                 lambda_cyc, lambda_cls, params=None):
        self.X_source_1 = X_source_1
        self.X_target_1 = X_target_1
        self.source_name_1 = source_name_1
        self.target_name_1 = target_name_1
        self.X_source_2 = X_source_2
        self.X_target_2 = X_target_2
        self.source_name_2 = source_name_2
        self.target_name_2 = target_name_2
        self.X_source_3 = X_source_3
        self.X_target_3 = X_target_3
        self.source_name_3 = source_name_3
        self.target_name_3 = target_name_3
        self.X_source_4 = X_source_4
        self.X_target_4 = X_target_4
        self.source_name_4 = source_name_4
        self.target_name_4 = target_name_4
        self.session_id = session_id
        self.dim_g = dim_g
        self.dim_d = dim_d
        self.dim_c = dim_c
        self.save_path = save_path
        self.GPU_CONFIG = GPU_CONFIG
        self.lr = lr
        self.keep_probs = keep_probs
        self.batch_size = batch_size
        self.train_epochs = train_epochs
        self.lambda_cyc = lambda_cyc
        self.lambda_cls = lambda_cls
        self.params = params
        self.cgan_1 = CycleGAN(self.X_source_1, self.X_target_1, 
                               self.source_name_1, self.target_name_1, 
                               self.session_id, self.dim_g, self.dim_d, 
                               save_path=self.save_path+'/'+ \
                                           self.target_name_1+'/', 
                               GPU_CONFIG=self.GPU_CONFIG, 
                               batch_size=self.batch_size,
                               train_epochs=self.train_epochs,
                               lambda_cyc=self.lambda_cyc,
                               params=self.params[
                                   self.source_name_1+'2'+self.target_name_1
                               ])
        self.cgan_2 = CycleGAN(self.X_source_2, self.X_target_2, 
                               self.source_name_2, self.target_name_2, 
                               self.session_id, self.dim_g, self.dim_d, 
                               save_path=self.save_path+'/'+ \
                                           self.target_name_2+'/', 
                               GPU_CONFIG=self.GPU_CONFIG, 
                               batch_size=self.batch_size,
                               train_epochs=self.train_epochs,
                               lambda_cyc=self.lambda_cyc,
                               params=self.params[
                                   self.source_name_2+'2'+self.target_name_2
                               ])
        self.cgan_3 = CycleGAN(self.X_source_3, self.X_target_3, 
                               self.source_name_3, self.target_name_3, 
                               self.session_id, self.dim_g, self.dim_d, 
                               save_path=self.save_path+'/'+ \
                                           self.target_name_3+'/', 
                               GPU_CONFIG=self.GPU_CONFIG, 
                               batch_size=self.batch_size,
                               train_epochs=self.train_epochs,
                               lambda_cyc=self.lambda_cyc,
                               params=self.params[
                                   self.source_name_3+'2'+self.target_name_3
                               ])
        self.cgan_4 = CycleGAN(self.X_source_4, self.X_target_4, 
                               self.source_name_4, self.target_name_4, 
                               self.session_id, self.dim_g, self.dim_d, 
                               save_path=self.save_path+'/'+ \
                                           self.target_name_4+'/', 
                               GPU_CONFIG=self.GPU_CONFIG, 
                               batch_size=self.batch_size,
                               train_epochs=self.train_epochs, 
                               lambda_cyc=self.lambda_cyc,
                               params=self.params[
                                   self.source_name_4+'2'+self.target_name_4
                               ])
       
    
    def build(self):
        self.lr_d = tf.placeholder(tf.float32)
        self.lr_g = tf.placeholder(tf.float32)
        self.cgan_1.build(encoder_id='encoder_1', 
                          decoder_id='decoder_1', 
                          discriminator_id_A='discriminator_A_1', 
                          discriminator_id_B='discriminator_B_1')
        self.cgan_2.build(encoder_id='encoder_2', 
                          decoder_id='decoder_2', 
                          discriminator_id_A='discriminator_A_2', 
                          discriminator_id_B='discriminator_B_2')
        self.cgan_3.build(encoder_id='encoder_3', 
                          decoder_id='decoder_3', 
                          discriminator_id_A='discriminator_A_3', 
                          discriminator_id_B='discriminator_B_3')
        self.cgan_4.build(encoder_id='encoder_4', 
                          decoder_id='decoder_4', 
                          discriminator_id_A='discriminator_A_4', 
                          discriminator_id_B='discriminator_B_4')
        self.d_loss = self.cgan_1.d_loss + self.cgan_2.d_loss + \
                        self.cgan_3.d_loss + self.cgan_4.d_loss
        self.g_loss = self.cgan_1.g_loss + self.cgan_2.g_loss + \
                        self.cgan_3.g_loss + self.cgan_4.g_loss
        # define classification loss
        if self.lambda_cls > 0:
            self.keep_prob_cls = tf.placeholder(tf.float32)
            syn = tf.concat([self.cgan_1.encoding_A, 
                             self.cgan_2.encoding_A, 
                             self.cgan_3.encoding_A, 
                             self.cgan_4.encoding_A], axis=0) 
            logits = discriminator(syn, self.dim_c, 'classifier', 
                                   keep_prob=self.keep_prob_cls)
            self.labels = tf.placeholder(tf.int64, shape=(None,))
            self.c_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=logits, labels=self.labels
                )
            )
            self.g_loss += self.lambda_cls*self.c_loss
        
        
    def set_optimizer(self):
        t_vars = tf.trainable_variables()
        self.saver = tf.train.Saver(t_vars)
#         pprint.pprint(t_vars)
        d_vars = [var for var in t_vars 
                  if 'discriminator' in var.name or 'encoder' in var.name]
        g_vars = [var for var in t_vars 
                  if 'encoder' in var.name or 'decoder' in var.name or \
                  'classifier' in var.name]
        self.D_optimizer = tf.train.AdamOptimizer(
            self.lr_d).minimize(self.d_loss, var_list=d_vars)
        self.G_optimizer = tf.train.AdamOptimizer(
            self.lr_g).minimize(self.g_loss, var_list=g_vars)
        
    
    def feed_losses(self, sess, losses, epoch):
        d_l_1, g_l_1 = sess.run([self.cgan_1.d_loss, self.cgan_1.g_loss], 
                                feed_dict={
                                    self.cgan_1.input_A: self.X_source_1, 
                                    self.cgan_1.input_B: self.X_target_1, 
                                    self.cgan_1.keep_prob: 1.0
                                }
                               )
        d_l_2, g_l_2 = sess.run([self.cgan_2.d_loss, self.cgan_2.g_loss], 
                                feed_dict={
                                    self.cgan_2.input_A: self.X_source_2, 
                                    self.cgan_2.input_B: self.X_target_2, 
                                    self.cgan_2.keep_prob: 1.0
                                }
                               )
        d_l_3, g_l_3 = sess.run([self.cgan_3.d_loss, self.cgan_3.g_loss], 
                                feed_dict={
                                    self.cgan_3.input_A: self.X_source_3, 
                                    self.cgan_3.input_B: self.X_target_3, 
                                    self.cgan_3.keep_prob: 1.0
                                }
                               )
        d_l_4, g_l_4 = sess.run([self.cgan_4.d_loss, self.cgan_4.g_loss], 
                                feed_dict={
                                    self.cgan_4.input_A: self.X_source_4, 
                                    self.cgan_4.input_B: self.X_target_4, 
                                    self.cgan_4.keep_prob: 1.0
                                }
                               )
        print('CycleGAN_1 - Epoch %i: D Loss: %f, G Loss: %f;' % (epoch, 
                                                                  d_l_1, 
                                                                  g_l_1)) 
        print('CycleGAN_2 - Epoch %i: D Loss: %f, G Loss: %f;' % (epoch, 
                                                                  d_l_2, 
                                                                  g_l_2)) 
        print('CycleGAN_3 - Epoch %i: D Loss: %f, G Loss: %f;' % (epoch, 
                                                                  d_l_3, 
                                                                  g_l_3)) 
        print('CycleGAN_4 - Epoch %i: D Loss: %f, G Loss: %f;' % (epoch, 
                                                                  d_l_4, 
                                                                  g_l_4)) 
        losses['d1'].append(d_l_1)
        losses['g1'].append(g_l_1)
        losses['d2'].append(d_l_2)
        losses['g2'].append(g_l_2)
        losses['d3'].append(d_l_3)
        losses['g3'].append(g_l_3)
        losses['d4'].append(d_l_4)
        losses['g4'].append(g_l_4)
        return losses
    
    
    def generate_synthetic_samples(self, sess, save=True):
        # synthetic samples for source 1
        enc_1 = sess.run(self.cgan_1.encoding_A, 
                         feed_dict={self.cgan_1.input_A: self.X_source_1, 
                                    self.cgan_1.input_B: self.X_target_1, 
                                    self.cgan_1.keep_prob: 1.0}
                        )
        enc_1_df = pd.DataFrame(
            enc_1, index=[v+"_syn_"+self.target_name_1 
                          for v in self.X_source_1.index.values])
        enc_1_df.columns = self.X_source_1.columns.values
        if save:
            enc_1_df.to_hdf(
                self.save_path+'syn_samples/syn_'+self.session_id+'.h5', 
                key=self.target_name_1
            )
        # synthetic samples for source 2
        enc_2 = sess.run(self.cgan_2.encoding_A, 
                         feed_dict={self.cgan_2.input_A: self.X_source_2, 
                                    self.cgan_2.input_B: self.X_target_2, 
                                    self.cgan_2.keep_prob: 1.0}
                        )
        enc_2_df = pd.DataFrame(
            enc_2, index=[v+"_syn_"+self.target_name_2 
                          for v in self.X_source_2.index.values])
        enc_2_df.columns = self.X_source_2.columns.values
        if save:
            enc_2_df.to_hdf(
                self.save_path+'syn_samples/syn_'+self.session_id+'.h5', 
                key=self.target_name_2
            )
        # synthetic samples for source 3
        enc_3 = sess.run(self.cgan_3.encoding_A, 
                         feed_dict={self.cgan_3.input_A: self.X_source_3, 
                                    self.cgan_3.input_B: self.X_target_3, 
                                    self.cgan_3.keep_prob: 1.0})
        enc_3_df = pd.DataFrame(
            enc_3, index=[v+"_syn_"+self.target_name_3 
                          for v in self.X_source_3.index.values])
        enc_3_df.columns = self.X_source_3.columns.values
        if save:
            enc_3_df.to_hdf(
                self.save_path+'syn_samples/syn_'+self.session_id+'.h5', 
                key=self.target_name_3)
        # synthetic samples for source 4
        enc_4 = sess.run(self.cgan_4.encoding_A, 
                         feed_dict={self.cgan_4.input_A: self.X_source_4, 
                                    self.cgan_4.input_B: self.X_target_4, 
                                    self.cgan_4.keep_prob: 1.0}
                        )
        enc_4_df = pd.DataFrame(
            enc_4, index=[v+"_syn_"+self.target_name_4 
                          for v in self.X_source_4.index.values])
        enc_4_df.columns = self.X_source_4.columns.values
        if save:
            enc_4_df.to_hdf(
                self.save_path+'syn_samples/syn_'+self.session_id+'.h5', 
                key=self.target_name_4)
        
        
    def learn_representation(self):
        tf.reset_default_graph()        
        self.build()
        self.set_optimizer()
        t_vars = tf.trainable_variables()
        pprint.pprint(t_vars)
        # start training
        losses = {'d1': [], 'g1': [], 'd2': [], 'g2': [], 'd3': [], 'g3': [], 
                  'd4': [], 'g4': [],}
        init = tf.global_variables_initializer() 
        with tf.Session(config=self.GPU_CONFIG) as sess:
            sess.run(init)
            for epoch in range(self.train_epochs+1):
                X_source_1_shuffle = self.X_source_1.sample(frac=1)
                X_target_1_shuffle = self.X_target_1.sample(frac=1)
                X_source_2_shuffle = self.X_source_2.sample(frac=1)
                X_target_2_shuffle = self.X_target_2.sample(frac=1)
                X_source_3_shuffle = self.X_source_3.sample(frac=1)
                X_target_3_shuffle = self.X_target_3.sample(frac=1)
                X_source_4_shuffle = self.X_source_4.sample(frac=1)
                X_target_4_shuffle = self.X_target_4.sample(frac=1)
                length_source_1 = int(
                    len(X_source_1_shuffle) / self.batch_size)
                length_target_1 = int(
                    len(X_target_1_shuffle) / self.batch_size)
                length_source_2 = int(
                    len(X_source_2_shuffle) / self.batch_size)
                length_target_2 = int(
                    len(X_target_2_shuffle) / self.batch_size)
                length_source_3 = int(
                    len(X_source_3_shuffle) / self.batch_size)
                length_target_3 = int(
                    len(X_target_3_shuffle) / self.batch_size)
                length_source_4 = int(
                    len(X_source_4_shuffle) / self.batch_size)
                length_target_4 = int(
                    len(X_target_4_shuffle) / self.batch_size)
                lr_d_current = self.lr['d']*(0.8**int(epoch / 50))
                lr_g_current = self.lr['g']*(0.8**int(epoch / 50))
                if self.lambda_cls > 0:
                    labels = np.concatenate(
                        [[i]*self.batch_size for i in range(4)]
                    )
                for i in range(min([length_source_1, length_target_1, 
                                    length_source_2, length_target_2, 
                                    length_source_3, length_target_3, 
                                    length_source_4, length_target_4])):
                    X_source_1_batch = X_source_1_shuffle.iloc[
                        i*self.batch_size : (i+1)*self.batch_size]
                    X_target_1_batch = X_target_1_shuffle.iloc[
                        i*self.batch_size : (i+1)*self.batch_size]
                    X_source_2_batch = X_source_2_shuffle.iloc[
                        i*self.batch_size : (i+1)*self.batch_size]
                    X_target_2_batch = X_target_2_shuffle.iloc[
                        i*self.batch_size : (i+1)*self.batch_size]
                    X_source_3_batch = X_source_3_shuffle.iloc[
                        i*self.batch_size : (i+1)*self.batch_size]
                    X_target_3_batch = X_target_3_shuffle.iloc[
                        i*self.batch_size : (i+1)*self.batch_size]
                    X_source_4_batch = X_source_4_shuffle.iloc[
                        i*self.batch_size : (i+1)*self.batch_size]
                    X_target_4_batch = X_target_4_shuffle.iloc[
                        i*self.batch_size : (i+1)*self.batch_size]
                    sess.run([self.D_optimizer], 
                             feed_dict={
                                 self.cgan_1.input_A: X_source_1_batch, 
                                 self.cgan_1.input_B: X_target_1_batch,
                                 self.cgan_2.input_A: X_source_2_batch, 
                                 self.cgan_2.input_B: X_target_2_batch,
                                 self.cgan_3.input_A: X_source_3_batch, 
                                 self.cgan_3.input_B: X_target_3_batch,
                                 self.cgan_4.input_A: X_source_4_batch, 
                                 self.cgan_4.input_B: X_target_4_batch,
                                 self.cgan_1.keep_prob: self.keep_probs['d'], 
                                 self.cgan_2.keep_prob: self.keep_probs['d'],
                                 self.cgan_3.keep_prob: self.keep_probs['d'], 
                                 self.cgan_4.keep_prob: self.keep_probs['d'], 
                                 self.lr_d: lr_d_current
                             })
                    for _ in range(2):
                        feed_dict={
                            self.cgan_1.input_A: X_source_1_batch, 
                            self.cgan_1.input_B: X_target_1_batch, 
                            self.cgan_2.input_A: X_source_2_batch, 
                            self.cgan_2.input_B: X_target_2_batch,
                            self.cgan_3.input_A: X_source_3_batch, 
                            self.cgan_3.input_B: X_target_3_batch,
                            self.cgan_4.input_A: X_source_4_batch, 
                            self.cgan_4.input_B: X_target_4_batch,
                            self.cgan_1.keep_prob: self.keep_probs['g'],
                            self.cgan_2.keep_prob: self.keep_probs['g'],
                            self.cgan_3.keep_prob: self.keep_probs['g'],
                            self.cgan_4.keep_prob: self.keep_probs['g'],
                            self.lr_g: lr_g_current
                        }
                        if self.lambda_cls > 0:
                            feed_dict[self.keep_prob_cls] = self.keep_probs['c']
                            feed_dict[self.labels] = labels
                            sess.run([self.G_optimizer], feed_dict=feed_dict)
                        else:
                            sess.run([self.G_optimizer], feed_dict=feed_dict)
                                 
                if epoch % 10 == 0:
                    losses = self.feed_losses(sess, losses, epoch)
            # save losses  
            pd.DataFrame(losses).to_hdf(
                self.save_path+'losses/losses_'+self.session_id+'.h5', 
                key='quad_cycle')
            # generate synthetic samples with transferred emotions
            self.generate_synthetic_samples(sess)
            
def train(gpu, opt):

    device = torch.device('cuda:{}'.format(gpu)) if gpu>=0 else torch.device('cpu')
    if gpu >= 0:
        torch.cuda.set_device(gpu)
        rank = opt.nr * len(opt.gpu_ids) + gpu	                          
        dist.init_process_group(                                   
            backend='nccl',                                         
            init_method='env://',                                   
            world_size=opt.world_size,                              
            rank=rank                                               
        )           
        model = CycleGAN(opt, device)
        model = model.to(device) 
        model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank)
        train_dataset = UnalignedDataset(opt)
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=opt.world_size,
            rank=rank
        )
        train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset,
            batch_size=opt.batch_size,
            shuffle=False,            
            num_workers=0,
            pin_memory=True,
            sampler=train_sampler)       

    for epoch in range(opt.start_epoch, opt.n_epochs + opt.n_epochs_decay + 1):
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()    # timer for data loading per iteration
        print("Length of loader is ",len(train_loader))
        for i, data in enumerate(train_loader):
            model.module.set_input(data)
            model.module.optimize_parameters() 
            if (i+1) % 50 == 0 and gpu<=0:
                print_str = (
                    f"Task: {opt.task_num} | Epoch: {epoch} | Iter: {i+1} | G_A: {model.module.loss_G_A:.5f} | "
                    f"G_B: {model.module.loss_G_B:.5f} | cycle_A: {model.module.loss_cycle_A:.5f} | "
                    f"cycle_B: {model.module.loss_cycle_B:.5f} | idt_A: {model.module.loss_idt_A:.5f} | "
                    f"idt_B: {model.module.loss_idt_B:.5f} | D_A: {model.module.loss_D_A:.5f} | "
                    f"D_B: {model.module.loss_D_A:.5f}" 
                )

                print(print_str)
                

        model.module.update_learning_rate()  

        if gpu<=0:
            model.module.save_train_images(epoch)
            save_dict = {'model': model.state_dict(),
                        'epoch': epoch 
                    }
            torch.save(save_dict, opt.ckpt_save_path+'/latest_checkpoint.pt')
        dist.barrier()

    if gpu <= 0:

        netG_A_layer_list = list(model.module.netG_A)
        netG_B_layer_list = list(model.module.netG_B)

        conv_A_idx = 0
        for layer in netG_A_layer_list:
            if isinstance(layer, PiggybackConv) or isinstance(layer, PiggybackTransposeConv):
                layer.unc_filt.requires_grad = False
                if opt.task_num == 1:
                    opt.netG_A_filter_list.append([layer.unc_filt.detach().cpu()])
                elif opt.task_num == 2:
                    opt.netG_A_filter_list[conv_A_idx].append(layer.unc_filt.detach().cpu())
                    opt.weights_A.append([layer.weights_mat.detach().cpu()])
                    conv_A_idx += 1
                else:
                    opt.netG_A_filter_list[conv_A_idx].append(layer.unc_filt.detach().cpu())
                    opt.weights_A[conv_A_idx].append(layer.weights_mat.detach().cpu())
                    conv_A_idx += 1
                    
                
        conv_B_idx = 0
        for layer in netG_B_layer_list:
            if isinstance(layer, PiggybackConv) or isinstance(layer, PiggybackTransposeConv):
                layer.unc_filt.requires_grad = False
                if opt.task_num == 1:
                    opt.netG_B_filter_list.append([layer.unc_filt.detach().cpu()])
                elif opt.task_num == 2:
                    opt.netG_B_filter_list[conv_B_idx].append(layer.unc_filt.detach().cpu())
                    opt.weights_B.append([layer.weights_mat.detach().cpu()])
                    conv_B_idx += 1
                else:
                    opt.netG_B_filter_list[conv_B_idx].append(layer.unc_filt.detach().cpu())
                    opt.weights_B[conv_B_idx].append(layer.weights_mat.detach().cpu())
                    conv_B_idx += 1

        savedict_task = {'netG_A_filter_list':opt.netG_A_filter_list, 
                            'netG_B_filter_list':opt.netG_B_filter_list,
                            'weights_A':opt.weights_A,
                            'weights_B':opt.weights_B
                        }

        torch.save(savedict_task, opt.ckpt_save_path+'/filters.pt')

        del netG_A_layer_list
        del netG_B_layer_list
        del opt.netG_A_filter_list
        del opt.netG_B_filter_list
        del opt.weights_A
        del opt.weights_B
    
    dist.barrier()
    del model

    dist.destroy_process_group()
def test(opt, task_idx):

    opt.train = False
    device = torch.device('cpu')
    model = CycleGAN(opt, device)
    model = model.to(device) 
    model.eval()
    test_dataset = UnalignedDataset(opt)
    test_loader = torch.utils.data.DataLoader(
            dataset=test_dataset,
            batch_size=1,
            shuffle=False,            
            num_workers=4,
            pin_memory=True)

    model.netG_A = load_pb_conv(model.netG_A, opt.netG_A_filter_list, opt.weights_A, task_idx)
    model.netG_B = load_pb_conv(model.netG_B, opt.netG_B_filter_list, opt.weights_B, task_idx)

    for i, data in enumerate(test_loader):
        model.set_input(data)   
        model.forward()
        model.save_test_images(i)
        print(f"Task {opt.task_num} : Image {i}")
        if i > 50:
            break

    del model
Ejemplo n.º 7
0
def train():
    model = CycleGAN(params=args)
    model.train()
    if args.continue_epoch > -1:
        model.load_parameters(args.continue_epoch)

    for e in range(args.epochs):
        e_begin = time.time()
        for batch_idx, inputs in enumerate(dataset):
            model.set_inputs(inputs)
            model.optimize_parameters()
            e_fraction_passed = batch_idx * args.batch_size / len(
                dataset.data_loader_A)
            if batch_idx % args.log_interval == 0:
                err = model.get_errors()
                visualizer.plot_errors(err, e, e_fraction_passed)
                desc = model.get_errors_string()
                print(
                    'Epoch:[{}/{}] Batch:[{:10d}/{}] '.format(
                        e, args.epochs, batch_idx * args.batch_size,
                        len(dataset.data_loader_A)), desc)
            if batch_idx % args.vis_interval == 0:
                imAB_gen_file = os.path.join(
                    args.save_path, 'imAB_gen_{}_{}.jpg'.format(e, batch_idx))
                vutils.save_image(model.get_AB_images_triple(),
                                  imAB_gen_file,
                                  normalize=True)
            if batch_idx % args.save_interval == 0:
                model.save_parameters(e)
        e_end = time.time()
        e_time = e_end - e_begin
        print('End of epoch [{}/{}] Time taken: {:.4f} sec.'.format(
            e, args.epochs, e_time))

    print('saving final model paramaters')
    model.save_parameters(args.epochs)
Ejemplo n.º 8
0
parser.add_argument('--identity', default=0, type=int)
parser.add_argument('--num_test_iterations', default=5, type=int)
parser.add_argument('--phase', default='test', type=str)
args = parser.parse_args()

for k,v in vars(args).items():
    print('{} = {}'.format(k,v))

test_dir = os.path.join(args.save_path,'test')

if not os.path.exists(args.save_path):
    os.makedirs(args.save_path)
if not os.path.exists(test_dir):
    os.makedirs(test_dir)

model = CycleGAN(args)
model.load_parameters(args.load_epoch)
model.print_model_desription()
data_loader = UnalignedDataLoader(args)
dataset = data_loader.load_data()

for batch_idx, inputs in enumerate(dataset):
    if batch_idx >= args.num_test_iterations:
        break
    model.set_inputs(inputs)
    model.test_model()

    imAB_gen_file = os.path.join(test_dir, 'imAB_gen_{}_{}_{}_test.jpg'.format(batch_idx, args.height, args.width))
    vutils.save_image(model.get_AB_images_triple(), imAB_gen_file, normalize=True)
    print('processed item with idx: {}'.format(batch_idx))