Esempio n. 1
0
    def __init__(self, sess):
        self.sess=sess
        self.global_step = tf.Variable(0.0, name='global_step',dtype=tf.float32, trainable=False)
        
        #for data input
        self.pipline_data_train=cdata.get_pipline_data_train(img_size, batchsize)
        self.pipline_data_test=cdata.get_pipline_data_test(img_size, batchsize_test)
        
        #3个placeholder, img和noise,training 
        self.imgs_pla = tf.placeholder(tf.float32, [batchsize, img_size_h, img_size_w, G_group_img_num*img_channel], name='imgs_in')
        self.imgs_pla_eval = tf.placeholder(tf.float32, [batchsize, img_size_h*2, img_size_w*2, G_group_img_num*img_channel], name='imgs_in_eval')
        self.training=tf.placeholder(tf.bool, name='training_in')  #这里没用上但是为了兼容就保留了
        self.timerates_pla=tf.placeholder(tf.float32, [batchsize], name='timerates_in')
        self.timerates_expand=tf.expand_dims(self.timerates_pla, -1)
        self.timerates_expand=tf.expand_dims(self.timerates_expand, -1)
        self.timerates_expand=tf.expand_dims(self.timerates_expand, -1) #12*1*1*1
        
        print ('placeholders:\n','img_placeholder:',self.imgs_pla,self.timerates_pla)
        #img_placeholder: Tensor("imgs_in:0", shape=(10, 180, 320, 9), dtype=float32) Tensor("timerates_in:0", shape=(10,), dtype=float32)
        
        self.frame0=self.imgs_pla[:,:,:,:img_channel]
        self.frame1=self.imgs_pla[:,:,:,img_channel:img_channel*2]
        self.frame2=self.imgs_pla[:,:,:,img_channel*2:]
        
        #这里用来进行evaluate注意这里只用于输出最终图        
        self.frame0_eval=self.imgs_pla_eval[:,:,:,:img_channel]
        self.frame1_eval=self.imgs_pla_eval[:,:,:,img_channel:img_channel*2]
        self.frame2_eval=self.imgs_pla_eval[:,:,:,img_channel*2:]
        
        with tf.variable_scope("first_unet",  reuse=tf.AUTO_REUSE) as scope:
            firstinput=tf.concat([self.frame0, self.frame2], -1)
            self.first_opticalflow=my_unet( firstinput, 4,training=self.training , withbias=True, withbn=False)  #注意这里是直接作为optical flow
            
            firstinput=tf.concat([self.frame0_eval, self.frame2_eval], -1)
            self.first_opticalflow_eval=my_unet( firstinput, 4,training=self.training , withbias=True, withbn=False)  #注意这里是直接作为optical flow
            #self.first_opticalflow=my_unet_split( firstinput, 4,training=self.training , withbias=True, withbn=True)  #注意这里是直接作为optical flow
            
        self.first_opticalflow_0_1=self.first_opticalflow[:, :, :, :2]
        self.first_opticalflow_0_1=tf.identity(self.first_opticalflow_0_1, name="first_opticalflow_0_1")
        print ('first_opticalflow_0_1:',self.first_opticalflow_0_1)
        self.first_opticalflow_1_0=self.first_opticalflow[:, :, :, 2:]
        self.first_opticalflow_1_0=tf.identity(self.first_opticalflow_1_0, name="first_opticalflow_1_0")
        print ('first_opticalflow_1_0:',self.first_opticalflow_1_0)
        #first_opticalflow_0_1: Tensor("first_opticalflow_0_1:0", shape=(10, 180, 320, 2), dtype=float32)
        #first_opticalflow_1_0: Tensor("first_opticalflow_1_0:0", shape=(10, 180, 320, 2), dtype=float32)
        self.first_opticalflow_0_1_eval=self.first_opticalflow_eval[:, :, :, :2]
        self.first_opticalflow_1_0_eval=self.first_opticalflow_eval[:, :, :, 2:]
        
        #输出光流形状
        self.flow_size_h=self.first_opticalflow_0_1.get_shape().as_list()[1]
        self.flow_size_w=self.first_opticalflow_0_1.get_shape().as_list()[2]
        self.flow_channel=self.first_opticalflow_0_1.get_shape().as_list()[-1]
        
        #eval shape
        self.flow_size_h_eval=self.first_opticalflow_eval.get_shape().as_list()[1]
        self.flow_size_w_eval=self.first_opticalflow_eval.get_shape().as_list()[2]
        
        ########################################################
        self.step2_flow_channel=4
        self.flow_shape=[ self.flow_size_h, self.flow_size_w, self.step2_flow_channel]
        
        #lstm的每个状态(c,h)的形状
        self.state_shape=[2, 1, self.flow_size_h, self.flow_size_w, self.step2_flow_channel]
        self.state_shape_eval=[2, 1, self.flow_size_h_eval, self.flow_size_w_eval, self.step2_flow_channel]
        
        #获取数据时的一些cpu上的参数,用于扩张数据和判定时序
        self.last_flow_init_np=np.zeros(self.state_shape, dtype=np.float32)
        self.last_flow_init_np_eval=np.zeros(self.state_shape_eval, dtype=np.float32)
        print (self.last_flow_init_np.shape, self.last_flow_init_np_eval.shape) #(2, 1, 180, 320, 5) (2, 1, 360, 640, 5)
        ##############################################################
        
        self.last_optical_flow=tf.placeholder(tf.float32, self.state_shape, name='second_last_flow')
        self.last_optical_flow_eval=tf.placeholder(tf.float32, self.state_shape_eval, name='second_last_flow_eval')
        
        #初始化train和test的初始0状态
        self.last_flow_new_train=self.last_flow_init_np
        self.last_flow_new_test=self.last_flow_init_np
        
        #反向光流算中间帧
        self.first_opticalflow_t_0=tf.add( -(1-self.timerates_expand)*self.timerates_expand*self.first_opticalflow_0_1 ,\
                                      self.timerates_expand*self.timerates_expand*self.first_opticalflow_1_0 , name="first_opticalflow_t_0")
        self.first_opticalflow_t_2=tf.add( (1-self.timerates_expand)*(1-self.timerates_expand)*self.first_opticalflow_0_1 ,\
                                      self.timerates_expand*(self.timerates_expand-1)*self.first_opticalflow_1_0, name="first_opticalflow_t_2")
        
        #反向光流算中间帧
        self.first_opticalflow_t_0_eval=tf.add( -(1-self.timerates_expand)*self.timerates_expand*self.first_opticalflow_0_1_eval ,\
                                      self.timerates_expand*self.timerates_expand*self.first_opticalflow_1_0_eval , name="first_opticalflow_t_0_eval")
        self.first_opticalflow_t_2_eval=tf.add( (1-self.timerates_expand)*(1-self.timerates_expand)*self.first_opticalflow_0_1_eval ,\
                                      self.timerates_expand*(self.timerates_expand-1)*self.first_opticalflow_1_0_eval, name="first_opticalflow_t_2_eval")

        #2种方法合成t时刻的帧
        self.first_img_flow_2_t=self.warp_op(self.frame2, -self.first_opticalflow_t_2) #!!!
        self.first_img_flow_0_t=self.warp_op(self.frame0, -self.first_opticalflow_t_0) #!!!
        
        #2种方法合成t时刻的帧
        self.first_img_flow_2_t_eval=self.warp_op(self.frame2_eval, -self.first_opticalflow_t_2_eval) #!!!
        self.first_img_flow_0_t_eval=self.warp_op(self.frame0_eval, -self.first_opticalflow_t_0_eval) #!!!
        
        #虽然论文里用不到第一步的输出中间帧,但是这里也给他输出看看效果
        self.first_output=tf.add( self.timerates_expand*self.first_img_flow_2_t, (1-self.timerates_expand)*self.first_img_flow_0_t , name="first_outputimg")
        print ('first output img:',self.first_output)
        #first output img: Tensor("first_outputimg:0", shape=(10, 180, 320, 3), dtype=float32)
        
        #利用光流前后帧互相合成
        self.first_img_flow_2_0=self.warp_op(self.frame2, -self.first_opticalflow_0_1)  #frame2->frame0
        self.first_img_flow_0_2=self.warp_op(self.frame0, -self.first_opticalflow_1_0)  #frame0->frame2
        
        ####################################################################################################################3
        #第二个unet
        with tf.variable_scope("second_unet",  reuse=tf.AUTO_REUSE) as scope:
            secinput=tf.concat([self.frame0, self.frame2, \
                                self.first_opticalflow_0_1, self.first_opticalflow_1_0, \
                                self.first_opticalflow_t_2, self.first_opticalflow_t_0,\
                                self.first_img_flow_2_t, self.first_img_flow_0_t,\
                                ], -1) #self.last_optical_flow     
            secinput=tf.expand_dims(secinput, 0)
            print ("secinput:",secinput)#secinput: Tensor("second_unet/ExpandDims:0", shape=(1, 10, 180, 320, 20), dtype=float32)  
            
            secinput_eval=tf.concat([self.frame0_eval, self.frame2_eval, \
                                self.first_opticalflow_0_1_eval, self.first_opticalflow_1_0_eval, \
                                self.first_opticalflow_t_2_eval, self.first_opticalflow_t_0_eval,\
                                self.first_img_flow_2_t_eval, self.first_img_flow_0_t_eval,\
                                ], -1) #self.last_optical_flow     
            secinput_eval=tf.expand_dims(secinput_eval, 0)
            print ("secinput_eval:",secinput_eval)#secinput_eval: Tensor("second_unet/ExpandDims_1:0", shape=(1, 10, 360, 640, 20), dtype=float32)
                
            
            
            lstm_input_channel=secinput.get_shape().as_list()[-1]
            self.cell = tf.contrib.rnn.ConvLSTMCell(conv_ndims=2, input_shape=[self.flow_size_h, self.flow_size_w, lstm_input_channel], \
                                                    output_channels=self.step2_flow_channel, kernel_shape=[5, 5])
            
            self.cell_eval = tf.contrib.rnn.ConvLSTMCell(conv_ndims=2, input_shape=[self.flow_size_h_eval, self.flow_size_w_eval, lstm_input_channel], \
                                                    output_channels=self.step2_flow_channel, kernel_shape=[5, 5])
            
            lstm_outputs, lstm_state_final = tf.nn.dynamic_rnn(self.cell, inputs =secinput , \
                                            initial_state = tf.nn.rnn_cell.LSTMStateTuple(self.last_optical_flow[0], self.last_optical_flow[1]), time_major = False)
            
            lstm_outputs_eval, lstm_state_final_eval = tf.nn.dynamic_rnn(self.cell_eval, inputs =secinput_eval , \
                                            initial_state = tf.nn.rnn_cell.LSTMStateTuple(self.last_optical_flow_eval[0], self.last_optical_flow_eval[1]), time_major = False)
            
            
            
            self.second_batch_last_flow=tf.stack([lstm_state_final.c, lstm_state_final.h], 0)
            self.second_batch_last_flow=tf.identity(self.second_batch_last_flow, name="second_batch_last_flow")
            print ("second_batch_last_flow:",self.second_batch_last_flow) 
            #second_batch_last_flow: Tensor("second_unet/second_batch_last_flow:0", shape=(2, 1, 180, 320, 5), dtype=float32)
            
            self.second_batch_last_flow_eval=tf.stack([lstm_state_final_eval.c, lstm_state_final_eval.h], 0)
            self.second_batch_last_flow_eval=tf.identity(self.second_batch_last_flow_eval, name="second_batch_last_flow_eval")
            print ("second_batch_last_flow_eval:",self.second_batch_last_flow_eval)
            #Tensor("second_unet/second_batch_last_flow_eval:0", shape=(2, 1, 360, 640, 5), dtype=float32)
            
            self.second_opticalflow=lstm_outputs[0]  
            print ("self.second_opticalflow:",self.second_opticalflow) 
            #self.second_opticalflow: Tensor("second_unet/strided_slice_2:0", shape=(10, 180, 320, 5), dtype=float32)
            
            self.second_opticalflow_eval=lstm_outputs_eval[0]  
            print ("self.second_opticalflow_eval:",self.second_opticalflow_eval)
            #Tensor("second_unet/strided_slice_5:0", shape=(10, 360, 640, 5), dtype=float32)
            
        self.second_opticalflow_t_0=tf.add( self.second_opticalflow[:,:,:,:2],  self.first_opticalflow_t_0, name="second_opticalflow_t_0")
        self.second_opticalflow_t_1=tf.add( self.second_opticalflow[:,:,:,2:4], self.first_opticalflow_t_2, name="second_opticalflow_t_1")
        print ('second_opticalflow_t_0:',self.second_opticalflow_t_0)
        print ('second_opticalflow_t_1:',self.second_opticalflow_t_1)
        #second_opticalflow_t_0: Tensor("second_opticalflow_t_0:0", shape=(10, 180, 320, 2), dtype=float32)
        #second_opticalflow_t_1: Tensor("second_opticalflow_t_1:0", shape=(10, 180, 320, 2), dtype=float32)
        
        self.second_opticalflow_t_0_eval=tf.add( self.second_opticalflow_eval[:,:,:,:2],  self.first_opticalflow_t_0_eval, name="second_opticalflow_t_0_eval")
        self.second_opticalflow_t_1_eval=tf.add( self.second_opticalflow_eval[:,:,:,2:4], self.first_opticalflow_t_2_eval, name="second_opticalflow_t_1_eval")
        print ('second_opticalflow_t_0_eval:',self.second_opticalflow_t_0_eval)
        print ('second_opticalflow_t_1_eval:',self.second_opticalflow_t_1_eval)
        #second_opticalflow_t_0_eval: Tensor("second_opticalflow_t_0_eval:0", shape=(10, 360, 640, 2), dtype=float32)
        #second_opticalflow_t_1_eval: Tensor("second_opticalflow_t_1_eval:0", shape=(10, 360, 640, 2), dtype=float32)
        '''
        self.vmap_t_0=tf.expand_dims( tf.sigmoid(self.second_opticalflow[:,:,:,-1])  , -1)
        self.vmap_t_1=1-self.vmap_t_0
        
        self.vmap_t_0_eval=tf.expand_dims( tf.sigmoid(self.second_opticalflow_eval[:,:,:,-1])  , -1)
        self.vmap_t_1_eval=1-self.vmap_t_0_eval
        '''
        #2种方法合成t时刻的帧
        self.second_img_flow_1_t=self.warp_op(self.frame2, -self.second_opticalflow_t_1) #!!!
        self.second_img_flow_0_t=self.warp_op(self.frame0, -self.second_opticalflow_t_0) #!!!
        
        #2种方法合成t时刻的帧
        self.second_img_flow_1_t_eval=self.warp_op(self.frame2_eval, -self.second_opticalflow_t_1_eval) #!!!
        self.second_img_flow_0_t_eval=self.warp_op(self.frame0_eval, -self.second_opticalflow_t_0_eval) #!!!
        
        #最终输出的图
        print (self.timerates_expand, self.second_img_flow_0_t)
        #Tensor("ExpandDims_2:0", shape=(10, 1, 1, 1), dtype=float32) Tensor("ExpandDims_3:0", shape=(10, 180, 320, 1), dtype=float32) 
        #Tensor("dense_image_warp_5/Reshape_1:0", shape=(10, 180, 320, 3), dtype=float32)
        self.second_output=tf.add(   (1-self.timerates_expand)*self.second_img_flow_0_t, self.timerates_expand*self.second_img_flow_1_t , name="second_outputimg" )
        print ('second output img:',self.second_output)
        #second output img: Tensor("second_outputimg:0", shape=(10, 180, 320, 3), dtype=float32)
        
        self.second_output_eval=tf.add(   (1-self.timerates_expand)*self.second_img_flow_0_t_eval, self.timerates_expand*self.second_img_flow_1_t_eval , name="second_outputimg_eval" )
        print ('second output img_eval:',self.second_output_eval)
        #second output img_eval: Tensor("second_outputimg_eval:0", shape=(10, 360, 640, 3), dtype=float32)
        
        #判别器的网络构建
        self.D_1_net_F, self.D_1_net_F_logit=Discriminator_net(self.second_output, name="D1", training=self.training)
        self.D_1_net_T, self.D_1_net_T_logit=Discriminator_net(self.frame1, name="D1", training=self.training)
        #D的loss计算
        self.D_1_net_loss_sum, _, _=self.D_loss_TandF_logits(self.D_1_net_T_logit, self.D_1_net_F_logit, "D_1_net")
        
        #计算loss
        self.second_L1_loss_interframe,self.first_warp_loss,self.second_contex_loss,self.second_local_var_loss_all,self.second_global_var_loss_all,self.second_ssim,self.second_psnr,\
                self.first_L1_loss_interframe, self.first_ssim, self.first_psnr, self.second_GAN_loss_mean_D1=self.loss_cal_all()
                
        #训练G的总loss
        self.G_loss_all=204 * self.second_L1_loss_interframe + 102 *  self.first_warp_loss  + 0.005 * self.second_contex_loss \
                    +self.second_global_var_loss_all
                    #+ self.second_GAN_loss_mean_D1*0.03   
        
        #训练D的总loss
        self.D_loss_all=self.D_1_net_loss_sum
        
        
        #####################################
        self.last_label_train='#'
        self.last_label_test='#'
        self.state_random_row_train=0
        self.state_random_col_train=0
        self.state_flip_train=False
        
        self.state_random_row_test=0
        self.state_random_col_test=0
        self.state_flip_test=False
        
        #为了兼容性
        self.batchsize_inputimg=batchsize
        self.img_size_w=img_size_w
        self.img_size_h=img_size_h
        
        t_vars=tf.trainable_variables()
        print ("trainable vars cnt:",len(t_vars))
        self.first_para=[var for var in t_vars if var.name.startswith('first')]
        self.sec_para=[var for var in t_vars if var.name.startswith('second')]
        self.vgg_para=[var for var in t_vars if var.name.startswith('VGG')]
        self.D_para=[var for var in t_vars if var.name.startswith('D')]
        print ("first param len:",len(self.first_para))
        print ("second param len:",len(self.sec_para))
        print ("VGG param len:",len(self.vgg_para))
        print ("D param len:",len(self.D_para))
        print (self.vgg_para)
        '''
        trainable vars cnt: 114
        first param len: 46
        second param len: 2
        VGG param len: 52
        D param len: 14
        '''
        
        #G训练过程
        self.lr_rate = tf.train.exponential_decay(base_lr,  global_step=self.global_step, decay_steps=decay_steps, decay_rate=decay_rate)
        self.train_op_G = tf.train.AdamOptimizer(self.lr_rate, beta1=beta1, name="superslomo_adam_G").minimize(self.G_loss_all,  \
                                                                                              global_step=self.global_step  , var_list=self.first_para+self.sec_para  )
        
        # weight clipping
        self.clip_D = [p.assign(tf.clip_by_value(p, weightclip_min, weightclip_max)) for p in self.D_para]
        
        #D训练过程
        self.train_op_D= tf.train.AdamOptimizer(self.lr_rate  , beta1=beta1, name="superslomo_adam_D").minimize(self.D_loss_all, var_list=self.D_para)
        
        #最后构建完成后初始化参数 
        self.sess.run(tf.global_variables_initializer())
    def __init__(self, sess):
        self.sess = sess

        #加载原模型
        saver = tf.train.import_meta_graph(op.join(modelpath, meta_name))
        saver.restore(self.sess, tf.train.latest_checkpoint(modelpath))
        self.graph = tf.get_default_graph()
        self.global_step = tf.Variable(0.0,
                                       name='step2_global_step',
                                       dtype=tf.float32,
                                       trainable=False)

        #placeholders
        self.imgs_pla = self.graph.get_tensor_by_name('imgs_in:0')
        self.training = self.graph.get_tensor_by_name("training_in:0")
        self.timerates = self.graph.get_tensor_by_name("timerates_in:0")

        #self.timerates_pla=tf.placeholder(tf.float32, [batchsize], name='step2_inner_timerates_in')
        self.timerates_expand = tf.expand_dims(self.timerates, -1)
        self.timerates_expand = tf.expand_dims(self.timerates_expand, -1)
        self.timerates_expand = tf.expand_dims(self.timerates_expand,
                                               -1)  #batchsize*1*1*1

        #tesorfs
        #这里是第一部的输出光流
        self.optical_0_1 = self.graph.get_tensor_by_name("G_opticalflow_0_2:0")
        self.optical_1_0 = self.graph.get_tensor_by_name("G_opticalflow_2_0:0")
        self.outimg = self.graph.get_tensor_by_name("G_net_generate:0")

        #第一部中的batchsize,这里可以当作timestep使用
        self.batchsize_inputimg = self.imgs_pla.get_shape().as_list()[0]

        #输入图像形状
        self.img_size_w = self.imgs_pla.get_shape().as_list()[2]
        self.img_size_h = self.imgs_pla.get_shape().as_list()[1]
        self.img_size = [self.img_size_h, self.img_size_w]
        print(self.imgs_pla
              )  #Tensor("imgs_in:0", shape=(12, 180, 320, 9), dtype=float32)
        self.frame0 = self.imgs_pla[:, :, :, :img_channel]
        self.frame1 = self.imgs_pla[:, :, :, img_channel:img_channel * 2]
        self.frame2 = self.imgs_pla[:, :, :, img_channel * 2:]

        self.pipline_data_train = cdata.get_pipline_data_train(
            self.img_size, self.batchsize_inputimg)
        self.pipline_data_test = cdata.get_pipline_data_test(
            self.img_size, self.batchsize_inputimg)
        #输出光流形状
        self.flow_size_h = self.optical_0_1.get_shape().as_list()[1]
        self.flow_size_w = self.optical_0_1.get_shape().as_list()[2]
        self.flow_channel = self.optical_0_1.get_shape().as_list()[-1]

        self.flow_shape = [
            self.flow_size_h, self.flow_size_w, self.flow_channel * 2
        ]

        #last flow placeholder
        self.last_optical_flow = tf.placeholder(tf.float32,
                                                self.flow_shape,
                                                name='step2_last_flow')

        #这里将batch中的第一组中的前后帧和前后光流拼起来
        input_pla = tf.concat([
            self.frame0[0], self.frame2[0], self.optical_0_1[0],
            self.optical_1_0[0], self.last_optical_flow
        ], -1)  #这里将两个光流拼起来 可以考虑将前后帧也拼起来
        print(input_pla
              )  #Tensor("concat_9:0", shape=(180, 320, 14), dtype=float32)

        with tf.variable_scope("STEP2", reuse=tf.AUTO_REUSE) as scopevar:
            new_flow = self.step2_network(input_pla, training=self.training)
            kep_new_flow = [new_flow]

            for ti in range(1, self.batchsize_inputimg):
                input_pla = tf.concat([
                    self.frame0[ti], self.frame2[ti], self.optical_0_1[ti],
                    self.optical_1_0[ti], new_flow
                ], -1)  #14
                new_flow = self.step2_network(input_pla,
                                              training=self.training)
                kep_new_flow.append(new_flow)

            self.flow_next = new_flow
            self.flow_after = tf.stack(kep_new_flow,
                                       axis=0,
                                       name='step2_opticalflow')
            print(
                'self.flow_after:', self.flow_after
            )  #Tensor("STEP2/step2_opticalflow:0", shape=(12, 180, 320, 4), dtype=float32)
            print(
                'self.flow_next:', self.flow_next
            )  #Tensor("STEP2/strided_slice_79:0", shape=(180, 320, 4), dtype=float32)

        self.last_flow_init_np = np.zeros(self.flow_shape, dtype=np.float32)
        print(self.last_flow_init_np.shape)  #(180, 320, 4)
        #初始化train和test的初始0状态
        self.last_flow_new_train = self.last_flow_init_np
        self.last_flow_new_test = self.last_flow_init_np

        #########################################################################
        self.opticalflow_0_2 = tf.slice(self.flow_after, [0, 0, 0, 0],
                                        [-1, -1, -1, 2],
                                        name='step2_opticalflow_0_2')
        self.opticalflow_2_0 = tf.slice(self.flow_after, [0, 0, 0, 2],
                                        [-1, -1, -1, 2],
                                        name='step2_opticalflow_2_0')
        print('original flow:', self.opticalflow_0_2, self.opticalflow_2_0)

        #获取数据时的一些cpu上的参数,用于扩张数据和判定时序
        self.last_label_train = '#'
        self.last_label_test = '#'
        self.state_random_row_train = 0
        self.state_random_col_train = 0
        self.state_flip_train = False

        self.state_random_row_test = 0
        self.state_random_col_test = 0
        self.state_flip_test = False

        t_vars = tf.trainable_variables()
        print("trainable vars cnt:", len(t_vars))
        self.G_para = [var for var in t_vars if var.name.startswith('G')]
        self.D_para = [var for var in t_vars if var.name.startswith('D')]
        self.STEP2_para = [
            var for var in t_vars if var.name.startswith('STEP2')
        ]
        print("G param len:", len(self.G_para))
        print("D param len:", len(self.D_para))
        print("STEP2 param len:", len(self.STEP2_para))
        print(self.STEP2_para)
        '''
        trainable vars cnt: 184
        G param len: 60
        D param len: 16
        STEP2 param len: 56
        剩下的52个是VGG
        相比于前面不加第二部的128个,这里注意将VGG与step1中的VGG共享参数,否则会白白多用内存
        '''

        # weight clipping
        self.clip_D = [
            p.assign(tf.clip_by_value(p, weightclip_min, weightclip_max))
            for p in self.D_para
        ]

        self.step1_L1_loss_all,self.step1_contex_loss,self.step1_local_var_loss_all,self.step1_global_var_loss_all,self.step1_G_loss_all,\
        self.step1_ssim,self.step1_psnr,self.step1_G_net=self.loss_cal(self.optical_0_1, self.optical_1_0, LR_step1, self.G_para, scopevar.name+"_step1_losscal")

        self.step2_L1_loss_all,self.step2_contex_loss,self.step2_local_var_loss_all, self.step2_global_var_loss_all, self.step2_G_loss_all,\
        self.step2_ssim,self.step2_psnr,self.step2_G_net=self.loss_cal(self.opticalflow_0_2, self.opticalflow_2_0, LR, self.STEP2_para, scopevar.name+"_step2_losscal")

        Loss_merge=10*self.step1_L1_loss_all + 5*self.step1_contex_loss +\
                   102*self.step2_L1_loss_all + 10*self.step2_contex_loss + self.step2_global_var_loss_all*0.1

        #训练过程
        self.lr_rate = tf.train.exponential_decay(LR,
                                                  global_step=self.global_step,
                                                  decay_steps=decay_steps,
                                                  decay_rate=decay_rate)
        self.train_op = tf.train.AdamOptimizer(
            self.lr_rate, name="step2_v2_adam").minimize(
                Loss_merge,
                global_step=self.global_step,
                var_list=self.STEP2_para + self.G_para)

        #最后构建完成后初始化参数
        self.sess.run(tf.global_variables_initializer())
    def __init__(self, sess):
        self.sess = sess
        self.global_step = tf.Variable(0.0,
                                       name='global_step',
                                       dtype=tf.float32,
                                       trainable=False)

        #for data input
        self.pipline_data_train = cdata.get_pipline_data_train(
            img_size, batchsize)
        self.pipline_data_test = cdata.get_pipline_data_test(
            img_size, batchsize_test)

        #3个placeholder, img和noise,training
        self.imgs_pla = tf.placeholder(
            tf.float32,
            [batchsize, img_size_h, img_size_w, G_group_img_num * img_channel],
            name='imgs_in')
        self.training = tf.placeholder(tf.bool,
                                       name='training_in')  #这里没用上但是为了兼容就保留了
        self.timerates_pla = tf.placeholder(tf.float32, [batchsize],
                                            name='timerates_in')
        self.timerates_expand = tf.expand_dims(self.timerates_pla, -1)
        self.timerates_expand = tf.expand_dims(self.timerates_expand, -1)
        self.timerates_expand = tf.expand_dims(self.timerates_expand,
                                               -1)  #12*1*1*1

        print('placeholders:\n', 'img_placeholder:', self.imgs_pla,
              self.timerates_pla)
        #img_placeholder: Tensor("imgs_in:0", shape=(10, 180, 320, 9), dtype=float32) Tensor("timerates_in:0", shape=(10,), dtype=float32)

        self.frame0 = self.imgs_pla[:, :, :, :img_channel]
        self.frame1 = self.imgs_pla[:, :, :, img_channel:img_channel * 2]
        self.frame2 = self.imgs_pla[:, :, :, img_channel * 2:]

        with tf.variable_scope("first_unet", reuse=tf.AUTO_REUSE) as scope:
            firstinput = tf.concat([self.frame0, self.frame2], -1)
            self.first_opticalflow = my_unet(
                firstinput, 4, withbias=True)  #注意这里是直接作为optical flow

        self.first_opticalflow_0_1 = self.first_opticalflow[:, :, :, :2]
        self.first_opticalflow_0_1 = tf.identity(self.first_opticalflow_0_1,
                                                 name="first_opticalflow_0_1")
        print('first_opticalflow_0_1:', self.first_opticalflow_0_1)
        self.first_opticalflow_1_0 = self.first_opticalflow[:, :, :, 2:]
        self.first_opticalflow_1_0 = tf.identity(self.first_opticalflow_1_0,
                                                 name="first_opticalflow_1_0")
        print('first_opticalflow_1_0:', self.first_opticalflow_1_0)
        #first_opticalflow_0_1: Tensor("first_opticalflow_0_1:0", shape=(10, 180, 320, 2), dtype=float32)
        #first_opticalflow_1_0: Tensor("first_opticalflow_1_0:0", shape=(10, 180, 320, 2), dtype=float32)

        #输出光流形状
        self.flow_size_h = self.first_opticalflow_0_1.get_shape().as_list()[1]
        self.flow_size_w = self.first_opticalflow_0_1.get_shape().as_list()[2]
        self.flow_channel = self.first_opticalflow_0_1.get_shape().as_list(
        )[-1]

        self.flow_shape = [
            self.flow_size_h, self.flow_size_w, self.flow_channel * 2
        ]

        #反向光流算中间帧
        self.first_opticalflow_t_0=tf.add( -(1-self.timerates_expand)*self.timerates_expand*self.first_opticalflow_0_1 ,\
                                      self.timerates_expand*self.timerates_expand*self.first_opticalflow_1_0 , name="first_opticalflow_t_0")
        self.first_opticalflow_t_2=tf.add( (1-self.timerates_expand)*(1-self.timerates_expand)*self.first_opticalflow_0_1 ,\
                                      self.timerates_expand*(self.timerates_expand-1)*self.first_opticalflow_1_0, name="first_opticalflow_t_2")

        #2种方法合成t时刻的帧
        self.first_img_flow_2_t = self.warp_op(
            self.frame2, -self.first_opticalflow_t_2)  #!!!
        self.first_img_flow_0_t = self.warp_op(
            self.frame0, -self.first_opticalflow_t_0)  #!!!

        #虽然论文里用不到第一步的输出中间帧,但是这里也给他输出看看效果
        self.first_output = tf.add(
            self.timerates_expand * self.first_img_flow_2_t,
            (1 - self.timerates_expand) * self.first_img_flow_0_t,
            name="first_outputimg")
        print('first output img:', self.first_output)
        #first output img: Tensor("first_outputimg:0", shape=(10, 180, 320, 3), dtype=float32)

        #利用光流前后帧互相合成
        self.first_img_flow_2_0 = self.warp_op(
            self.frame2, -self.first_opticalflow_0_1)  #frame2->frame0
        self.first_img_flow_0_2 = self.warp_op(
            self.frame0, -self.first_opticalflow_1_0)  #frame0->frame2

        ####################################################################################################################3
        #第二个unet
        with tf.variable_scope("second_unet", reuse=tf.AUTO_REUSE) as scope:
            secinput=tf.concat([self.frame0, self.frame2, \
                                self.first_opticalflow_0_1, self.first_opticalflow_1_0, \
                                self.first_opticalflow_t_2, self.first_opticalflow_t_0,\
                                self.first_img_flow_2_t, self.first_img_flow_0_t], -1)
            print(secinput)
            self.second_opticalflow = my_unet(
                secinput, 5, withbias=True)  #注意这里是直接作为optical flow
        self.second_opticalflow_t_0 = tf.add(
            self.second_opticalflow[:, :, :, :2],
            self.first_opticalflow_t_0,
            name="second_opticalflow_t_0")
        self.second_opticalflow_t_1 = tf.add(self.second_opticalflow[:, :, :,
                                                                     2:4],
                                             self.first_opticalflow_t_2,
                                             name="second_opticalflow_t_1")
        print('second_opticalflow_t_0:', self.second_opticalflow_t_0)
        print('second_opticalflow_t_1:', self.second_opticalflow_t_1)
        #second_opticalflow_t_0: Tensor("second_opticalflow_t_0:0", shape=(10, 180, 320, 2), dtype=float32)
        #second_opticalflow_t_1: Tensor("second_opticalflow_t_1:0", shape=(10, 180, 320, 2), dtype=float32)

        self.vmap_t_0 = tf.expand_dims(
            tf.sigmoid(self.second_opticalflow[:, :, :, -1]), -1)
        self.vmap_t_1 = 1 - self.vmap_t_0

        #2种方法合成t时刻的帧
        self.second_img_flow_1_t = self.warp_op(
            self.frame2, -self.second_opticalflow_t_1)  #!!!
        self.second_img_flow_0_t = self.warp_op(
            self.frame0, -self.second_opticalflow_t_0)  #!!!

        #最终输出的图
        print(self.timerates_expand, self.vmap_t_0, self.second_img_flow_0_t)
        #Tensor("ExpandDims_2:0", shape=(6, 1, 1, 1), dtype=float32) Tensor("Sigmoid:0", shape=(6, 180, 320, 1), dtype=float32)
        #Tensor("dense_image_warp_5/Reshape_1:0", shape=(6, 180, 320, 3), dtype=float32)
        self.second_output=tf.div(  ( (1-self.timerates_expand)*self.vmap_t_0*self.second_img_flow_0_t+self.timerates_expand*self.vmap_t_1*self.second_img_flow_1_t),  \
                             ((1-self.timerates_expand)*self.vmap_t_0+self.timerates_expand*self.vmap_t_1) , name="second_outputimg" )
        print('second output img:', self.second_output)
        #second output img: Tensor("second_outputimg:0", shape=(10, 180, 320, 3), dtype=float32)

        #计算loss
        self.second_L1_loss_interframe,self.first_warp_loss,self.second_contex_loss,self.second_local_var_loss_all,self.second_global_var_loss_all,self.second_ssim,self.second_psnr,\
                self.first_L1_loss_interframe, self.first_ssim, self.first_psnr=self.loss_cal_all()

        self.G_loss_all = 204 * self.second_L1_loss_interframe + 102 * self.first_warp_loss + 0.005 * self.second_contex_loss + self.second_global_var_loss_all

        #获取数据时的一些cpu上的参数,用于扩张数据和判定时序
        self.last_flow_init_np = np.zeros(self.flow_shape, dtype=np.float32)
        print(self.last_flow_init_np.shape)  #(180, 320, 4)

        #初始化train和test的初始0状态
        self.last_flow_new_train = self.last_flow_init_np
        self.last_flow_new_test = self.last_flow_init_np

        self.last_label_train = '#'
        self.last_label_test = '#'
        self.state_random_row_train = 0
        self.state_random_col_train = 0
        self.state_flip_train = False

        self.state_random_row_test = 0
        self.state_random_col_test = 0
        self.state_flip_test = False

        #为了兼容性
        self.batchsize_inputimg = batchsize
        self.img_size_w = img_size_w
        self.img_size_h = img_size_h

        t_vars = tf.trainable_variables()
        print("trainable vars cnt:", len(t_vars))
        self.first_para = [
            var for var in t_vars if var.name.startswith('first')
        ]
        self.sec_para = [
            var for var in t_vars if var.name.startswith('second')
        ]
        self.vgg_para = [var for var in t_vars if var.name.startswith('VGG')]
        print("first param len:", len(self.first_para))
        print("second param len:", len(self.sec_para))
        print("VGG param len:", len(self.vgg_para))
        print(self.vgg_para)
        '''
        trainable vars cnt: 144
        first param len: 46
        second param len: 46
        VGG param len: 52
        '''

        #训练过程
        self.lr_rate = tf.train.exponential_decay(base_lr,
                                                  global_step=self.global_step,
                                                  decay_steps=decay_steps,
                                                  decay_rate=decay_rate)
        self.train_op = tf.train.AdamOptimizer(self.lr_rate, name="superslomo_adam").minimize(self.G_loss_all,  \
                                                                                              global_step=self.global_step  , var_list=self.first_para+self.sec_para  )

        # weight clipping
        #self.clip_D = [p.assign(tf.clip_by_value(p, weightclip_min, weightclip_max)) for p in self.D_para]

        #最后构建完成后初始化参数
        self.sess.run(tf.global_variables_initializer())
    def __init__(self, sess):
        self.sess = sess

        #加载原模型
        saver = tf.train.import_meta_graph(op.join(modelpath, meta_name) )
        saver.restore(self.sess, tf.train.latest_checkpoint(modelpath))
        self.graph = tf.get_default_graph()
        self.global_step = tf.Variable(0.0, name='step2_global_step',dtype=tf.float32, trainable=False)
        
        #placeholders
        self.imgs_pla= self.graph.get_tensor_by_name('imgs_in:0')
        self.training= self.graph.get_tensor_by_name("training_in:0")
        self.timerates= self.graph.get_tensor_by_name("timerates_in:0")
        
        #tesorfs
        self.optical_0_1=self.graph.get_tensor_by_name("G_opticalflow_0_2:0")
        self.optical_1_0=self.graph.get_tensor_by_name("G_opticalflow_2_0:0")
        self.outimg = self.graph.get_tensor_by_name("G_net_generate:0")
        
        # 注意这里构建lstm的输入是根据第一部的输出直接来的
        #self.input_pla = tf.placeholder(tf.float32, [batchsize,  flow_size_h, flow_size_w, input_channel], name='step2_opticalflow_in')
        self.input_pla=tf.concat([self.imgs_pla[:,:,:,:img_channel], self.optical_0_1, self.optical_1_0, self.imgs_pla[:,:,:, img_channel*2:]], -1)  #这里将两个光流拼起来 可以考虑将前后帧也拼起来
        self.lstm_input_channel=self.input_pla.get_shape().as_list()[-1]
        print (self.input_pla)  #Tensor("concat_9:0", shape=(12, 180, 320, ?), dtype=float32)
        
        #第一部中的batchsize,这里可以当作timestep使用
        self.batchsize_inputimg=self.imgs_pla.get_shape().as_list()[0]
        
        #输入图像形状
        self.img_size_w=self.imgs_pla.get_shape().as_list()[2]
        self.img_size_h=self.imgs_pla.get_shape().as_list()[1]
        self.img_size=[self.img_size_h, self.img_size_w]
        print (self.imgs_pla) #Tensor("imgs_in:0", shape=(12, 180, 320, 9), dtype=float32)
        
        self.pipline_data_train=cdata.get_pipline_data_train(self.img_size, self.batchsize_inputimg)
        self.pipline_data_test=cdata.get_pipline_data_test(self.img_size, self.batchsize_inputimg)
        #输出光流形状
        self.flow_size_h=self.optical_0_1.get_shape().as_list()[1]
        self.flow_size_w=self.optical_0_1.get_shape().as_list()[2]
        self.flow_channel=self.optical_0_1.get_shape().as_list()[-1]
        
        #lstm的每个状态(c,h)的形状
        self.state_shape=[batchsize, self.flow_size_h, self.flow_size_w, output_channel]
        
        #state placeholder
        self.state_pla_c = tf.placeholder(tf.float32, self.state_shape, name='step2_state_in_c')
        self.state_pla_h = tf.placeholder(tf.float32, self.state_shape, name='step2_state_in_h')
        
        #获取数据时的一些cpu上的参数,用于扩张数据和判定时序
        self.last_label_train='#'
        self.last_label_test='#'
        self.state_random_row_train=0
        self.state_random_col_train=0
        self.state_flip_train=False
        
        self.state_random_row_test=0
        self.state_random_col_test=0
        self.state_flip_test=False
        
        #self.imgs_pla = tf.placeholder(tf.float32, [batchsize, img_size_h, img_size_w, group_img_num*img_channel], name='step2_oriimgs_in')
        self.frame0=self.imgs_pla[:,:,:,:img_channel]
        self.frame1=self.imgs_pla[:,:,:,img_channel:img_channel*2]
        self.frame2=self.imgs_pla[:,:,:,img_channel*2:]
        
        #self.timerates_pla=tf.placeholder(tf.float32, [batchsize], name='step2_inner_timerates_in')
        self.timerates_expand=tf.expand_dims(self.timerates, -1)
        self.timerates_expand=tf.expand_dims(self.timerates_expand, -1)
        self.timerates_expand=tf.expand_dims(self.timerates_expand, -1) #batchsize*1*1*1
        
        with tf.variable_scope("STEP2",  reuse=tf.AUTO_REUSE) as scopevar:
            self.cell = tf.contrib.rnn.ConvLSTMCell(conv_ndims=2, input_shape=[self.flow_size_h, self.flow_size_w, self.lstm_input_channel], \
                                                    output_channels=output_channel, kernel_shape=[kernel_len, kernel_len])
            
            self.state_init = self.cell.zero_state(batch_size=batchsize, dtype=tf.float32)
            self.state_init_np=( np.zeros(self.state_shape, dtype=np.float32), np.zeros(self.state_shape, dtype=np.float32) )
            print (self.state_init) #LSTMStateTuple(c=<tf.Tensor 'ConvLSTMCellZeroState/zeros:0' shape=(2, 180, 320, 12) dtype=float32>, h=<tf.Tensor 'ConvLSTMCellZeroState/zeros_1:0' shape=(2, 180, 320, 12) dtype=float32>)
            #初始化train和test的初始0状态
            self.state_new_train=self.state_init_np
            self.state_new_test=self.state_init_np
            
            #这里开始搞lstm了
            self.input_dynamic_lstm=tf.expand_dims(self.input_pla, 0)   #这里默认lstm的输入batchsize=1,注意,设置里batchsize必须为1
            print (self.input_dynamic_lstm)  #Tensor("ExpandDims_6:0", shape=(1, 12, 180, 320, 4), dtype=float32)
            self.outputs, self.state_final = tf.nn.dynamic_rnn(self.cell, inputs =self.input_dynamic_lstm , initial_state = tf.nn.rnn_cell.LSTMStateTuple(self.state_pla_c, self.state_pla_h), time_major = False)
            
            #self.output,self.state_final=self.cell.call(inputs=self.input_pla,state=(self.state_pla_c, self.state_pla_h) )
            print (self.outputs,self.state_final)  
            #Tensor("rnn/transpose_1:0", shape=(1, 12, 180, 320, 12), dtype=float32) 
            #LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_3:0' shape=(1, 180, 320, 12) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_4:0' shape=(1, 180, 320, 12) dtype=float32>)
            
            
            self.flow_after=self.final_convlayer(self.outputs[0])
            print (self.flow_after) #Tensor("final_convlayer/BiasAdd:0", shape=(12, 180, 320, 4), dtype=float32)
            
        self.opticalflow_0_2=tf.slice(self.flow_after, [0, 0, 0, 0], [-1, -1, -1, 2], name='step2_opticalflow_0_2')
        self.opticalflow_2_0=tf.slice(self.flow_after, [0, 0, 0, 2], [-1, -1, -1, 2], name='step2_opticalflow_2_0')
        print ('original flow:',self.opticalflow_0_2, self.opticalflow_2_0)
        #original flow: Tensor("step2_opticalflow_0_2:0", shape=(12, 180, 320, 2), dtype=float32) Tensor("step2_opticalflow_2_0:0", shape=(12, 180, 320, 2), dtype=float32)
        
