示例#1
0
    def __init__(self, epochs, batch_size):
        self.epochs = epochs
        self.batch_size = batch_size

        # Generator
        filter_kernal = [[32, 128, 128, 128, 256],
                         [32, 128, 128, 128, 128, 256, 512],
                         [32, 128, 128, 128, 128, 128, 128, 512, 512]]

        self.g_1 = generator(3, filter_kernal[2], in_1_size, 'Adam',
                             'mse').generator_model()
        self.g_2 = generator(2, filter_kernal[1], in_2_size, 'Adam',
                             'mse').generator_model()
        self.g_3 = generator(1, filter_kernal[0], in_3_size, 'Adam',
                             'mse').generator_model()

        # Discriminator
        self.D = discriminator(out).modified_vgg()
        self.D.compile(loss='mse', optimizer='Adam', metrics=['accuracy'])

        # Images
        lr = Input(shape=in_1_size)
        hr_1 = Input(shape=in_2_size)
        hr_2 = Input(shape=in_3_size)
        hr_3 = Input(shape=out)

        fake_hr_image_1 = self.g_1(lr)
        fake_hr_image_2 = self.g_2(fake_hr_image_1)
        fake_hr_image_3 = self.g_3(fake_hr_image_2)

        self.vgg_1 = vgg(in_2_size).vgg_loss_model()
        self.vgg_2 = vgg(in_3_size).vgg_loss_model()
        self.vgg_3 = vgg(out).vgg_loss_model()

        fake_hr_feat_1 = self.vgg_1(fake_hr_image_1)
        fake_hr_feat_2 = self.vgg_2(fake_hr_image_2)
        fake_hr_feat_3 = self.vgg_3(fake_hr_image_3)

        self.D.trainable = False
        valid = self.D(fake_hr_image_3)

        self.merged_net = Model(
            inputs=[lr, hr_1, hr_2, hr_3],
            outputs=[valid, fake_hr_feat_1, fake_hr_feat_2, fake_hr_feat_3])
        self.merged_net.compile(
            loss=['binary_crossentropy', 'mse', 'mse', 'mse'],
            loss_weights=[1e-3, .3, .3, .3],
            optimizer='Adam')
        self.merged_net.summary()
示例#2
0
def gen_model(param_object):
    x_train, y_train = load_data(flag=0)
    _G = generator(param_object.gen_type,
                   param_object.filter_kernal[param_object.filter_kernal_no],
                   param_object.input_shape, 'Adam', mean_squared_error,
                   tf.zeros(shape=param_object.inflow_shape,
                            dtype=tf.float32)).generator_model()
    opti_trick = ReduceLROnPlateau(monitor='val_loss',
                                   factor=0.2,
                                   patience=100,
                                   verbose=1)
    ckpt = ModelCheckpoint(ckpt_path,
                           monitor='val_loss',
                           verbose=1,
                           save_best_only=True,
                           save_weights_only=True,
                           period=1)
    _G.fit(x_train,
           y_train,
           epochs=param_object.epochs,
           batch_size=param_object.batch_size,
           callbacks=[opti_trick, ckpt],
           validation_split=0.2,
           shuffle=True)
    return _
 def feed_forward(self):
 	in_1 = [256,256,3]
 	in_2 = [512,512,3]
 	in_3 = [1024,1024,3]
 	out_ = [128,128,3]
 	#/home/mdo2/sid_codes/newew/ee2.jpg
 	#/home/mdo2/sid_codes/datasets_big/qhd_no_pad/train/1.png
 	img = cv2.resize(cv2.imread('/home/mdo2/sid_codes/datasets_big/qhd_no_pad/train/1.png',1),(512*2,512*2))
 	img = img.reshape(1,512*2,512*2,3)
 	filter_kernal = [[32,128,128,128,256],[32,128,128,128,128,256,512],[32,128,128,128,128,128,128,128,512]]
 	g3 = generator(3,filter_kernal[2],in_3,'Adam',path = '/home/mdo2/sid_codes/new_codes/gen_1.h5').generator_model()
 	#g2 = generator(2,filter_kernal[1],in_2,'Adam',path = '/home/mdo2/sid_codes/new_codes/new_save/gen_2.h5').generator_model()
 	#g1 = generator(1,filter_kernal[0],in_1,'Adam',path = '/home/mdo2/sid_codes/new_codes/new_save/gen_3.h5').generator_model()
 	img = g3.predict(img)
 	#img = g2.predict(img)
 	img = np.array(img).astype('uint8')
 	img = img.reshape(1024*2,1024*2,3)
 	cv2.imshow('img',img)
 	cv2.waitKey(0)
 	cv2.imwrite('ee3.png',img)
 	im1 = tf.io.decode_image('/home/mdo2/sid_codes/newew/ee3.png')
 	#im1 = tf.image.convert_image_dtype(im1, tf.float32)
 	im2 = tf.io.decode_image('/home/mdo2/sid_codes/datasets_big/qhd_no_pad/train/1.png')
 	#im2 = tf.image.convert_image_dtype(im2, tf.float32)
 	psnr1 = tf.image.psnr(im2, im2, max_val=1.0)
 	print(psnr1)
