Ejemplo n.º 1
0
    def __init__(self,
                 image_size,
                 learning_rate=2e-5,
                 batch_size=1,
                 ngf=64,
                 units=4096
                 ):
        """
        Args:
          input_size:list [H, W, C]
          batch_size: integer, batch size
          learning_rate: float, initial learning rate for Adam
          ngf: number of gen filters in first conv layer
        """
        self.learning_rate = learning_rate
        self.input_shape = [int(batch_size / 4), image_size[0], image_size[1], image_size[2]]
        self.ones = tf.ones(self.input_shape, name="ones")
        self.tenaor_name = {}

        self.EC_S = VEncoder('EC_S', ngf=ngf, units=units, keep_prob=0.85)
        self.DC_S = VDecoder('DC_S', ngf=ngf, output_channl=2, units=units)

        self.G_M = Unet('G_M', ngf=ngf / 2, keep_prob=0.9, output_channl=2)

        self.D_S = Discriminator('D_S', ngf=ngf, keep_prob=0.85)
        self.FD_Z = FeatureDiscriminator('FD_Z', ngf=ngf)
    def __init__(self,
                 image_size,
                 learning_rate=2e-5,
                 batch_size=1,
                 ngf=64,
                 ):
        """
        Args:
          input_size:list [H, W, C]
          batch_size: integer, batch size
          learning_rate: float, initial learning rate for Adam
          ngf: number of gen filters in first conv layer
        """
        self.learning_rate = learning_rate
        self.input_shape = [int(batch_size / 4), image_size[0], image_size[1], image_size[2]]
        self.code_shape = [int(batch_size / 4), int(image_size[0] / 4), int(image_size[1] / 4), 4]
        self.ones = tf.ones(self.input_shape, name="ones")
        self.ones_code = tf.ones(self.code_shape, name="ones_code")
        self.image_list = {}
        self.prob_list = {}
        self.code_list = {}
        self.judge_list = {}
        self.tenaor_name = {}

        self.EC_L = Encoder('EC_L', ngf=ngf)
        self.DC_L = Decoder('DC_L', ngf=ngf, output_channl=5)

        self.EC_R = Encoder('EC_R', ngf=ngf)
        self.EC_M = Encoder('EC_M', ngf=ngf)
        self.DC_M = Decoder('DC_M', ngf=ngf)

        self.D_M = Discriminator('D_M', ngf=ngf)
        self.FD_R = FeatureDiscriminator('FD_R', ngf=ngf)