Exemplo n.º 1
0
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.color_analysis_transform = ColorAnalysisTransform()
        self.disp1_analysis_transform = DispAnalysisTransform()
        self.disp2_analysis_transform = DispAnalysisTransform()
        self.disp3_analysis_transform = DispAnalysisTransform()
        self.disp4_analysis_transform = DispAnalysisTransform()
        self.disp5_analysis_transform = DispAnalysisTransform()
        self.disp6_analysis_transform = DispAnalysisTransform()
        self.disp7_analysis_transform = DispAnalysisTransform()
        self.disp8_analysis_transform = DispAnalysisTransform()

        self.color_synthesis_transform = ColorSynthesisTransform()
        self.disp1_synthesis_transform = DispSynthesisTransform()
        self.disp2_synthesis_transform = DispSynthesisTransform()
        self.disp3_synthesis_transform = DispSynthesisTransform()
        self.disp4_synthesis_transform = DispSynthesisTransform()
        self.disp5_synthesis_transform = DispSynthesisTransform()
        self.disp6_synthesis_transform = DispSynthesisTransform()
        self.disp7_synthesis_transform = DispSynthesisTransform()
        self.disp8_synthesis_transform = DispSynthesisTransform()
        
        
        self.entropy_bottleneck_color = tfc.NoisyDeepFactorized(batch_shape=[320])
        self.entropy_bottleneck_disp = tfc.NoisyDeepFactorized(batch_shape=[8])

        self.optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

        tf.summary.experimental.set_step(self.optimizer.iterations)
        self.writer = tf.summary.create_file_writer(self.args.checkpoint_dir)
Exemplo n.º 2
0
 def __init__(self, lmbda,
              num_filters, latent_depth, hyperprior_depth,
              num_slices, max_support_slices,
              num_scales, scale_min, scale_max):
   super().__init__()
   self.lmbda = lmbda
   self.num_scales = num_scales
   self.num_slices = num_slices
   self.max_support_slices = max_support_slices
   offset = tf.math.log(scale_min)
   factor = (tf.math.log(scale_max) - tf.math.log(scale_min)) / (
       num_scales - 1.)
   self.scale_fn = lambda i: tf.math.exp(offset + factor * i)
   self.analysis_transform = AnalysisTransform(latent_depth)
   self.synthesis_transform = SynthesisTransform()
   self.hyper_analysis_transform = HyperAnalysisTransform(hyperprior_depth)
   self.hyper_synthesis_mean_transform = HyperSynthesisTransform()
   self.hyper_synthesis_scale_transform = HyperSynthesisTransform()
   self.cc_mean_transforms = [
       SliceTransform(latent_depth, num_slices) for _ in range(num_slices)]
   self.cc_scale_transforms = [
       SliceTransform(latent_depth, num_slices) for _ in range(num_slices)]
   self.lrp_transforms = [
       SliceTransform(latent_depth, num_slices) for _ in range(num_slices)]
   self.hyperprior = tfc.NoisyDeepFactorized(batch_shape=[hyperprior_depth])
   self.build((None, None, None, 3))
   # The call signature of decompress() depends on the number of slices, so we
   # need to compile the function dynamically.
   self.decompress = tf.function(
       input_signature=3 * [tf.TensorSpec(shape=(2,), dtype=tf.int32)] +
       (num_slices + 1) * [tf.TensorSpec(shape=(1,), dtype=tf.string)]
   )(self.decompress)
Exemplo n.º 3
0
 def __init__(self, lmbda, num_filters):
   super().__init__()
   self.lmbda = lmbda
   self.analysis_transform = AnalysisTransform(num_filters)
   self.synthesis_transform = SynthesisTransform(num_filters)
   self.prior = tfc.NoisyDeepFactorized(batch_shape=(num_filters,))
   self.build((None, None, None, 3))
Exemplo n.º 4
0
 def __init__(self, lmbda, context_length, num_filters, num_scales,
              scale_min, scale_max):
     super().__init__()
     self.lmbda = lmbda
     self.num_scales = num_scales
     offset = tf.math.log(scale_min)
     factor = (tf.math.log(scale_max) -
               tf.math.log(scale_min)) / (num_scales - 1.)
     self.context_model = ContextModel(context_length=context_length,
                                       num_filters=num_filters)
     self.scale_fn = lambda i: tf.math.exp(offset + factor * i)
     self.hyperprior = tfc.NoisyDeepFactorized(batch_shape=(num_filters, ))
     self.analysis_transform = AnalysisTransform(
         context_length=context_length, num_filters=num_filters)
     self.hyper_analysis_transform = HyperAnalysisTransform(
         context_length=context_length, num_filters=num_filters)
     self.hyper_synthesis_transform = HyperSynthesisTransform(
         context_length=context_length, num_filters=num_filters)
     self.synthesis_transform = SynthesisTransform(
         context_length=context_length, num_filters=num_filters)
     self.entropy_model = tfc.LocationScaleIndexedEntropyModel(
         tfc.NoisyNormal,
         self.num_scales,
         self.scale_fn,
         coding_rank=3,
         compression=False)
     self.side_entropy_model = tfc.ContinuousBatchedEntropyModel(
         self.hyperprior, coding_rank=3, compression=False)
Exemplo n.º 5
0
 def __init__(self, lmbda, num_filters, num_scales, scale_min, scale_max):
     super().__init__()
     self.lmbda = lmbda
     self.num_scales = num_scales
     offset = tf.math.log(scale_min)
     factor = (tf.math.log(scale_max) -
               tf.math.log(scale_min)) / (num_scales - 1.)
     self.scale_fn = lambda i: tf.math.exp(offset + factor * i)
     self.analysis_transform = AnalysisTransform(num_filters)
     self.synthesis_transform = SynthesisTransform(num_filters)
     self.hyper_analysis_transform = HyperAnalysisTransform(num_filters)
     self.hyper_synthesis_transform = HyperSynthesisTransform(num_filters)
     self.hyperprior = tfc.NoisyDeepFactorized(batch_shape=(num_filters, ))
     self.build((None, None, None, 3))
Exemplo n.º 6
0
    def __init__(self, args):
        super().__init__()
        self.args = args

        self.analysis_transform = AnalysisTransform()
        self.synthesis_transform = SynthesisTransform()
        self.hyper_analysis_transform = HyperAnalysisTransform()
        self.hyper_synthesis_mean_transform = HyperSynthesisTransform()
        self.hyper_synthesis_scale_transform = HyperSynthesisTransform()
        self.cc_mean_transforms = [SliceTransform() for _ in range(NUM_SLICES)]
        self.cc_scale_transforms = [
            SliceTransform() for _ in range(NUM_SLICES)
        ]
        self.lrp_transforms = [SliceTransform() for _ in range(NUM_SLICES)]
        self.entropy_bottleneck = tfc.NoisyDeepFactorized(
            batch_shape=[HYPERPRIOR_DEPTH])
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

        tf.summary.experimental.set_step(self.optimizer.iterations)
        self.writer = tf.summary.create_file_writer(args.checkpoint_dir)