示例#4
0
 def feed_forward(self):
     in_1 = [64, 64, 3]
     in_2 = [64, 64, 3]
     in_3 = [64 * 4, 64 * 4, 3]
     out_ = [128, 128, 3]
     img = cv2.resize(
         cv2.imread(
             '/home/mdo2/sid_codes/datasets_big/qhd_no_pad/train/1.png', 1),
         (64, 64))
     img = img.reshape(1, 64, 64, 3)
     filter_kernal = [[32, 128, 128, 128, 256],
                      [32, 128, 128, 128, 128, 256, 512],
                      [32, 128, 128, 128, 128, 128, 128, 128, 512]]
     g3 = generator(
         3,
         filter_kernal[2],
         in_1,
         'Adam',
         path='/home/mdo2/sid_codes/new_codes/gen_1.h5').generator_model()
     #g2 = generator(2,filter_kernal[1],in_2,'Adam',path = '/home/mdo2/sid_codes/new_codes/gen_2.h5').generator_model()
     #g1 = generator(1,filter_kernal[0],in_1,'Adam',path = '/home/mdo2/sid_codes/new_codes/gen_3.h5').generator_model()
     img = g3.predict(img)
     #img = g2.predict(img)
     img = np.array(img).astype('uint8')
     img = img.reshape(128, 128, 3)
     cv2.imshow('img', img)
     cv2.waitKey(0)
     cv2.imwrite('ee.jpg', img)
    def __init__(self,epochs,batch_size):
        self.epochs = epochs
        self.batch_size = batch_size
        
            # Generator
        filter_kernal = [[32,128,128,128,256],[32,128,128,128,128,256,512],[32,128,128,128,128,128,128,128,512]]
       
        #self.g_1 = generator(3,filter_kernal[2],in_1_size,'Adam','mse',path = '/home/mdo2/sid_codes/new_codes/gen_1.h5').generator_model()   
        #self.g_2 = generator(2,filter_kernal[1],in_2_size,'Adam','mse',path = '/home/mdo2/sid_codes/new_codes/gen_2.h5').generator_model()
        self.g_3 = generator(1,filter_kernal[0],in_3_size,'Adam','mse',path = '/home/mdo2/sid_codes/new_codes/gen_3.h5').generator_model()

        # Discriminator
        #self.D1 = discriminator(in_2_size).modified_vgg()
        #self.D1.compile(loss='mse',optimizer='Adam',metrics=['accuracy'])
        #self.D2 = discriminator(in_3_size).modified_vgg()
        #self.D2.compile(loss='mse',optimizer='Adam',metrics=['accuracy'])
        self.D = discriminator(out,'/home/mdo2/sid_codes/new_codes/D.h5').modified_vgg()
        self.D.compile(loss='mse',optimizer='Adam',metrics=['accuracy'])
        
        self.g = 0
        # Images
        lr = Input(shape=in_1_size)
        hr_1 = Input(shape=in_2_size)
        hr_2 = Input(shape=in_2_size)
        hr_3 = Input(shape=out)
        #fake_hr_image_1 = self.g_1(lr)
        #fake_hr_image_2 = self.g_2(lr)
        fake_hr_image_3 = self.g_3(lr)

        #self.vgg_1 = vgg(in_2_size).vgg_loss_model()
        #self.vgg_2 = vgg(in_2_size).vgg_loss_model()
        self.vgg_3 = vgg(in_2_size).vgg_loss_model()
       
        #fake_hr_feat_1 = self.vgg_1(fake_hr_image_1)
        #fake_hr_feat_2 = self.vgg_2(fake_hr_image_2)
        fake_hr_feat_3 = self.vgg_3(fake_hr_image_3)

        self.D.trainable = False
        #self.D1.trainable = False
        #self.D2.trainable = False
        validity = self.D(fake_hr_image_3)
        #self.D.load_weights('/media/arjun/119GB/attachments/D_best.h5')

        self.merged_net = Model(inputs=[lr,hr_2],outputs=[validity,fake_hr_feat_3])
        self.merged_net.compile(loss=['binary_crossentropy','mse'],loss_weights=[1e-3,.3],optimizer='Adam')
        self.merged_net.summary()
        #self.merged_net.load_weights('/home/mdo2/sid_codes/newew/my_model.h5')
        print('loaded')