#         #反向光流算中间帧
#         self.opticalflow_t_0=tf.add( -(1-self.timerates_expand)*self.timerates_expand*self.opticalflow_0_2 ,\
#                                       self.timerates_expand*self.timerates_expand*self.opticalflow_2_0 , name="step2_opticalflow_t_0")
#         self.opticalflow_t_2=tf.add( (1-self.timerates_expand)*(1-self.timerates_expand)*self.opticalflow_0_2 ,\
#                                       self.timerates_expand*(self.timerates_expand-1)*self.opticalflow_2_0, name="step2_opticalflow_t_2")
#          
#         print ('two optical flow:',self.opticalflow_t_0, self.opticalflow_t_2)
#         #two optical flow: Tensor("step2_opticalflow_t_0:0", shape=(12, 180, 320, 2), dtype=float32) Tensor("step2_opticalflow_t_2:0", shape=(12, 180, 320, 2), dtype=float32)
#          
#         #2种方法合成t时刻的帧
#         self.img_flow_2_t=self.warp_op(self.frame2, -self.opticalflow_t_2) #!!!
#         self.img_flow_0_t=self.warp_op(self.frame0, -self.opticalflow_t_0) #!!!
#          
#         self.G_net=tf.add(self.timerates_expand*self.img_flow_2_t , (1-self.timerates_expand)*self.img_flow_0_t, name="step2_net_generate" )
#          
#         #利用光流前后帧互相合成
#         self.img_flow_2_0=self.warp_op(self.frame2, self.opticalflow_2_0)  #frame2->frame0
#         self.img_flow_0_2=self.warp_op(self.frame0, self.opticalflow_0_2)  #frame0->frame2
#          
#          
#         #1、contex loss
#         print ("forming conx loss:")
#         tep_G_shape=self.G_net.get_shape().as_list()[1:]
#          
#         self.contex_Genera =tf.keras.applications.VGG16(include_top=False, input_tensor=self.G_net,  input_shape=tep_G_shape).get_layer("block4_conv3").output
#         self.contex_frame1 =tf.keras.applications.VGG16(include_top=False, input_tensor=self.frame1, input_shape=tep_G_shape).get_layer("block4_conv3").output
#          
#         self.contex_loss=   tf.reduce_mean(tf.squared_difference( self.contex_frame1, self.contex_Genera), name='step2_Contex_loss')
#         print ('step2_loss_mean_contex form finished..')
#          
#          
#         #2、L1 loss
#         print ("forming L1 loss:生成帧与GT、frame2->frame0与frame0、frame0->frame2与frame2")
#         self.L1_loss_interframe =tf.reduce_mean(tf.abs(  self.G_net-self.frame1  ))
#         self.L1_loss_all        =tf.reduce_mean(tf.abs(  self.G_net-self.frame1  ) + \
#                                                 tf.abs(self.img_flow_2_0-self.frame0) + \
#                                                 tf.abs(self.img_flow_0_2-self.frame2), name='step2_G_clear_l1_loss')
#         #self.G_loss_mean_Square=  self.contex_loss*1 + self.L1_loss_all
#         print ('step2_loss_mean_l1 form finished..')
#          
#         #4 local var loss
#         self.local_var_loss_0_2=self.local_var_loss(self.opticalflow_0_2)
#         self.local_var_loss_2_0=self.local_var_loss(self.opticalflow_2_0)
#         #print ("local _var loss:",self.local_var_loss_0_2,  self.G_loss_mean_D1)
#         #local _var loss: Tensor("mean_local_var:0", shape=(), dtype=float32) Tensor("Mean_3:0", shape=(), dtype=float32)
#         self.local_var_loss_all=tf.add(self.local_var_loss_0_2, self.local_var_loss_2_0, name="step2_local_var_add")
#          
#         #5 global var loss
#         self.global_var_loss_0_2=self.global_var_loss(self.opticalflow_0_2)
#         self.global_var_loss_2_0=self.global_var_loss(self.opticalflow_2_0)
#         self.global_var_loss_all=tf.add(self.global_var_loss_0_2, self.global_var_loss_2_0, name="step2_global_var_add")
#          
#         #6 SSIM
#         self.ssim = tf.image.ssim(self.G_net, self.frame1, max_val=2.0)
#         print ("ssim:",self.ssim)  #ssim: Tensor("Mean_10:0", shape=(12,), dtype=float32)
#          
#         #7 PSNR
#         self.psnr = tf.image.psnr(self.G_net, self.frame1, max_val=2.0, name="step2_frame1_psnr")
#         print ("psnr:", self.psnr) #psnr: Tensor("G_frame1_psnr/Identity_3:0", shape=(12,), dtype=float32)
#          
#          
#         self.G_loss_all=self.contex_loss + self.L1_loss_all +  self.local_var_loss_all*0.06
#          
#         self.train_op = tf.train.AdamOptimizer(LR, name="step2_adam").minimize(self.G_loss_all)
        
        
        t_vars=tf.trainable_variables()
        print ("trainable vars cnt:",len(t_vars))
        self.G_para=[var for var in t_vars if var.name.startswith('G')]
        self.D_para=[var for var in t_vars if var.name.startswith('D')]
        self.STEP2_para=[var for var in t_vars if var.name.startswith('STEP2')]
        print ("G param len:",len(self.G_para))
        print ("D param len:",len(self.D_para))
        print ("STEP2 param len:",len(self.STEP2_para))
        print (self.STEP2_para)
        '''
        trainable vars cnt: 186
        G param len: 60
        D param len: 16
        STEP2 param len: 6
        相比于前面不加第二部的128个,这里注意将VGG与step1中的VGG共享参数,否则会白白多用内存
        '''
        
        # weight clipping
        self.clip_D = [p.assign(tf.clip_by_value(p, weightclip_min, weightclip_max)) for p in self.D_para]
        
        self.step1_train_op,\
        self.step1_L1_loss_all,self.step1_contex_loss,self.step1_local_var_loss_all,self.step1_G_loss_all,\
        self.step1_ssim,self.step1_psnr,self.step1_G_net=self.loss_cal(self.optical_0_1, self.optical_1_0, LR_step1, self.G_para, scopevar.name)
        
        self.step2_train_op,\
        self.step2_L1_loss_all,self.step2_contex_loss,self.step2_local_var_loss_all,self.step2_G_loss_all,\
        self.step2_ssim,self.step2_psnr,self.step2_G_net=self.loss_cal(self.opticalflow_0_2, self.opticalflow_2_0, LR, self.STEP2_para, scopevar.name)
        
        
        #最后构建完成后初始化参数 
        self.sess.run(tf.global_variables_initializer())
    def __init__(self, sess):
        self.sess = sess
        self.global_step = tf.Variable(0.0,
                                       name='global_step',
                                       dtype=tf.float32,
                                       trainable=False)

        #for data input
        self.pipline_data_train = cdata.get_pipline_data_train(
            img_size, batchsize)
        self.pipline_data_test = cdata.get_pipline_data_test(
            img_size, batchsize_test)

        #3个placeholder, img和noise,training
        self.imgs_pla = tf.placeholder(
            tf.float32,
            [batchsize, img_size_h, img_size_w, G_group_img_num * img_channel],
            name='imgs_in')
        self.training = tf.placeholder(tf.bool,
                                       name='training_in')  #这里没用上但是为了兼容就保留了
        self.timerates_pla = tf.placeholder(tf.float32, [batchsize],
                                            name='timerates_in')
        self.timerates_expand = tf.expand_dims(self.timerates_pla, -1)
        self.timerates_expand = tf.expand_dims(self.timerates_expand, -1)
        self.timerates_expand = tf.expand_dims(self.timerates_expand,
                                               -1)  #12*1*1*1

        print('placeholders:\n', 'img_placeholder:', self.imgs_pla,
              self.timerates_pla)
        #img_placeholder: Tensor("imgs_in:0", shape=(10, 180, 320, 9), dtype=float32) Tensor("timerates_in:0", shape=(10,), dtype=float32)

        self.frame0 = self.imgs_pla[:, :, :, :img_channel]
        self.frame1 = self.imgs_pla[:, :, :, img_channel:img_channel * 2]
        self.frame2 = self.imgs_pla[:, :, :, img_channel * 2:]

        with tf.variable_scope("first_unet", reuse=tf.AUTO_REUSE) as scope:
            firstinput = tf.concat([self.frame0, self.frame2], -1)
            #self.first_opticalflow=my_unet( firstinput, 4,training=self.training , withbias=True, withbn=False)  #注意这里是直接作为optical flow
            self.first_opticalflow = my_unet_split(
                firstinput,
                4,
                training=self.training,
                withbias=True,
                withbn=True)  #注意这里是直接作为optical flow

        self.first_opticalflow_0_1 = self.first_opticalflow[:, :, :, :2]
        self.first_opticalflow_0_1 = tf.identity(self.first_opticalflow_0_1,
                                                 name="first_opticalflow_0_1")
        print('first_opticalflow_0_1:', self.first_opticalflow_0_1)
        self.first_opticalflow_1_0 = self.first_opticalflow[:, :, :, 2:]
        self.first_opticalflow_1_0 = tf.identity(self.first_opticalflow_1_0,
                                                 name="first_opticalflow_1_0")
        print('first_opticalflow_1_0:', self.first_opticalflow_1_0)
        #first_opticalflow_0_1: Tensor("first_opticalflow_0_1:0", shape=(10, 180, 320, 2), dtype=float32)
        #first_opticalflow_1_0: Tensor("first_opticalflow_1_0:0", shape=(10, 180, 320, 2), dtype=float32)

        #输出光流形状
        self.flow_size_h = self.first_opticalflow_0_1.get_shape().as_list()[1]
        self.flow_size_w = self.first_opticalflow_0_1.get_shape().as_list()[2]
        self.flow_channel = self.first_opticalflow_0_1.get_shape().as_list(
        )[-1]

        ########################################################
        self.step2_flow_channel = 5
        self.flow_shape = [
            self.flow_size_h, self.flow_size_w, self.step2_flow_channel
        ]
        #获取数据时的一些cpu上的参数,用于扩张数据和判定时序
        self.last_flow_init_np = np.zeros(self.flow_shape, dtype=np.float32)
        print(self.last_flow_init_np.shape)  #(180, 320, 5)
        ##############################################################

        self.last_optical_flow = tf.placeholder(tf.float32,
                                                self.flow_shape,
                                                name='second_last_flow')

        #初始化train和test的初始0状态
        self.last_flow_new_train = self.last_flow_init_np
        self.last_flow_new_test = self.last_flow_init_np

        #反向光流算中间帧
        self.first_opticalflow_t_0=tf.add( -(1-self.timerates_expand)*self.timerates_expand*self.first_opticalflow_0_1 ,\
                                      self.timerates_expand*self.timerates_expand*self.first_opticalflow_1_0 , name="first_opticalflow_t_0")
        self.first_opticalflow_t_2=tf.add( (1-self.timerates_expand)*(1-self.timerates_expand)*self.first_opticalflow_0_1 ,\
                                      self.timerates_expand*(self.timerates_expand-1)*self.first_opticalflow_1_0, name="first_opticalflow_t_2")

        #2种方法合成t时刻的帧
        self.first_img_flow_2_t = self.warp_op(
            self.frame2, -self.first_opticalflow_t_2)  #!!!
        self.first_img_flow_0_t = self.warp_op(
            self.frame0, -self.first_opticalflow_t_0)  #!!!

        #虽然论文里用不到第一步的输出中间帧,但是这里也给他输出看看效果
        self.first_output = tf.add(
            self.timerates_expand * self.first_img_flow_2_t,
            (1 - self.timerates_expand) * self.first_img_flow_0_t,
            name="first_outputimg")
        print('first output img:', self.first_output)
        #first output img: Tensor("first_outputimg:0", shape=(10, 180, 320, 3), dtype=float32)

        #利用光流前后帧互相合成
        self.first_img_flow_2_0 = self.warp_op(
            self.frame2, -self.first_opticalflow_0_1)  #frame2->frame0
        self.first_img_flow_0_2 = self.warp_op(
            self.frame0, -self.first_opticalflow_1_0)  #frame0->frame2

        ####################################################################################################################3
        #第二个unet
        with tf.variable_scope("second_unet", reuse=tf.AUTO_REUSE) as scope:
            secinput=tf.concat([self.frame0[0], self.frame2[0], \
                                self.first_opticalflow_0_1[0], self.first_opticalflow_1_0[0], \
                                self.first_opticalflow_t_2[0], self.first_opticalflow_t_0[0],\
                                self.first_img_flow_2_t[0], self.first_img_flow_0_t[0],\
                                self.last_optical_flow], -1)
            secinput = tf.expand_dims(secinput, 0)
            print(
                "secinput:", secinput
            )  #secinput: Tensor("second_unet/ExpandDims:0", shape=(1, 180, 320, 25), dtype=float32)

            step2_withbn = False
            new_step2_flow = my_unet(
                secinput,
                self.step2_flow_channel,
                training=self.training,
                withbias=True,
                withbn=step2_withbn)  #注意这里是直接作为optical flow
            kep_step2_flow = [new_step2_flow]
            print("new_step2_flow:", new_step2_flow)
            #new_step2_flow: Tensor("second_unet/unet_end0_relu/LeakyRelu:0", shape=(1, 180, 320, 5), dtype=float32)

            for ti in range(1, batchsize):
                secinput=tf.concat([self.frame0[ti], self.frame2[ti], \
                                self.first_opticalflow_0_1[ti], self.first_opticalflow_1_0[ti], \
                                self.first_opticalflow_t_2[ti], self.first_opticalflow_t_0[ti],\
                                self.first_img_flow_2_t[ti], self.first_img_flow_0_t[ti],\
                                new_step2_flow[0] ], -1)
                secinput = tf.expand_dims(secinput, 0)
                new_step2_flow = my_unet(secinput,
                                         self.step2_flow_channel,
                                         withbias=True,
                                         withbn=step2_withbn)
                kep_step2_flow.append(new_step2_flow)

            self.second_batch_last_flow = new_step2_flow[0]
            #self.second_batch_last_flow=tf.identity(self.second_batch_last_flow, name="second_batch_last_flow")
            print(
                "second_batch_last_flow:", self.second_batch_last_flow
            )  #Tensor("second_unet/strided_slice_89:0", shape=(180, 320, 5), dtype=float32)
            self.second_opticalflow = tf.concat(kep_step2_flow, 0)
            print(
                "self.second_opticalflow:", self.second_opticalflow
            )  #self.second_opticalflow: Tensor("second_unet/concat_60:0", shape=(10, 180, 320, 5), dtype=float32)
        self.second_opticalflow_t_0 = tf.add(
            self.second_opticalflow[:, :, :, :2],
            self.first_opticalflow_t_0,
            name="second_opticalflow_t_0")
        self.second_opticalflow_t_1 = tf.add(self.second_opticalflow[:, :, :,
                                                                     2:4],
                                             self.first_opticalflow_t_2,
                                             name="second_opticalflow_t_1")
        print('second_opticalflow_t_0:', self.second_opticalflow_t_0)
        print('second_opticalflow_t_1:', self.second_opticalflow_t_1)
        #second_opticalflow_t_0: Tensor("second_opticalflow_t_0:0", shape=(10, 180, 320, 2), dtype=float32)
        #second_opticalflow_t_1: Tensor("second_opticalflow_t_1:0", shape=(10, 180, 320, 2), dtype=float32)

        self.vmap_t_0 = tf.expand_dims(
            tf.sigmoid(self.second_opticalflow[:, :, :, -1]), -1)
        self.vmap_t_1 = 1 - self.vmap_t_0

        #2种方法合成t时刻的帧
        self.second_img_flow_1_t = self.warp_op(
            self.frame2, -self.second_opticalflow_t_1)  #!!!
        self.second_img_flow_0_t = self.warp_op(
            self.frame0, -self.second_opticalflow_t_0)  #!!!

        #最终输出的图
        print(self.timerates_expand, self.vmap_t_0, self.second_img_flow_0_t)
        #Tensor("ExpandDims_2:0", shape=(6, 1, 1, 1), dtype=float32) Tensor("Sigmoid:0", shape=(6, 180, 320, 1), dtype=float32)
        #Tensor("dense_image_warp_5/Reshape_1:0", shape=(6, 180, 320, 3), dtype=float32)
        self.second_output=tf.div(  ( (1-self.timerates_expand)*self.vmap_t_0*self.second_img_flow_0_t+self.timerates_expand*self.vmap_t_1*self.second_img_flow_1_t),  \
                             ((1-self.timerates_expand)*self.vmap_t_0+self.timerates_expand*self.vmap_t_1) , name="second_outputimg" )
        print('second output img:', self.second_output)
        #second output img: Tensor("second_outputimg:0", shape=(10, 180, 320, 3), dtype=float32)

        #判别器的网络构建
        self.D_1_net_F, self.D_1_net_F_logit = Discriminator_net(
            self.second_output, name="D1", training=self.training)
        self.D_1_net_T, self.D_1_net_T_logit = Discriminator_net(
            self.frame1, name="D1", training=self.training)
        #D的loss计算
        self.D_1_net_loss_sum, _, _ = self.D_loss_TandF_logits(
            self.D_1_net_T_logit, self.D_1_net_F_logit, "D_1_net")

        #计算loss
        self.second_L1_loss_interframe,self.first_warp_loss,self.second_contex_loss,self.second_local_var_loss_all,self.second_global_var_loss_all,self.second_ssim,self.second_psnr,\
                self.first_L1_loss_interframe, self.first_ssim, self.first_psnr, self.second_GAN_loss_mean_D1=self.loss_cal_all()

        #训练G的总loss
        self.G_loss_all = 100 * self.second_L1_loss_interframe + 30 * (
            self.first_L1_loss_interframe +
            self.first_warp_loss) + 0.05 * self.second_contex_loss
        #self.second_global_var_loss_all
        #+ self.second_GAN_loss_mean_D1*0.03

        #训练D的总loss
        self.D_loss_all = self.D_1_net_loss_sum

        #####################################
        self.last_label_train = '#'
        self.last_label_test = '#'
        self.state_random_row_train = 0
        self.state_random_col_train = 0
        self.state_flip_train = False

        self.state_random_row_test = 0
        self.state_random_col_test = 0
        self.state_flip_test = False

        #为了兼容性
        self.batchsize_inputimg = batchsize
        self.img_size_w = img_size_w
        self.img_size_h = img_size_h

        t_vars = tf.trainable_variables()
        print("trainable vars cnt:", len(t_vars))
        self.first_para = [
            var for var in t_vars if var.name.startswith('first')
        ]
        self.sec_para = [
            var for var in t_vars if var.name.startswith('second')
        ]
        self.vgg_para = [var for var in t_vars if var.name.startswith('VGG')]
        self.D_para = [var for var in t_vars if var.name.startswith('D')]
        print("first param len:", len(self.first_para))
        print("second param len:", len(self.sec_para))
        print("VGG param len:", len(self.vgg_para))
        print("D param len:", len(self.D_para))
        print(self.vgg_para)
        '''
        trainable vars cnt: 144
        first param len: 46
        second param len: 46
        VGG param len: 52
        '''

        #G训练过程
        self.lr_rate = tf.train.exponential_decay(base_lr,
                                                  global_step=self.global_step,
                                                  decay_steps=decay_steps,
                                                  decay_rate=decay_rate)
        self.train_op_G = tf.train.AdamOptimizer(self.lr_rate, beta1=beta1, name="superslomo_adam_G").minimize(self.G_loss_all,  \
                                                                                              global_step=self.global_step  , var_list=self.first_para+self.sec_para  )

        # weight clipping
        self.clip_D = [
            p.assign(tf.clip_by_value(p, weightclip_min, weightclip_max))
            for p in self.D_para
        ]

        #D训练过程
        self.train_op_D = tf.train.AdamOptimizer(
            self.lr_rate, beta1=beta1,
            name="superslomo_adam_D").minimize(self.D_loss_all,
                                               var_list=self.D_para)

        #最后构建完成后初始化参数
        self.sess.run(tf.global_variables_initializer())