def res_block_lstm(x, hidden_state=None, keep_p=1.0, name="resnet_lstm"): orig_x = x if hidden_state is not None: hidden_state_1 = hidden_state[0] hidden_state_2 = hidden_state[1] else: hidden_state_1 = None hidden_state_2 = None filter_size = orig_x.get_shape().as_list()[-1] with tf.variable_scope(name + "_conv_LSTM_1", initializer = tf.random_uniform_initializer(-0.01, 0.01)) as scope: lstm_cell_1 = BasicConvLSTMCell.BasicConvLSTMCell([int(x.get_shape()[1]),int(x.get_shape()[2])], [3,3], filter_size) if hidden_state_1 == None: batch_size = x.get_shape()[0] hidden_state_1 = lstm_cell_1.zero_state(batch_size, tf.float32) x_1, hidden_state_1 = lstm_cell_1(x, hidden_state_1, scope=scope) if keep_p < 1.0: x_1 = tf.nn.dropout(x_1, keep_prob=keep_p) with tf.variable_scope(name + "_conv_LSTM_2", initializer = tf.random_uniform_initializer(-0.01, 0.01)) as scope: lstm_cell_2 = BasicConvLSTMCell.BasicConvLSTMCell([int(x_1.get_shape()[1]),int(x_1.get_shape()[2])], [3,3], filter_size) if hidden_state_2 == None: batch_size = x_1.get_shape()[0] hidden_state_2 = lstm_cell_2.zero_state(batch_size, tf.float32) x_2, hidden_state_2 = lstm_cell_2(x_1, hidden_state_2, scope=scope) return orig_x + x_2, [hidden_state_1, hidden_state_2]
def inference(self, videoslides, mask_in, mask_h): #videoslides: [batch framenum h w num_features] with tf.variable_scope('inference'): shapes = videoslides.get_shape().as_list() #shapes2 = GTs.get_shape().as_list() assert len(shapes)==5 self.batch_size = videoslides.get_shape()[0].value #self.framenum = videoslides.get_shape()[1].value # assert self.framenum % self.maximgbatch == 0 # frmaenum shoube be the multiple of maximgbatch scope = 'layer_1' with tf.variable_scope('conv_lstm', initializer=tf.random_uniform_initializer(-.01, 0.1)): # cell_1 = BasicConvLSTMCell.BasicConvLSTMCell([56, 56], [3, 3], 128, state_is_tuple = False) # input size,fliter size, input channals # cell_2 = BasicConvLSTMCell.BasicConvLSTMCell([56, 56], [3, 3], 128, state_is_tuple = False) # input size,fliter size, input channals cell_1 = BasicConvLSTMCell.BasicConvLSTMCell([28, 28], [3, 3], 128, state_is_tuple=False) # input size,fliter size, input channals cell_2 = BasicConvLSTMCell.BasicConvLSTMCell([28, 28], [3, 3], 128, state_is_tuple=False) # input size,fliter size, input channals new_state_1 = cell_1.zero_state(self.batch_size, 2, tf.float32) new_state_2 = cell_2.zero_state(self.batch_size, 2, tf.float32) # print(videoslides.get_shape().as_list()) for indexframe in range(self.framenum): frame = videoslides[:, indexframe, ...] #print(indexframe+self.gapnum) frame_gap = videoslides[:, indexframe+self.gapnum, ...] #GTframe = GTs[:, indexframe, ...] Yolo_features = self.YOLO_tiny_inference(frame) Presalmap = self.Coarse_salmap(Yolo_features) if self.startflagcnn == True: self.yolofeatures_colllection = self.pretrain_var_collection self.pretrain_var_collection = [] salmask = self._normlized_0to1(Presalmap) salmask = salmask*(1-self.salmask_lb)+self.salmask_lb Flow_features = self.flownet_with_conv(frame, frame_gap, salmask) CNNout = self.Final_inference(Yolo_features, Flow_features) if self.startflagcnn == True: self.flowfeatures_colllection = self.pretrain_var_collection y_1, new_state_1 = cell_1(CNNout, new_state_1,mask_in[...,0:4], mask_h[...,0:4], self.dp_in, self.dp_h, 'lstm_layer1') y_2, new_state_2 = cell_2(y_1, new_state_2,mask_in[...,4:8], mask_h[...,4:8], self.dp_in, self.dp_h, 'lstm_layer2') deconv = self.transpose_conv_layer('deconv', y_2, 4, 16, stride=2, pretrain=False, trainable=True) deconv2 = self.transpose_conv_layer('deconv2', deconv, 4, 1, stride=2, linear=True, pretrain=False, trainable=True) if self.startflagcnn == True: tf.get_variable_scope().reuse_variables() self.trainable_var_collection.extend(cell_1.trainable_var_collection) self.trainable_var_collection.extend(cell_2.trainable_var_collection) self.startflagcnn = False output = self._normlized_0to1(deconv2) #norm_GT = self._normlized(GTframe) norm_output = self._normlized(output) # frame_loss = norm_GT * tf.log(self.eps + norm_GT / (norm_output + self.eps)) #frame_loss = tf.reduce_sum(frame_loss) / norm_GT.get_shape()[0].value # tf.add_to_collection('losses', frame_loss) output = tf.expand_dims(output, 1) if indexframe == 0: tempout = output else: tempout = tf.concat([tempout, output], axis=1) self.out = tempout
def __init__(self, input_dim=[64,64,2], att_inputs=[], att_nodes=1024, batch_size=32, layer={}, layer_param={}, input_steps=10, output_steps=10, weighted_loss=False, reg_lambda=0.02): #self.input_dim = input_dim self.input_row = input_dim[0] self.input_col = input_dim[1] self.input_channel = input_dim[2] self.att_inputs = att_inputs self.att_nodes = att_nodes self.att_layer = layer['attention'] self.att_layer_param = layer_param['attention'] self.batch_size = batch_size self.seq_length = input_steps + output_steps self.input_steps = input_steps self.output_steps = output_steps self.weighted_loss = weighted_loss self.reg_lambda = reg_lambda self.encoder_layer = layer['encoder'] self.decoder_layer = layer['decoder'] self.encoder_layer_param = layer_param['encoder'] self.decoder_layer_param = layer_param['decoder'] # initialize conv_lstm cell self.encoder_conv_lstm = [] self.encoder_state = [] for i in range(len(self.encoder_layer)): if self.encoder_layer[i]=='conv_lstm': convLSTM = BasicConvLSTMCell.BasicConvLSTMCell( self.encoder_layer_param[i][0], self.encoder_layer_param[i][1], self.encoder_layer_param[i][2], state_is_tuple=True) self.encoder_conv_lstm.append(convLSTM) self.encoder_state.append(convLSTM.zero_state(self.batch_size)) #self.init_state_encoder_conv_lstm = self.encoder_conv_lstm[0].zero_state(self.batch_size) self.decoder_conv_lstm = [] self.decoder_state = [] for i in range(len(self.decoder_layer)): if self.decoder_layer[i]=='conv_lstm': convLSTM = BasicConvLSTMCell.BasicConvLSTMCell( self.decoder_layer_param[i][0], self.decoder_layer_param[i][1], self.decoder_layer_param[i][2], state_is_tuple=True) self.decoder_conv_lstm.append(convLSTM) self.decoder_state.append(convLSTM.zero_state(self.batch_size)) #self.init_state_decoder_conv_lstm = self.decoder_conv_lstm[0].zero_state(self.batch_size) self.weight_initializer = tf.contrib.layers.xavier_initializer() self.const_initializer = tf.constant_initializer() self.x = tf.placeholder(tf.float32, [None, self.input_steps, self.input_row, self.input_col, self.input_channel]) self.y = tf.placeholder(tf.float32, [None, self.output_steps, self.input_row, self.input_col, self.input_channel])
def RNN(x): x_dropout = tf.nn.dropout(x, keep_prob) x_unwrap = [] # create network with tf.variable_scope('conv_lstm', initializer=tf.random_uniform_initializer( -.01, 0.1)): cell = BasicConvLSTMCell.BasicConvLSTMCell([FLAGS.hight, FLAGS.width], [3, 3], 1) new_state = cell.zero_state(FLAGS.batch_size, tf.float32) # conv network for i in xrange(FLAGS.seq_length - 1): # conv1 if i < FLAGS.seq_start: #inputs, kernel_size, stride, num_features, idx #conv1 = ld.conv_layer(x_dropout[:,i,:,:,:], 3, 2, 8, "encode_1") conv1 = ld.transpose_conv_layer(x_dropout[:, i, :, :, :], 3, 1, 1, "decode_1") else: conv1 = ld.transpose_conv_layer(x_1, 3, 1, 1, "decode_1") y_0 = conv1 # conv lstm cell y_1, new_state = cell(y_0, new_state) # x_1 x_1 = ld.conv_layer(y_1, 3, 1, 1, "encode_1", True) if i >= FLAGS.seq_start: x_unwrap.append(x_1) # set reuse to true after first go if i == 0: tf.get_variable_scope().reuse_variables() # pack them all together x_unwrap = tf.stack(x_unwrap) x_unwrap = tf.transpose(x_unwrap, [1, 0, 2, 3, 4]) return x_unwrap
def network(inputs, hidden, lstm=True): # conv1 conv1 = ld.conv_layer(inputs, 3, 2, 8, "encode_1") # conv2 conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2") # conv3 conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3") # conv4 conv4 = ld.conv_layer(conv3, 3, 1, 8, "encode_4") # conv5 conv5 = ld.conv_layer(conv4, 3, 2, 8, "encode_5") conv6 = ld.conv_layer(conv5, 1, 1, 4, "encode_6") y_0 = conv6 with tf.variable_scope('conv_lstm', initializer=tf.random_uniform_initializer(-.01, 0.1)): cell = BasicConvLSTMCell.BasicConvLSTMCell([8, 8], [3, 3], 4) if hidden is None: hidden = cell.zero_state(FLAGS.batch_size, tf.float32) y_1, hidden = cell(y_0, hidden) # conv5 conv5 = ld.transpose_conv_layer(y_1, 1, 1, 8, "decode_5") # conv6 conv6 = ld.transpose_conv_layer(conv5, 3, 2, 8, "decode_6") # conv7 conv7 = ld.transpose_conv_layer(conv6, 3, 1, 8, "decode_7") conv8 = ld.transpose_conv_layer(conv7, 3, 2, 8, "decode_8") conv9 = ld.transpose_conv_layer(conv8, 3, 1, 8, "decode_9") # x_1 x_1 = ld.transpose_conv_layer(conv9, 3, 2, 1, "decode_10", True) # set activation to linear return x_1, hidden
def conv_lstm_cell(inputs, hidden): # some convolutional layers before the convLSTM cell conv1 = ld.conv_layer(inputs, 3, 2, 4, "encode_1") conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2") conv3 = ld.conv_layer(conv2, 3, 1, 16, "encode_3") # take output from first conv layers as input to convLSTM cell with tf.variable_scope('conv_lstm', initializer=tf.random_uniform_initializer( -.01, 0.1)): cell = BasicConvLSTMCell.BasicConvLSTMCell([16, 16], [3, 3], 16) if hidden is None: hidden = cell.zero_state(FLAGS.batch_size, tf.float32) cell_output, hidden = cell(conv3, hidden) # some convolutional layers after the convLSTM cell conv5 = ld.transpose_conv_layer(cell_output, 1, 1, 16, "decode_5") conv6 = ld.transpose_conv_layer(conv5, 3, 1, 8, "decode_6") conv7 = ld.transpose_conv_layer(conv6, 3, 1, 4, "decode_7") # the last convolutional layer will use linear activations x_1 = ld.transpose_conv_layer(conv7, 3, 2, 1, "decode_8", True) # return the output of the last conv layer, & the hidden cell state return x_1, hidden
def __call__(self, input_layer, output_dim, hidden, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, in_dim=None, padding='SAME', name="conv_lstm"): with tf.variable_scope(name, initializer=tf.random_uniform_initializer( -.01, 0.1)): cell = BasicConvLSTMCell.BasicConvLSTMCell( [input_layer.shape[-3], input_layer.shape[-2]], [k_h, k_w, d_h, d_w], output_dim) if hidden is None: hidden = cell.zero_state(cfg.TRAIN.BATCH_SIZE, tf.float32) y_1, hidden = cell(input_layer, hidden) # biases = self.variable('biases', [output_dim], init=tf.constant_initializer(0.0)) #import ipdb; ipdb.set_trace() # return input_layer.with_tensor(tf.nn.bias_add(conv, biases), parameters=self.vars) #return input_layer.with_tensor(conv, parameters=self.vars) return y_1, hidden
def SpatialConnectionAwareNetwork(inputs, hidden, batch_size, lstm=True): conv1 = conv_layer(inputs, 5, 1, 32, 1, 'encode_1') conv2 = conv_layer(conv1, 3, 2, 64, 1, 'encode_2') conv3 = conv_layer(conv2, 3, 1, 64, 1, 'encode_3') conv4 = conv_layer(conv3, 3, 2, 128, 1, 'encode_4') y_0 = conv4 if lstm: # conv lstm cell with tf.variable_scope('conv_lstm', initializer=tf.random_uniform_initializer( -.01, 0.1)): cell = cl.BasicConvLSTMCell([90, 108], [3, 3], 128) if hidden is None: hidden = cell.zero_state(batch_size, tf.float32) y_1, hidden = cell(y_0, hidden) else: y_1 = conv_layer(y_0, 3, 1, 128, 1, 'encode_5') conv6 = conv_layer(y_1, 3, 1, 128, 1, 'decode_6') conv7 = transpose_conv_layer(conv6, 4, 2, 64, 1, 'decode_7') + conv3 conv8 = conv_layer(conv7, 3, 1, 64, 1, 'decode_8') conv9 = transpose_conv_layer(conv8, 4, 2, 32, 1, 'decode_9') + conv1 conv10 = conv_layer(conv9, 3, 1, 64, 1, 'decode_10') # x_1 conv11 = conv_layer(conv10, 5, 1, 1, 1, 'decode_11', True) + inputs[:, :, :, 0:1] # set activation to linear x_1 = conv11 return x_1, hidden
def convLSTM(input, hidden, filters, kernel, scope): with tf.variable_scope(scope, initializer = tf.truncated_normal_initializer(stddev=0.1)): cell = BasicConvLSTMCell.BasicConvLSTMCell([input.get_shape()[1], input.get_shape()[2]], kernel, filters) if hidden is None: hidden = cell.zero_state(input.get_shape()[0], tf.float32) y_, hidden = cell(input, hidden) return y_, hidden
def convLSTM(input, hidden, filters, kernel, scope): cell = BasicConvLSTMCell.BasicConvLSTMCell( [input.shape[1], input.shape[2]], kernel, filters) if hidden is None: hidden = cell.zero_state(input.shape[0]).float() y_, hideen = cell(input, hidden) return y_, hidden
def network(inputs, hidden): conv1 = ld.conv2d(inputs, (3,3), (1,1), 1, "encode_1") # conv2 # conv2 = ld.conv2d(conv1, (3,3), (1,2), 8, "encode_2") # conv3 # conv3 = ld.conv2d(conv2, (3,3), (1,2), 8, "encode_3") # conv4 #conv4 = ld.conv2d(conv3, (1,1), (1,1), 4, "encode_4") #y_0 = conv4 y_0 = conv1 # conv lstm cell with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)): cell = BasicConvLSTMCell.BasicConvLSTMCell([16,128], [8,8], 16) if hidden is None: hidden = cell.zero_state(FLAGS.batch_size, tf.float32) y_1, hidden = cell(y_0, hidden) with tf.variable_scope('conv_lstm_2', initializer = tf.random_uniform_initializer(-.01, 0.1)): cell2 = BasicConvLSTMCell.BasicConvLSTMCell([16,128], [8,8], 16) if hidden is None: hidden = cell2.zero_state(FLAGS.batch_size, tf.float32) y_2, hidden = cell2(y_1, hidden) with tf.variable_scope('conv_lstm_3', initializer = tf.random_uniform_initializer(-.01, 0.1)): cell3 = BasicConvLSTMCell.BasicConvLSTMCell([16,128], [8,8], 16) if hidden is None: hidden = cell3.zero_state(FLAGS.batch_size, tf.float32) y_3, hidden = cell3(y_2, hidden) # conv5 #conv5 = ld.transpose_conv_layer(y_3, (1,1), (1,1), 8, "decode_5") # conv6 # conv6 = ld.transpose_conv_layer(conv5, (3,3), (1,2), 8, "decode_6") # conv7 # conv7 = ld.transpose_conv_layer(conv6, (3,3), (1,2), 8, "decode_7") # x_1 conv7 = y_3 x_1 = ld.transpose_conv_layer(conv7, (3,3), (1,1), 1, "decode_8", True) # set activation to linear return x_1, hidden
def define_graph(self): """ define a Bidirectional LSTM and concatenated feature for MLP """ lstm_forward = [] lstm_backward = [] lstm_forward_state = [] lstm_backward_state = [] full_connect_w =[] full_connect_b =[] with tf.name_scope('input'): self.input_frames = tf.placeholder( tf.float32, shape=[None, self.length, self.height, self.width,1]) # use variable batch_size for more flexibility self.GT_label = tf.placeholder(tf.float32,shape = [self.batch_size,2]) with tf.variable_scope("D_scale_{}".format(self.scale_index)): for layer_id_, kernel_,kernel_num_ in zip(xrange(self.layer_num_lstm),self.kernel_size,self.kernel_num): layer_name_forward = "conv_lstm_forward_{}".format(layer_id_) layer_name_backward= "conv_lstm_backward_{}".format(layer_id_) with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)): temp_cell= BasicConvLSTMCell.BasicConvLSTMCell([self.height,self.width], [kernel_,kernel_],kernel_num_,layer_name_forward) temp_state = temp_cell.zero_state(self.batch_size,tf.float32) lstm_forward.append(temp_cell) lstm_forward_state.append(temp_state) with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)): temp_cell= BasicConvLSTMCell.BasicConvLSTMCell([self.height,self.width], [kernel_,kernel_],kernel_num_,layer_name_backward) temp_state = temp_cell.zero_state(self.batch_size,tf.float32) lstm_backward.append(temp_cell) lstm_backward_state.append(temp_state) for layer_id_ in xrange(self.layer_num_full-1): layer_name_forward = "full_forward_{}".format(layer_id_) layer_name_backward= "full_backward_{}".format(layer_id_) with tf.variable_scope("full_connect_{}".format(layer_id_)): full_connect_w.append(tf.get_variable("matrix", [self.full_size[layer_id_],self.full_size[layer_id_+1]],initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1, seed=None, dtype=tf.float32))) full_connect_b.append(tf.get_variable("bias", [self.full_size[layer_id_+1]],initializer=tf.constant_initializer(0.01))) """Connecting pipelines """ input_ = self.input_frames[:,0,:,:,:] for lstm_layer_id in xrange(self.layer_num_lstm): input_,lstm_forward_state[lstm_layer_id]=lstm_forward[lstm_layer_id](input_,lstm_forward_state[lstm_layer_id]) forward_output_ = input_ input_ = self.input_frames[:,0,:,:,:] for lstm_layer_id in xrange(self.layer_num_lstm): input_,lstm_backward_state[lstm_layer_id]=lstm_backward[lstm_layer_id](input_,lstm_backward_state[lstm_layer_id]) backward_output_ = input_ input_mlp = tf.concat(3,[forward_output_,backward_output_]) #input_mlp = forward_output_ preds = tf.reshape(input_mlp,[self.batch_size,-1]) for layer_id in range(self.layer_num_full-1): preds = tf.matmul(preds,full_connect_w[layer_id])+full_connect_b[layer_id] if layer_id == self.layer_num_full-2: preds = tf.sigmoid(preds) preds_1 = tf.sub(tf.ones([self.batch_size,1]),preds) preds = tf.concat(1,[preds,preds_1]) else: preds = tf.nn.relu(preds) self.scope.reuse_variables() def forward(): """Make forward pass """ for frame_id in xrange(self.length): input_ = self.input_frames[:,frame_id,:,:,:] for lstm_layer_id in xrange(self.layer_num_lstm): input_,lstm_forward_state[lstm_layer_id]=lstm_forward[lstm_layer_id](input_,lstm_forward_state[lstm_layer_id]) forward_output_ = input_ """Make backward pass """ for frame_id in xrange(self.length-1,-1,-1): input_ = self.input_frames[:,frame_id,:,:,:] for lstm_layer_id in xrange(self.layer_num_lstm): input_,lstm_backward_state[lstm_layer_id]=lstm_backward[lstm_layer_id](input_,lstm_backward_state[lstm_layer_id]) backward_output_ = input_ input_mlp = tf.concat(3,[forward_output_,backward_output_]) #output = forward_output_ preds = tf.reshape(input_mlp,[self.batch_size,-1]) """Make mlp pass""" full_out_summary= [] for layer_id in range(self.layer_num_full-1): histogram_name = "histogram:"+str(layer_id) preds = tf.matmul(preds,full_connect_w[layer_id])+full_connect_b[layer_id] temp = preds temp = tf.reshape(temp,[-1]) full_out_summary.append(tf.summary.histogram(histogram_name,temp)) if layer_id == self.layer_num_full-2: preds = tf.sigmoid(preds) preds_1 = tf.sub(tf.ones([self.batch_size,1]),preds) preds = tf.concat(1,[preds,preds_1]) temp = preds temp = tf.reshape(temp,[-1]) full_out_summary.append(tf.summary.histogram("output",temp)) else: preds = tf.nn.relu(preds) output = tf.clip_by_value(preds,0.05,0.95) output = tf.squeeze(output) return output,full_out_summary self.preds ,full_out_summary= forward() """ loss and training op """ #self.loss = bce_loss(self.preds,self.GT_label) self.loss = adv_loss(self.preds,self.GT_label) temp_op = tf.train.AdamOptimizer(FLAGS.lr) variable_collection = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope_string) gvs = temp_op.compute_gradients(self.loss,var_list=variable_collection) capped_gvs = [(tf.clip_by_norm(grad, FLAGS.clip), var) for grad, var in gvs] self.train_op = temp_op.apply_gradients(capped_gvs) #gvs = temp_op.compute_gradients(self.loss) """ summary """ #w_hist =[] #b_hist =[] #for layer_id in range(self.layer_num_full-1): # histogram_name = "histogram_w:"+str(layer_id) # w_hist.append(tf.summary.histogram(histogram_name,full_connect_w[layer_id].value())) # histogram_name = "histogram_b:"+str(layer_id) # b_hist.append(tf.summary.histogram(histogram_name,full_connect_b[layer_id].value())) loss_summary = tf.summary.scalar('loss_D', self.loss) self.summary = tf.summary.merge([loss_summary]+full_out_summary)
def define_graph(self): with tf.name_scope('input'): self.input_features = tf.placeholder(tf.float32, shape=[ FLAGS.batch_size, self.length, self.height, self.width, 512 ]) #self.ground_truth= tf.place_holder(tf.float32, shape=[FLAGS.batch_size, self.length, 224, 224,1]) #self.initial_h_0 = tf.place_holder(tf.float32,shape=[FLAGS.batch_size, self.height, self.width,512]) #self.initial_c_0 = tf.place_holder(tf.float32,shape=[FLAGS.batch_size, self.height, self.width,512]) #self.initial_h_1 = tf.place_holder(tf.float32,shape=[FLAGS.batch_size, self.height, self.width,1]) #self.initial_c_1 = tf.place_holder(tf.float32,shape=[FLAGS.batch_size, self.height, self.width,1]) with tf.variable_scope(self.scope) as scope: lstms = [] lstms_state = [] for layer_id_, kernel_, kernel_num_ in zip( xrange(self.layer_num_lstm), self.kernel_size, self.kernel_num): layer_name_encode = 'conv_lstm' + str(layer_id_) + 'enc' temp_cell = BasicConvLSTMCell.BasicConvLSTMCell( [self.height, self.width], [kernel_, kernel_], kernel_num_, layer_name_encode) if layer_id_ == 0: lstms_state.append( tf.concat([self.initial_c_0, self.initial_h_0], 3)) else: lstms_state.append( tf.concat([self.initial_c_1, self.initial_h_1], 3)) lstms.append(temp_cell) dec_conv_W = [] dec_conv_b = [] for layer_id_, kernel_, input_, output_ in zip( xrange(self.layer_num_cnn), self.kernel_size_dec, self.num_dec_input, self.num_dec_output): with tf.variable_scope("dec_conv_{}".format(layer_id_)): dec_conv_W.append( tf.get_variable( "matrix", shape=kernel_ + [output_, input_], initializer=tf.random_uniform_initializer( -0.01, 0.01))) dec_conv_b.append( tf.get_variable( "bias", shape=[output_], initializer=tf.constant_initializer(0.01))) input_ = self.input_features[:, 0, :, :, :] input_, _ = lstms[0](input_, lstms_state[0]) input_, _ = lstms[1](input_, lstms_state[1]) for layer_id_, kernel_num_, num_input, num_output in zip( xrange(self.layer_num_cnn), self.kernel_size_dec, self.num_dec_input, self.num_dec_output): if layer_id_ == self.layer_num_cnn - 1: output_shape = tf.stack([ tf.shape(input_)[0], tf.shape(input_)[1] * 2, tf.shape(input_)[2] * 2, num_output ]) input_ = tf.nn.conv2d_transpose( input_, dec_conv_W[layer_id_], output_shape=output_shape, strides=[1, 2, 2, 1], padding='SAME') + dec_conv_b[layer_id_] input_ = tf.nn.sigmoid(input_) else: output_shape = tf.stack([ tf.shape(input_)[0], tf.shape(input_)[1] * 2, tf.shape(input_)[2] * 2, num_output ]) input_ = tf.nn.conv2d_transpose( input_, dec_conv_W[layer_id_], output_shape=output_shape, strides=[1, 2, 2, 1], padding='SAME') + dec_conv_b[layer_id_] input_ = tf.maximum(0.2 * input_, input_) scope.reuse_variables() def forward(): output = [] for frame_id in xrange(self.length): input_ = self.input_features[:, frame_id, :, :, :] input_, lstms_state[0] = lstms[0](input_, lstms_state[0]) input_, lstms_state[1] = lstms[1](input_, lstms_state[1]) for layer_id_, kernel_num_, num_input, num_output in zip( xrange(self.layer_num_cnn), self.kernel_size_dec, self.num_dec_input, self.num_dec_output): if layer_id_ == self.layer_num_cnn - 1: output_shape = tf.stack([ tf.shape(input_)[0], tf.shape(input_)[1] * 2, tf.shape(input_)[2] * 2, num_output ]) input_ = tf.nn.conv2d_transpose( input_, dec_conv_W[layer_id_], output_shape=output_shape, strides=[1, 2, 2, 1], padding='SAME') + dec_conv_b[layer_id_] #input_ = tf.nn.sigmoid(input_) output.append(input_) else: output_shape = tf.stack([ tf.shape(input_)[0], tf.shape(input_)[1] * 2, tf.shape(input_)[2] * 2, num_output ]) input_ = tf.nn.conv2d_transpose( input_, dec_conv_W[layer_id_], output_shape=output_shape, strides=[1, 2, 2, 1], padding='SAME') + dec_conv_b[layer_id_] input_ = tf.maximum(0.2 * input_, input_) output = tf.stack(output, axis=1) return output self.logits = forward( ) #note: output is logits need to convert to sigmoid in testing mode
def model(self, rain, is_training, reuse): with tf.variable_scope('model', reuse=reuse): # RNN框架 with tf.variable_scope('LSTM'): cell = BasicConvLSTMCell([self.weight * 4, self.height * 4], [3, 3], 256) rnn_state = cell.zero_state(batch_size=self.batch_size, dtype=tf.float32) with tf.variable_scope('LSTM2'): cell2 = BasicConvLSTMCell([self.weight * 2, self.height * 2], [3, 3], 128) rnn_state2 = cell2.zero_state(batch_size=self.batch_size, dtype=tf.float32) with tf.variable_scope('LSTM4'): cell4 = BasicConvLSTMCell([self.weight, self.height], [3, 3], 64) rnn_state4 = cell4.zero_state(batch_size=self.batch_size, dtype=tf.float32) self.down_2 = self.downscale2(rain) self.down_4 = self.downscale4(rain) with tf.variable_scope('rnn1'): rain1 = deconv_layer( rain, [3, 3, 64, 3], [self.batch_size, self.weight * 4, self.height * 4, 64], 1) rain1 = prelu(rain1) with tf.variable_scope('rnn2'): rain2 = deconv_layer( self.down_2, [3, 3, 64, 3], [self.batch_size, self.weight * 2, self.height * 2, 64], 1) rain2 = prelu(rain2) with tf.variable_scope('rnn4'): rain4 = deconv_layer( self.down_4, [3, 3, 64, 3], [self.batch_size, self.weight, self.height, 64], 1) rain4 = prelu(rain4) long_connection = rain1 with tf.variable_scope('residual_memory'): ## RMM ############ rain4 ############ rain4 with tf.variable_scope('rain4_res1'): res4_lstmin = deconv_layer( rain4, [3, 3, 32, 64], [self.batch_size, self.weight, self.height, 32], 1) res4_lstmin = prelu(res4_lstmin) with tf.variable_scope('lstm_group4'): y_4, rnn_state4 = cell4(res4_lstmin, rnn_state4) with tf.variable_scope('rain4_res2'): res4_lstmout = deconv_layer( y_4, [3, 3, 64, 64], [self.batch_size, self.weight, self.height, 64], 1) res4_lstmout = prelu(res4_lstmout) res4_lstmout += rain4 ############ rain2 ############ rain2 with tf.variable_scope('upRMM4_2'): rain4to2 = deconv_layer(res4_lstmout, [3, 3, 64, 64], [ self.batch_size, self.weight * 2, self.height * 2, 64 ], 2) rain4to2 = prelu(rain4to2) rain2_e = rain2 + rain4to2 with tf.variable_scope('rain2_res1'): res2_lstmin = deconv_layer(rain2_e, [3, 3, 32, 64], [ self.batch_size, self.weight * 2, self.height * 2, 32 ], 1) res2_lstmin = prelu(res2_lstmin) with tf.variable_scope('lstm_group2'): y_2, rnn_state2 = cell2(res2_lstmin, rnn_state2) with tf.variable_scope('rain2_res2'): res2_lstmout = deconv_layer(y_2, [3, 3, 64, 128], [ self.batch_size, self.weight * 2, self.height * 2, 64 ], 1) res2_lstmout = prelu(res2_lstmout) res2_lstmout += rain2 ############ rain ############ rain with tf.variable_scope('upRMM4_1'): rain4to2 = deconv_layer(rain4to2, [3, 3, 64, 64], [ self.batch_size, self.weight * 4, self.height * 4, 64 ], 2) rain4to2 = prelu(rain4to2) with tf.variable_scope('upRMM2_1'): rain2to1 = deconv_layer(res2_lstmout, [3, 3, 64, 64], [ self.batch_size, self.weight * 4, self.height * 4, 64 ], 2) rain2to1 = prelu(rain2to1) rain1_e = rain1 + rain2to1 + rain4to2 with tf.variable_scope('rain_res1'): res_lstmin = deconv_layer(rain1_e, [3, 3, 32, 64], [ self.batch_size, self.weight * 4, self.height * 4, 32 ], 1) res_lstmin = prelu(res_lstmin) with tf.variable_scope('lstm_group'): y_1, rnn_state = cell(res_lstmin, rnn_state) with tf.variable_scope('rain_res2'): res_lstmout = deconv_layer(y_1, [3, 3, 64, 256], [ self.batch_size, self.weight * 4, self.height * 4, 64 ], 1) res_lstmout = prelu(res_lstmout) res_lstmout += rain1 u4_long_connection = res4_in = res4_lstmout u2_long_connection = res2_in = res2_lstmout u_long_connection = res_in = res_lstmout for n in range(17): ## URAB ############ rain4 ############ rain4 with tf.variable_scope('URAB4_{}'.format(n)): with tf.variable_scope('down4_{}'.format(n + 1)): x_rnn = conv_layer(res4_in, [3, 3, 64, 64], 2) x_rnn = prelu(x_rnn) res_short = res_input = x_rnn for m in range(1): with tf.variable_scope('group_{}_RCAB{}'.format( n + 1, m + 1)): res_input = self.RCAB(res_input, 4) with tf.variable_scope('up_{}'.format(n + 1)): res_out4 = deconv_layer( tf.concat([res_short, res_input], 3), [3, 3, 64, 128], [self.batch_size, self.weight, self.height, 64], 2) #2 res_out4 = prelu(res_out4) res4_in += res_out4 ############ rain2 ############ rain2 with tf.variable_scope('URAB2_{}'.format(n)): with tf.variable_scope('upURAB4_2_{}'.format(n)): rain4_resto2 = deconv_layer(res4_in, [3, 3, 64, 64], [ self.batch_size, self.weight * 2, self.height * 2, 64 ], 2) rain4_resto2 = prelu(rain4_resto2) res2_in_e = res2_in + rain4_resto2 with tf.variable_scope('down2_{}'.format(n + 1)): x_rnn = conv_layer(res2_in_e, [3, 3, 64, 64], 2) x_rnn = prelu(x_rnn) res_short = res_input = x_rnn for m in range(3): with tf.variable_scope('group_{}_RCAB{}'.format( n + 1, m + 1)): res_input = self.RCAB(res_input, 4) with tf.variable_scope('up_{}'.format(n + 1)): res_out2 = deconv_layer( tf.concat([res_short, res_input], 3), [3, 3, 64, 128], [ self.batch_size, self.weight * 2, self.height * 2, 64 ], 2) #2 res_out2 = prelu(res_out2) res2_in += res_out2 ############ rain ############ rain with tf.variable_scope('URAB_{}'.format(n)): with tf.variable_scope('upURAB4to1_{}'.format(n)): rain4_resto1 = deconv_layer( rain4_resto2, [3, 3, 64, 64], [ self.batch_size, self.weight * 4, self.height * 4, 64 ], 2) rain4_resto1 = prelu(rain4_resto1) with tf.variable_scope('upURAB2to1_{}'.format(n)): rain2_resto1 = deconv_layer(res2_in, [3, 3, 64, 64], [ self.batch_size, self.weight * 4, self.height * 4, 64 ], 2) rain2_resto1 = prelu(rain2_resto1) res_in_e = res_in + rain4_resto1 + rain2_resto1 with tf.variable_scope('down_{}'.format(n + 1)): x_rnn = conv_layer(res_in_e, [3, 3, 64, 64], 2) x_rnn = prelu(x_rnn) res_short = res_input = x_rnn for m in range(3): with tf.variable_scope('group_{}_RCAB{}'.format( n + 1, m + 1)): res_input = self.RCAB(res_input, 4) with tf.variable_scope('up_{}'.format(n + 1)): res_out = deconv_layer( tf.concat([res_short, res_input], 3), [3, 3, 64, 128], [ self.batch_size, self.weight * 4, self.height * 4, 64 ], 2) #2 res_out = prelu(res_out) res_in += res_out ################ rememory-long connections with tf.variable_scope('rnn_recon4'): res4_in = deconv_layer( tf.concat([res4_in, res4_lstmout], 3), [3, 3, 64, 128], [self.batch_size, self.weight, self.height, 64], 1) #2 res4_in = prelu(res4_in) with tf.variable_scope('rnn_recon2'): res2_in = deconv_layer( tf.concat([res2_in, res2_lstmout], 3), [3, 3, 64, 128], [self.batch_size, self.weight * 2, self.height * 2, 64], 1) #2 res2_in = prelu(res2_in) with tf.variable_scope('rnn_recon'): res_in = deconv_layer( tf.concat([res_in, res_lstmout], 3), [3, 3, 64, 128], [self.batch_size, self.weight * 4, self.height * 4, 64], 1) #2 res_in = prelu(res_in) #lstm_output = lstm_in + lstm_input with tf.variable_scope('up4_2'): res4_2 = deconv_layer( res4_in, [3, 3, 64, 64], [self.batch_size, self.weight * 2, self.height * 2, 64], 2) #2 res4_2 = prelu(res4_2) with tf.variable_scope('up2_1'): res42_1 = deconv_layer( tf.concat([res4_2, res2_in], 3), [3, 3, 64, 128], [self.batch_size, self.weight * 4, self.height * 4, 64], 2) #2 res42_1 = prelu(res42_1) res_mem_con = tf.concat([res_in, res42_1], 3) with tf.variable_scope('rnn5'): res_mem_con = deconv_layer( res_mem_con, [3, 3, 32, 128], [self.batch_size, self.weight * 4, self.height * 4, 32], 1) #2 res_mem_con = prelu(res_mem_con) with tf.variable_scope('rnn7'): res_rainimage = deconv_layer( res_mem_con, [3, 3, 3, 32], [self.batch_size, self.weight * 4, self.height * 4, 3], 1) #2 self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='model') return res_rainimage
def define_graph(self): """ define a Bidirectional LSTM and concatenated feature for MLP """ lstm_encode = [] lstm_predict = [] lstm_decode = [] lstm_encode_state = [] lstm_predict_state = [] lstm_decode_state = [] with tf.name_scope('input'): self.input_frames = tf.placeholder( tf.float32, shape=[None, self.length, self.height, self.width, 1]) # use variable batch_size for more flexibility self.D_label = tf.placeholder(tf.float32, shape=[self.batch_size, 2]) self.future_frames = tf.placeholder(tf.float32, shape=[ None, self.future_seq_length, self.height, self.width, 1 ]) self.input_frames_low_scale = tf.placeholder( tf.float32, shape=[None, self.length, self.height, self.width, 1]) self.future_frames_low_scale = tf.placeholder( tf.float32, shape=[ None, self.future_seq_length, self.height, self.width, 1 ]) with tf.variable_scope("G_scale_{}".format(self.scale_index)): for layer_id_, kernel_, kernel_num_ in zip( xrange(self.layer_num_lstm), self.kernel_size, self.kernel_num): layer_name_encode = "conv_lstm_encode_{}".format(layer_id_) #with tf.variable_scope('conv_lstm_encode', initializer = tf.random_uniform_initializer(-.01, 0.1)): temp_cell = BasicConvLSTMCell.BasicConvLSTMCell( [self.height, self.width], [kernel_, kernel_], kernel_num_, layer_name_encode) temp_state = temp_cell.zero_state(self.batch_size, tf.float32) lstm_encode.append(temp_cell) lstm_encode_state.append(temp_state) for layer_id_, kernel_, kernel_num_ in zip( xrange(self.layer_num_lstm), self.kernel_size, self.kernel_num): layer_name_predict = "conv_lstm_predict_{}".format(layer_id_) #with tf.variable_scope('conv_lstm_predict', initializer = tf.random_uniform_initializer(-.01, 0.1)): temp_cell = BasicConvLSTMCell.BasicConvLSTMCell( [self.height, self.width], [kernel_, kernel_], kernel_num_, layer_name_predict) temp_state = temp_cell.zero_state(self.batch_size, tf.float32) lstm_predict.append(temp_cell) lstm_predict_state.append(temp_state) for layer_id_, kernel_, kernel_num_ in zip( xrange(self.layer_num_lstm), self.kernel_size, self.kernel_num): layer_name_predict = "conv_lstm_decode_{}".format(layer_id_) #with tf.variable_scope('conv_lstm_predict', initializer = tf.random_uniform_initializer(-.01, 0.1)): temp_cell = BasicConvLSTMCell.BasicConvLSTMCell( [self.height, self.width], [kernel_, kernel_], kernel_num_, layer_name_predict) temp_state = temp_cell.zero_state(self.batch_size, tf.float32) lstm_decode.append(temp_cell) lstm_decode_state.append(temp_state) input_ = tf.concat(3, [ self.input_frames[:, 0, :, :, :], self.input_frames_low_scale[:, 0, :, :, :] ]) for lstm_layer_id in xrange(self.layer_num_lstm): input_, lstm_encode_state[lstm_layer_id] = lstm_encode[ lstm_layer_id](input_, lstm_encode_state[lstm_layer_id]) input_ = tf.concat(3, [ self.input_frames[:, 0, :, :, :], self.input_frames_low_scale[:, 0, :, :, :] ]) lstm_pyramid = [] for lstm_layer_id in xrange(self.layer_num_lstm): input_, lstm_predict_state[lstm_layer_id] = lstm_predict[ lstm_layer_id](input_, lstm_predict_state[lstm_layer_id]) lstm_pyramid.append(input_) input_ = tf.concat(3, [ self.input_frames[:, 0, :, :, :], self.input_frames_low_scale[:, 0, :, :, :] ]) lstm_pyramid_de = [] for lstm_layer_id in xrange(self.layer_num_lstm): input_, lstm_decode_state[lstm_layer_id] = lstm_decode[ lstm_layer_id](input_, lstm_decode_state[lstm_layer_id]) lstm_pyramid_de.append(input_) y_cat = tf.concat(3, lstm_pyramid) temp = ld.transpose_conv_layer(y_cat, 1, 1, 1, "predict") y_cat_de = tf.concat(3, lstm_pyramid_de) temp_de = ld.transpose_conv_layer(y_cat_de, 1, 1, 1, "decode") self.scope.reuse_variables() def forward(): """Make forward pass """ for frame_id in xrange(self.length): input_ = tf.concat(3, [ self.input_frames[:, frame_id, :, :, :], self.input_frames_low_scale[:, frame_id, :, :, :] ]) for lstm_layer_id in xrange(self.layer_num_lstm): input_, lstm_encode_state[lstm_layer_id] = lstm_encode[ lstm_layer_id](input_, lstm_encode_state[lstm_layer_id]) for i in xrange(self.layer_num_lstm): lstm_predict_state[i] = lstm_encode_state[i] lstm_decode_state[i] = lstm_decode_state[i] predicts = [] for frame_id in xrange(self.future_seq_length): if frame_id == 0: input_ = tf.concat(3, [ self.input_frames[:, -1, :, :, :], self.future_frames_low_scale[:, frame_id, :, :, :] ]) else: input_ = tf.concat(3, [ y_out, self.future_frames_low_scale[:, frame_id, :, :, :] ]) #input_ = tf.concat(3,[self.future_frames[:,frame_id-1,:,:,:],self.future_frames_low_scale[:,frame_id,:,:,:]]) # adding all layer predictions together lstm_pyramid = [] for lstm_layer_id in xrange(self.layer_num_lstm): input_, lstm_predict_state[lstm_layer_id] = lstm_predict[ lstm_layer_id](input_, lstm_predict_state[lstm_layer_id]) lstm_pyramid.append(input_) y_cat = tf.concat(3, lstm_pyramid) y_out = ld.transpose_conv_layer(y_cat, 1, 1, 1, "predict") predicts.append(y_out) # swap axis x_unwrap_gen = tf.pack(predicts) predicts = tf.transpose(x_unwrap_gen, [1, 0, 2, 3, 4]) decodes_temp = [] for frame_id in range(self.length, 0, -1): if frame_id == self.length: input_ = tf.concat(3, [ self.future_frames[:, 0, :, :, :], self.input_frames_low_scale[:, frame_id - 1, :, :, :] ]) else: input_ = tf.concat(3, [ self.input_frames[:, frame_id, :, :, :], self.input_frames_low_scale[:, frame_id - 1, :, :, :] ]) # adding all layer predictions together lstm_pyramid = [] for lstm_layer_id in xrange(self.layer_num_lstm): input_, lstm_decode_state[lstm_layer_id] = lstm_decode[ lstm_layer_id](input_, lstm_decode_state[lstm_layer_id]) lstm_pyramid.append(input_) y_cat = tf.concat(3, lstm_pyramid) y_out = ld.transpose_conv_layer(y_cat, 1, 1, 1, "decode") decodes_temp.append(y_out) decodes = [] for i in range(self.length): decodes.append(decodes_temp.pop()) # swap axis x_unwrap_de = tf.pack(decodes) decodes = tf.transpose(x_unwrap_de, [1, 0, 2, 3, 4]) return predicts, decodes self.preds, self.decodes = forward() """ loss and training op """ mean_loss = l2_loss(self.preds, self.future_frames) / self.future_seq_length mean_loss_de = l2_loss(self.decodes, self.input_frames) / self.length GT_label = tf.concat( 1, [tf.ones([self.batch_size, 1]), tf.zeros([self.batch_size, 1])]) entropy_loss = adv_loss(self.D_label, GT_label) self.loss = mean_loss + entropy_loss + mean_loss_de #self.loss = combined_loss(self.preds,self.future_frames,self.D_label) temp_op = tf.train.AdamOptimizer(FLAGS.lr) variable_collection = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, self.scope_string) gvs = temp_op.compute_gradients(self.loss, var_list=variable_collection) capped_gvs = [(tf.clip_by_norm(grad, FLAGS.clip), var) for grad, var in gvs] self.train_op = temp_op.apply_gradients(capped_gvs) mean_loss_summary = tf.summary.scalar('loss_mean_pre', mean_loss) mean_loss_de_summary = tf.summary.scalar('loss_mean_dec', mean_loss_de) entropy_loss_summary = tf.summary.scalar('loss_entropy', entropy_loss) loss_summary = tf.summary.scalar('loss_G', self.loss) self.summary = tf.summary.merge([ loss_summary, mean_loss_summary, entropy_loss_summary, mean_loss_de_summary ])
def train(): """Train ring_net for a number of steps.""" with tf.Graph().as_default(): # make inputs x = tf.placeholder(tf.float32, [None, FLAGS.seq_length, 32, 32, 3]) # possible dropout inside keep_prob = tf.placeholder("float") x_dropout = tf.nn.dropout(x, keep_prob) # create network x_unwrap = [] with tf.variable_scope('conv_lstm', initializer=tf.random_uniform_initializer(-.01, 0.1)): cell = BasicConvLSTMCell.BasicConvLSTMCell([8, 8], [3, 3], 4) new_state = cell.zero_state(FLAGS.batch_size, tf.float32) # conv network for i in xrange(FLAGS.seq_length - 1): # conv1 if i < FLAGS.seq_start: conv1 = ld.conv_layer(x_dropout[:, i, :, :, :], 3, 2, 8, "encode_1") else: conv1 = ld.conv_layer(x_1, 3, 2, 8, "encode_1") # conv2 conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2") # conv3 conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3") # conv4 conv4 = ld.conv_layer(conv3, 1, 1, 4, "encode_4") y_0 = conv4 # conv lstm cell y_1, new_state = cell(y_0, new_state) # conv5 conv5 = ld.transpose_conv_layer(y_1, 1, 1, 8, "decode_5") # conv6 conv6 = ld.transpose_conv_layer(conv5, 3, 2, 8, "decode_6") # conv7 conv7 = ld.transpose_conv_layer(conv6, 3, 1, 8, "decode_7") # x_1 x_1 = ld.transpose_conv_layer(conv7, 3, 2, 3, "decode_8", True) # set activation to linear if i >= FLAGS.seq_start: x_unwrap.append(x_1) # set reuse to true after first go if i == 0: tf.get_variable_scope().reuse_variables() # pack them all together x_unwrap = tf.pack(x_unwrap) x_unwrap = tf.transpose(x_unwrap, [1, 0, 2, 3, 4]) # this part will be used for generating video x_unwrap_gen = [] new_state_gen = cell.zero_state(FLAGS.batch_size, tf.float32) for i in xrange(50): # conv1 if i < FLAGS.seq_start: conv1 = ld.conv_layer(x[:, i, :, :, :], 3, 2, 8, "encode_1") else: conv1 = ld.conv_layer(x_1_gen, 3, 2, 8, "encode_1") # conv2 conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2") # conv3 conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3") # conv4 conv4 = ld.conv_layer(conv3, 1, 1, 4, "encode_4") y_0 = conv4 # conv lstm cell y_1, new_state_gen = cell(y_0, new_state_gen) # conv5 conv5 = ld.transpose_conv_layer(y_1, 1, 1, 8, "decode_5") # conv6 conv6 = ld.transpose_conv_layer(conv5, 3, 2, 8, "decode_6") # conv7 conv7 = ld.transpose_conv_layer(conv6, 3, 1, 8, "decode_7") # x_1_gen x_1_gen = ld.transpose_conv_layer(conv7, 3, 2, 3, "decode_8", True) # set activation to linear if i >= FLAGS.seq_start: x_unwrap_gen.append(x_1_gen) # pack them generated ones x_unwrap_gen = tf.pack(x_unwrap_gen) x_unwrap_gen = tf.transpose(x_unwrap_gen, [1, 0, 2, 3, 4]) # calc total loss (compare x_t to x_t+1) loss = tf.nn.l2_loss(x[:, FLAGS.seq_start + 1:, :, :, :] - x_unwrap[:, :, :, :, :]) tf.scalar_summary('loss', loss) # training train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss) # List of all Variables variables = tf.all_variables() # Build a saver saver = tf.train.Saver(tf.all_variables()) # Summary op summary_op = tf.merge_all_summaries() # Build an initialization operation to run below. init = tf.initialize_all_variables() # Start running operations on the Graph. sess = tf.Session() # init if this is the very time training print("init network from scratch") sess.run(init) # Summary op graph_def = sess.graph.as_graph_def(add_shapes=True) summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, graph_def=graph_def) for step in xrange(FLAGS.max_step): dat = generate_bouncing_ball_sample(FLAGS.batch_size, FLAGS.seq_length, 32, FLAGS.num_balls) t = time.time() _, loss_r = sess.run([train_op, loss], feed_dict={x: dat, keep_prob: FLAGS.keep_prob}) elapsed = time.time() - t if step % 100 == 0 and step != 0: summary_str = sess.run(summary_op, feed_dict={x: dat, keep_prob: FLAGS.keep_prob}) summary_writer.add_summary(summary_str, step) print("time per batch is " + str(elapsed)) print(step) print(loss_r) assert not np.isnan(loss_r), 'Model diverged with loss = NaN' if step % 1000 == 0: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step) print("saved to " + FLAGS.train_dir) # make video print("now generating video!") video = cv2.VideoWriter() success = video.open("generated_conv_lstm_video.mov", fourcc, 4, (180, 180), True) dat_gif = dat ims = sess.run([x_unwrap_gen], feed_dict={x: dat_gif, keep_prob: FLAGS.keep_prob}) ims = ims[0][0] print(ims.shape) for i in xrange(50 - FLAGS.seq_start): x_1_r = np.uint8(np.maximum(ims[i, :, :, :], 0) * 255) new_im = cv2.resize(x_1_r, (180, 180)) video.write(new_im) video.release()
def define_graph(self): with tf.name_scope('input'): #self.input_frames = tf.placeholder(tf.float32,shape=None) #self.first_mask = tf.placeholder(tf.float32,shape=None) self.input_frames = tf.placeholder(tf.float32, shape=[self.batch_size, self.length, self.img_height, self.img_width,3]) self.first_mask = tf.placeholder(tf.float32,shape=[self.batch_size,self.img_height,self.img_width,1]) with tf.variable_scope(self.scope) as scope: self.lstms = [] self.lstms_state = [] for layer_id_, kernel_, kernel_num_ in zip(xrange(self.layer_num_lstm),self.kernel_size,self.kernel_num): layer_name_encode = 'conv_lstm'+str(layer_id_)+'enc' temp_cell= BasicConvLSTMCell.BasicConvLSTMCell([self.height,self.width],[kernel_,kernel_],kernel_num_,layer_name_encode) if layer_id_ ==0: self.lstms_state.append(tf.concat([self.initial_c_0,self.initial_h_0],3)) else: self.lstms_state.append(tf.concat([self.initial_c_1,self.initial_h_1],3)) self.lstms.append(temp_cell) self.dec_conv_W = [] self.dec_conv_b = [] for layer_id_,kernel_,input_,output_ in zip(xrange(self.layer_num_cnn),self.kernel_size_dec,self.num_dec_input,self.num_dec_output): with tf.variable_scope("dec_conv_{}".format(layer_id_)): self.dec_conv_W.append(tf.get_variable("matrix", shape = kernel_+[output_,input_], initializer = tf.random_uniform_initializer(-0.01, 0.01))) self.dec_conv_b.append(tf.get_variable("bias",shape = [output_],initializer=tf.constant_initializer(0.01))) input_ = tf.zeros((self.batch_size,14,14,512)) for layer_lstm in range(self.layer_num_lstm): input_,_=self.lstms[layer_lstm](input_,self.lstms_state[layer_lstm]) for layer_id_,kernel_num_,num_input,num_output in zip(xrange(self.layer_num_cnn),self.kernel_size_dec,self.num_dec_input,self.num_dec_output): if layer_id_ == self.layer_num_cnn-1: output_shape = tf.stack([tf.shape(input_)[0],tf.shape(input_)[1]*2,tf.shape(input_)[2]*2,num_output]) input_ = tf.nn.conv2d_transpose(input_, self.dec_conv_W[layer_id_],output_shape = output_shape, strides= [1,2,2,1],padding= 'SAME') + self.dec_conv_b[layer_id_] input_ = tf.nn.sigmoid(input_) else: output_shape = tf.stack([tf.shape(input_)[0],tf.shape(input_)[1]*2,tf.shape(input_)[2]*2,num_output]) input_ = tf.nn.conv2d_transpose(input_, self.dec_conv_W[layer_id_],output_shape = output_shape, strides= [1,2,2,1],padding= 'SAME') + self.dec_conv_b[layer_id_] input_ = tf.maximum(0.2 * input_, input_) scope.reuse_variables() def get_template(frame,center,crop_size): # input: # search frame # center # output: # template (size 448 *448) pad_frame = tf.image.pad_to_bounding_box(frame,crop_size/2,crop_size/2,self.img_height+crop_size,self.img_width+crop_size) output = tf.image.crop_to_bounding_box(pad_frame,center[0],center[1],crop_size,crop_size) return output def get_center(mask): # get image mask center yv,xv = tf.meshgrid(tf.range(start=0,limit=self.img_width),tf.range(start=0,limit=self.img_height)) mask = tf.squeeze(mask) xv = tf.to_float(xv) yv = tf.to_float(yv) if tf.reduce_sum(mask)==0: return None center_x,center_y = tf.to_int32(tf.reduce_sum(xv*mask)/tf.reduce_sum(mask)),tf.to_int32(tf.reduce_sum(yv*mask)/tf.reduce_sum(mask)) #center_x = tf.cond(center_x<=0,lambda:tf.constant([0]),lambda:center_x) #center_y = tf.cond(center_y<=0,lambda:tf.constant([0]),lambda:center_y) #center_x = tf.cond(center_x>=self.img_height-1, lambda:self.img_height-1,lambda:center_x) #center_y = tf.cond(center_y>=self.img_width-1, lambda:self.img_width-1,lambda:center_y) center = [center_x,center_y] return center def fill_template(mask,center,crop_size): # mask : template prediction # crop_center: cropped location try: temp_img = tf.image.pad_to_bounding_box(mask,center[0],center[1] ,self.img_height+crop_size,self.img_width+crop_size) return tf.image.crop_to_bounding_box(temp_img,crop_size/2,crop_size/2,self.img_height,self.img_width) except: return None def get_radius(mask): return tf.round(tf.sqrt(tf.reduce_sum(mask))) def forward(): output_list = [] center_list = [] cell_state = [] output_template_list = [] for layer_lstm in range(self.layer_num_lstm): cell_state.append([]) for frame_id in xrange(self.length): for layer_lstm in range(self.layer_num_lstm): temp = self.lstms_state[layer_lstm] cell_state[layer_lstm].append(temp) if frame_id ==0: center = get_center(self.first_mask) center_list.append(center) temp = tf.to_float(tf.greater(self.first_mask,0.5)) crop_size = 4*get_radius(temp) crop_size = tf.minimum(crop_size,224*2) crop_size = tf.maximum(crop_size,112) crop_size = tf.to_int32(crop_size) #crop_size = 448 template = get_template(self.input_frames[0,frame_id,:,:,:],center,crop_size) template = tf.image.resize_images(template,tf.constant([224,224])) template = tf.expand_dims(template,axis=0) # extract feature input_ = feature_extract(template) for layer_lstm in range(self.layer_num_lstm): input_,self.lstms_state[layer_lstm]=self.lstms[layer_lstm](input_,self.lstms_state[layer_lstm]) #input_,self.lstms_state[0]=self.lstms[0](input_,self.lstms_state[0]) #input_,self.lstms_state[1]=self.lstms[1](input_,self.lstms_state[1]) else: center = get_center(output_frame) center_list.append(center) temp = tf.to_float(tf.greater(output_frame,0.5)) crop_size = 4*get_radius(temp) crop_size = tf.minimum(crop_size,224*2) crop_size = tf.maximum(crop_size,112) crop_size = tf.to_int32(crop_size) #crop_size = 448 if center is None: return tf.stack(output_list,axis=1) else: template = get_template(self.input_frames[0,frame_id,:,:,:],center,crop_size) template = tf.image.resize_images(template,tf.constant([224,224])) template = tf.expand_dims(template,axis=0) # extract feature input_ = feature_extract(template) for layer_lstm in range(self.layer_num_lstm): input_,self.lstms_state[layer_lstm]=self.lstms[layer_lstm](input_,self.lstms_state[layer_lstm]) for layer_id_,kernel_num_,num_input,num_output in zip(xrange(self.layer_num_cnn),self.kernel_size_dec,self.num_dec_input,self.num_dec_output): if layer_id_ == self.layer_num_cnn-1: #output_shape = tf.stack([tf.shape(input_)[0],tf.shape(input_)[1]*2,tf.shape(input_)[2]*2,num_output]) output_shape = tf.constant([1,224,224,1]) input_ = tf.nn.conv2d_transpose(input_, self.dec_conv_W[layer_id_],output_shape = output_shape, strides= [1,2,2,1],padding= 'SAME') + self.dec_conv_b[layer_id_] output_template = tf.sigmoid(input_) else: output_shape = tf.stack([tf.shape(input_)[0],tf.shape(input_)[1]*2,tf.shape(input_)[2]*2,num_output]) input_ = tf.nn.conv2d_transpose(input_, self.dec_conv_W[layer_id_],output_shape = output_shape, strides= [1,2,2,1],padding= 'SAME') + self.dec_conv_b[layer_id_] input_ = tf.maximum(0.2 * input_, input_) output_template = tf.squeeze(output_template,axis=0) output_template = tf.image.resize_images(output_template,[crop_size,crop_size]) #temp_template = tf.image.resize_images(output_template,tf.constant([224,224])) output_frame = fill_template(output_template,center,crop_size) output_template_list.append(template) if output_frame is None: return tf.stack(output_list,axis=0),tf.stack(center_list,axis=1) output_list.append(output_frame) return tf.stack(output_list,axis=0),tf.stack(center_list,axis=1),tf.stack(cell_state),tf.stack(output_template_list) self.predicts,self.centers,self.cell_states,self.templates = forward() #note: output is logits need to convert to sigmoid in testing mode
def generator(self, x_com, is_training, reuse): with tf.variable_scope('generator', reuse=reuse): # CNN框架 with tf.variable_scope('cnn1'): x_cnn = deconv_layer( x_com, [3, 3, 64, 3], [self.batch_size, self.image_size, self.image_size, 64], 1) x_cnn = prelu(x_cnn) #shortcut = x for i in range(6): with tf.variable_scope('block{}cnn1'.format(i + 1)): x1 = x2 = x3 = x_cnn for j in range(3): with tf.variable_scope('block{}_{}cnn1'.format( i + 1, j + 1)): with tf.variable_scope('ud1'): a1 = prelu( deconv_layer(x1, [3, 3, 64, 64], [ self.batch_size, self.image_size, self.image_size, 64 ], 1)) #a1 = batch_normalize(a1, is_training) with tf.variable_scope('ud2'): b1 = prelu( deconv_layer(x2, [3, 3, 64, 64], [ self.batch_size, self.image_size, self.image_size, 64 ], 1)) #b1 = batch_normalize(b1, is_training) with tf.variable_scope('ud3'): c1 = prelu( deconv_layer(x3, [3, 3, 64, 64], [ self.batch_size, self.image_size, self.image_size, 64 ], 1)) #c1 = batch_normalize(c1, is_training) sum = tf.concat([a1, b1, c1], 3) #sum = batch_normalize(sum, is_training) with tf.variable_scope('ud4'): x1 = prelu( deconv_layer(tf.concat( [sum, x1], 3), [1, 1, 64, 256], [ self.batch_size, self.image_size, self.image_size, 64 ], 1)) #x1 = batch_normalize(x1, is_training) with tf.variable_scope('ud5'): x2 = prelu( deconv_layer(tf.concat( [sum, x2], 3), [1, 1, 64, 256], [ self.batch_size, self.image_size, self.image_size, 64 ], 1)) #x2 = batch_normalize(x2, is_training) with tf.variable_scope('ud6'): x3 = prelu( deconv_layer(tf.concat( [sum, x3], 3), [1, 1, 64, 256], [ self.batch_size, self.image_size, self.image_size, 64 ], 1)) #x3 = batch_normalize(x3, is_training) with tf.variable_scope('ud7'): block_out = prelu( deconv_layer(tf.concat([x1, x2, x3], 3), [3, 3, 64, 192], [ self.batch_size, self.image_size, self.image_size, 64 ], 1)) #with tf.variable_scope('ud8'): #block_out_att = deconv_layer(block_out, [3, 3, 64, 64], [self.batch_size, self.image_size, self.image_size, 64], 1) #block_out_att = tf.nn.sigmoid(block_out_att) #block_out = block_out_att*block_out + block_out #x = x1+x2+x3+x x_cnn += block_out #x = batch_normalize(x, is_training)) with tf.variable_scope('cnn6'): x_cnn = deconv_layer( x_cnn, [3, 3, 256, 64], [self.batch_size, self.image_size, self.image_size, 256], 1) #2 x_cnn = pixel_shuffle_layerg(x_cnn, 2, 64) # n_split = 256 / 2 ** 2 x_cnn = prelu(x_cnn) with tf.variable_scope('cnn7'): x_cnn = deconv_layer(x_cnn, [3, 3, 256, 64], [ self.batch_size, self.image_size * 2, self.image_size * 2, 256 ], 1) #2 x_cnn = pixel_shuffle_layerg(x_cnn, 2, 64) # n_split = 256 / 2 ** 2 x_cnn = prelu(x_cnn) with tf.variable_scope('cnn8'): x_cnn = deconv_layer(x_cnn, [3, 3, 256, 64], [ self.batch_size, self.image_size * 4, self.image_size * 4, 256 ], 1) #2 x_cnn = pixel_shuffle_layerg(x_cnn, 2, 64) # n_split = 256 / 2 ** 2 x_cnn = prelu(x_cnn) with tf.variable_scope('cnn9'): x_cnn = deconv_layer(x_cnn, [3, 3, 3, 64], [ self.batch_size, self.image_size * 8, self.image_size * 8, 3 ], 1) x_cnn_SR = x_cnn + self.bic_ref x_cnn_edge = self.Laplacian(x_cnn_SR) # GAN框架 with tf.variable_scope('gan1'): x_gan = deconv_layer( x_com, [3, 3, 64, 3], [self.batch_size, self.image_size, self.image_size, 64], 1) x_gan = prelu(x_gan) #shortcut = x for i in range(6): with tf.variable_scope('block{}gan1'.format(i + 1)): x1 = x2 = x3 = x_gan for j in range(3): with tf.variable_scope('block{}_{}gan1'.format( i + 1, j + 1)): with tf.variable_scope('ud1'): a1 = prelu( deconv_layer(x1, [3, 3, 64, 64], [ self.batch_size, self.image_size, self.image_size, 64 ], 1)) #a1 = batch_normalize(a1, is_training) with tf.variable_scope('ud2'): b1 = prelu( deconv_layer(x2, [3, 3, 64, 64], [ self.batch_size, self.image_size, self.image_size, 64 ], 1)) #b1 = batch_normalize(b1, is_training) with tf.variable_scope('ud3'): c1 = prelu( deconv_layer(x3, [3, 3, 64, 64], [ self.batch_size, self.image_size, self.image_size, 64 ], 1)) #c1 = batch_normalize(c1, is_training) sum = tf.concat([a1, b1, c1], 3) #sum = batch_normalize(sum, is_training) with tf.variable_scope('ud4'): x1 = prelu( deconv_layer(tf.concat( [sum, x1], 3), [1, 1, 64, 256], [ self.batch_size, self.image_size, self.image_size, 64 ], 1)) #x1 = batch_normalize(x1, is_training) with tf.variable_scope('ud5'): x2 = prelu( deconv_layer(tf.concat( [sum, x2], 3), [1, 1, 64, 256], [ self.batch_size, self.image_size, self.image_size, 64 ], 1)) #x2 = batch_normalize(x2, is_training) with tf.variable_scope('ud6'): x3 = prelu( deconv_layer(tf.concat( [sum, x3], 3), [1, 1, 64, 256], [ self.batch_size, self.image_size, self.image_size, 64 ], 1)) #x3 = batch_normalize(x3, is_training) with tf.variable_scope('ud7'): block_out = prelu( deconv_layer(tf.concat([x1, x2, x3], 3), [3, 3, 64, 192], [ self.batch_size, self.image_size, self.image_size, 64 ], 1)) #with tf.variable_scope('ud8'): #block_out_att = deconv_layer(block_out, [3, 3, 64, 64], [self.batch_size, self.image_size, self.image_size, 64], 1) #block_out_att = tf.nn.sigmoid(block_out_att) #block_out = block_out_att*block_out + block_out #x = x1+x2+x3+x x_gan += block_out with tf.variable_scope('gan6'): x_gan = deconv_layer( x_gan, [3, 3, 256, 64], [self.batch_size, self.image_size, self.image_size, 256], 1) #2 x_gan = pixel_shuffle_layerg(x_gan, 2, 64) # n_split = 256 / 2 ** 2 x_gan = prelu(x_gan) with tf.variable_scope('gan7'): x_gan = deconv_layer(x_gan, [3, 3, 256, 64], [ self.batch_size, self.image_size * 2, self.image_size * 2, 256 ], 1) #2 x_gan = pixel_shuffle_layerg(x_gan, 2, 64) # n_split = 256 / 2 ** 2 x_gan = prelu(x_gan) with tf.variable_scope('gan8'): x_gan = deconv_layer(x_gan, [3, 3, 256, 64], [ self.batch_size, self.image_size * 4, self.image_size * 4, 256 ], 1) #2 x_gan = pixel_shuffle_layerg(x_gan, 2, 64) # n_split = 256 / 2 ** 2 x_gan = prelu(x_gan) with tf.variable_scope('gan9'): x_gan = deconv_layer(x_gan, [3, 3, 3, 64], [ self.batch_size, self.image_size * 8, self.image_size * 8, 3 ], 1) x_gan_SR = x_gan + self.bic_ref x_gan_edge = self.Laplacian(x_gan_SR) # RNN框架 with tf.variable_scope('LSTM'): cell = BasicConvLSTMCell([self.image_size, self.image_size], [3, 3], 128) rnn_state = cell.zero_state(batch_size=self.batch_size, dtype=tf.float32) '''with tf.variable_scope('LSTM1'): cell_2 = BasicConvLSTMCell([self.image_size, self.image_size], [3, 3], 128) rnn_state_2 = cell_2.zero_state(batch_size=self.batch_size, dtype=tf.float32)''' with tf.variable_scope('rnn1'): x_rnn = deconv_layer( x_com, [3, 3, 64, 3], [self.batch_size, self.image_size, self.image_size, 64], 1) x_rnn = prelu(x_rnn) lstm_input = lstm_in = x_rnn for n in range(6): with tf.variable_scope('lstm_1_{}'.format(n)): x_rnn = deconv_layer(lstm_in, [3, 3, 128, 64], [ self.batch_size, self.image_size, self.image_size, 128 ], 1) x_rnn = prelu(x_rnn) y_1, rnn_state = cell(x_rnn, rnn_state) #with tf.variable_scope('lstm_2_{}'.format(n)): #x_rnn = deconv_layer( #y_1, [3, 3, 64, 128], [self.batch_size, self.image_size, self.image_size, 64], 1) #x_rnn = prelu(x_rnn) #with tf.variable_scope('lstm_3_{}'.format(n)): #x_rnn = deconv_layer( #y_1, [3, 3, 128, 128], [self.batch_size, self.image_size, self.image_size, 128], 1) #x_rnn = prelu(x_rnn) #y_2, rnn_state = cell(x_rnn, rnn_state) with tf.variable_scope('lstm_4_{}'.format(n)): x_rnn = deconv_layer(y_1, [3, 3, 64, 128], [ self.batch_size, self.image_size, self.image_size, 64 ], 1) x_rnn = prelu(x_rnn) lstm_in += x_rnn lstm_output = tf.concat([lstm_in, lstm_input], 3) with tf.variable_scope('rnn4'): x_rnn = deconv_layer( lstm_output, [3, 3, 64, 128], [self.batch_size, self.image_size, self.image_size, 64], 1) x_rnn = prelu(x_rnn) with tf.variable_scope('rnn5'): x_rnn = deconv_layer( x_rnn, [3, 3, 256, 64], [self.batch_size, self.image_size, self.image_size, 256], 1) #2 x_rnn = pixel_shuffle_layerg(x_rnn, 2, 64) # n_split = 256 / 2 ** 2 x_rnn = prelu(x_rnn) with tf.variable_scope('rnn6'): x_rnn = deconv_layer(x_rnn, [3, 3, 256, 64], [ self.batch_size, self.image_size * 2, self.image_size * 2, 256 ], 1) #2 x_rnn = pixel_shuffle_layerg(x_rnn, 2, 64) # n_split = 256 / 2 ** 2 x_rnn = prelu(x_rnn) with tf.variable_scope('rnn7'): x_rnn = deconv_layer(x_rnn, [3, 3, 256, 64], [ self.batch_size, self.image_size * 4, self.image_size * 4, 256 ], 1) #2 x_rnn = pixel_shuffle_layerg(x_rnn, 2, 64) # n_split = 256 / 2 ** 2 x_rnn = prelu(x_rnn) with tf.variable_scope('rnn8'): x_rnn = deconv_layer(x_rnn, [3, 3, 3, 64], [ self.batch_size, self.image_size * 8, self.image_size * 8, 3 ], 1) x_rnn_SR = x_rnn + self.bic_ref #all_features = tf.concat([x_gan_SR, x_cnn_SR],3) # attention框架 # cnn mask with tf.variable_scope('cnnmask1'): x_cnn_mask = conv_layer(x_cnn_SR, [3, 3, 3, 32], 1) #128 x_cnn_mask = prelu(x_cnn_mask) with tf.variable_scope('cnnmask1_1'): x_cnn_mask = conv_layer(x_cnn_mask, [3, 3, 32, 64], 2) #64 x_cnn_mask = prelu(x_cnn_mask) with tf.variable_scope('cnnmask1_2'): x_cnn_mask = conv_layer(x_cnn_mask, [3, 3, 64, 128], 2) #32 x_cnn_mask = prelu(x_cnn_mask) with tf.variable_scope('cnnmask1_3'): x_cnn_mask = conv_layer(x_cnn_mask, [3, 3, 128, 64], 1) #128 x_cnn_mask = prelu(x_cnn_mask) res_input_cnn = res_in_cnn = x_cnn_mask for j in range(3): with tf.variable_scope('cnnmask_1{}'.format(j)): fuse = deconv_layer(res_in_cnn, [3, 3, 64, 64], [ self.batch_size, self.image_size * 2, self.image_size * 2, 64 ], 1) #2 fuse = prelu(fuse) with tf.variable_scope('cnnmask_2{}'.format(j)): fuse = deconv_layer(fuse, [3, 3, 64, 64], [ self.batch_size, self.image_size * 2, self.image_size * 2, 64 ], 1) #2 fuse = prelu(fuse) res_in_cnn += fuse res_output_cnn = tf.concat([res_input_cnn, res_in_cnn], 3) with tf.variable_scope('cnnmask1_4'): x_cnn_mask = deconv_layer(res_output_cnn, [3, 3, 128, 128], [ self.batch_size, self.image_size * 4, self.image_size * 4, 128 ], 2) #2 #x_cnn_mask = pixel_shuffle_layerg(x_cnn_mask, 2, 32) # n_split = 256 / 2 ** 2 x_cnn_mask = prelu(x_cnn_mask) with tf.variable_scope('cnnmask1_5'): x_cnn_mask = deconv_layer(x_cnn_mask, [3, 3, 128, 128], [ self.batch_size, self.image_size * 8, self.image_size * 8, 128 ], 2) #2 #x_cnn_mask = pixel_shuffle_layerg(x_cnn_mask, 2, 32) # n_split = 256 / 2 ** 2 x_cnn_mask = prelu(x_cnn_mask) with tf.variable_scope('cnnmask2'): x_cnn_mask = conv_layer(x_cnn_mask, [3, 3, 128, 3], 1) #128 #x_cnn_mask = prelu(x_cnn_mask) x_cnn_mask = tf.nn.sigmoid(x_cnn_mask) #x_cnn_att = x_cnn_SR+x_cnnmask*x_cnn_SR x_cnn_att = x_cnn_mask * x_cnn_SR + x_cnn_SR #x_cnn_att_gan = (1-x_ganmask)*short1 # gan mask with tf.variable_scope('ganmask1'): x_gan_mask = conv_layer(x_gan_SR, [3, 3, 3, 32], 1) #128 x_gan_mask = prelu(x_gan_mask) with tf.variable_scope('ganmask1_1'): x_gan_mask = conv_layer(x_gan_mask, [3, 3, 32, 64], 2) #64 x_gan_mask = prelu(x_gan_mask) with tf.variable_scope('ganmask1_2'): x_gan_mask = conv_layer(x_gan_mask, [3, 3, 64, 128], 2) #32 x_gan_mask = prelu(x_gan_mask) with tf.variable_scope('ganmask1_3'): x_gan_mask = conv_layer(x_gan_mask, [3, 3, 128, 64], 1) #128 x_gan_mask = prelu(x_gan_mask) res_input_gan = res_in_gan = x_gan_mask for j in range(3): with tf.variable_scope('ganmask_1{}'.format(j)): fuse = deconv_layer(res_in_gan, [3, 3, 64, 64], [ self.batch_size, self.image_size * 2, self.image_size * 2, 64 ], 1) #2 fuse = prelu(fuse) with tf.variable_scope('ganmask_2{}'.format(j)): fuse = deconv_layer(fuse, [3, 3, 64, 64], [ self.batch_size, self.image_size * 2, self.image_size * 2, 64 ], 1) #2 fuse = prelu(fuse) res_in_gan += fuse res_output_gan = tf.concat([res_input_gan, res_in_gan], 3) with tf.variable_scope('ganmask1_4'): x_gan_mask = deconv_layer(res_output_gan, [3, 3, 128, 128], [ self.batch_size, self.image_size * 4, self.image_size * 4, 128 ], 2) #2 #x_gan_mask = pixel_shuffle_layerg(x_gan_mask, 2, 32) # n_split = 256 / 2 ** 2 x_gan_mask = prelu(x_gan_mask) with tf.variable_scope('ganmask1_5'): x_gan_mask = deconv_layer(x_gan_mask, [3, 3, 128, 128], [ self.batch_size, self.image_size * 8, self.image_size * 8, 128 ], 2) #2 #x_gan_mask = pixel_shuffle_layerg(x_gan_mask, 2, 32) # n_split = 256 / 2 ** 2 x_gan_mask = prelu(x_gan_mask) with tf.variable_scope('ganmask2'): x_gan_mask = conv_layer(x_gan_mask, [3, 3, 128, 3], 1) #128 #x_gan_mask = prelu(x_gan_mask) x_gan_mask = tf.nn.sigmoid(x_gan_mask) #x_cnn_att = x_cnn_SR+x_cnnmask*x_cnn_SR x_gan_att = x_gan_mask * x_gan_SR + x_gan_SR # # RNN # with tf.variable_scope('rnnmask1'): # x_rnnmask = conv_layer(x_rnn_SR, [3, 3, 3, 64], 1)#128 # x_rnnmask = prelu(x_rnnmask) # with tf.variable_scope('rnnmask1_1'): # x_rnnmask = conv_layer(x_rnnmask, [3, 3, 64, 128], 2)#64 # x_rnnmask = prelu(x_rnnmask) # with tf.variable_scope('rnnmask1_2'): # x_rnnmask = conv_layer(x_rnnmask, [3, 3, 128, 128], 2)#32 # x_rnnmask = prelu(x_rnnmask) # with tf.variable_scope('rnnmask1_3'): # x_rnnmask = conv_layer(x_rnnmask, [3, 3, 128, 64], 1)#128 # x_rnnmask = prelu(x_rnnmask) # res_input_rnn = res_in_rnn = x_rnnmask # for j in range(3): # with tf.variable_scope('rnnmask_1{}'.format(j)): # fuse = deconv_layer( # res_in_rnn, [3, 3, 64, 64], [self.batch_size, self.image_size*2, self.image_size*2, 64], 1)#2 # fuse = prelu(fuse) # with tf.variable_scope('rnnmask_2{}'.format(j)): # fuse = deconv_layer( # fuse, [3, 3, 64, 64], [self.batch_size, self.image_size*2, self.image_size*2, 64], 1)#2 # fuse = prelu(fuse) # res_in_rnn +=fuse # res_output_gan = tf.concat([res_input_rnn,res_in_rnn],3) # with tf.variable_scope('rnnmask1_4'): # x_rnnmask = deconv_layer( # res_output_gan, [3, 3, 128, 128], [self.batch_size, self.image_size*2, self.image_size*2, 128], 1)#2 # x_rnnmask = pixel_shuffle_layerg(x_rnnmask, 2, 32) # n_split = 256 / 2 ** 2 # x_rnnmask = prelu(x_rnnmask) # with tf.variable_scope('rnnmask1_5'): # x_rnnmask = deconv_layer( # x_rnnmask, [3, 3, 128, 32], [self.batch_size, self.image_size*4, self.image_size*4, 128], 1)#2 # x_rnnmask = pixel_shuffle_layerg(x_rnnmask, 2, 32) # n_split = 256 / 2 ** 2 # x_rnnmask = prelu(x_rnnmask) # with tf.variable_scope('rnnmask2'): # x_rnnmask = conv_layer(x_rnnmask, [3, 3, 32, 3], 1)#128 # x_rnnmask = prelu(x_rnnmask) # x_rnnmask = tf.nn.sigmoid(x_rnnmask) # x_rnn_att = x_rnnmask*x_rnn_SR+x_rnn_SR #att_SR = x_gan_att + x_cnn_att # fusion 融合x_cnn_SR, x_gan_SR, x_ganmask,x_cnnmask,x_rnnmask, att_feature = tf.concat( [x_gan_att, x_cnn_att, x_gan_mask, x_cnn_mask], 3) #x_gan_att + x_cnn_att with tf.variable_scope('fu1'): fuse = deconv_layer(att_feature, [3, 3, 128, 12], [ self.batch_size, self.image_size * 8, self.image_size * 8, 128 ], 1) #2 fuse = prelu(fuse) #res_input = res_in = fuse # for j in range(3): # with tf.variable_scope('res1_{}'.format(j)): # fuse = deconv_layer( # res_in, [3, 3, 64, 128], [self.batch_size, self.image_size*8, self.image_size*8, 64], 1)#2 # fuse = prelu(fuse) # with tf.variable_scope('res2_{}'.format(j)): # fuse = deconv_layer( # fuse, [3, 3, 128, 64], [self.batch_size, self.image_size*8, self.image_size*8, 128], 1)#2 # fuse = prelu(fuse) # res_in +=fuse # res_output = res_input + res_in with tf.variable_scope('fu5'): fuse = deconv_layer(fuse, [3, 3, 3, 128], [ self.batch_size, self.image_size * 8, self.image_size * 8, 3 ], 1) att_SR = fuse + x_rnn_SR #*(1-x_ganmask-x_cnnmask-x_rnnmask) x_att_edge = self.Laplacian(att_SR) self.g_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') return x_att_edge, x_cnn_mask, x_gan_mask, x_cnn_SR, x_gan_SR, x_rnn_SR, att_SR
def train(train_data, validate_data, test_data): """Train ring_net for a number of steps.""" with tf.Graph().as_default(): # make inputs x = tf.placeholder(tf.float32, [ None, FLAGS.seq_length, FLAGS.input_height, FLAGS.input_width, FLAGS.input_channel ]) # possible dropout inside #keep_prob = tf.placeholder("float") #x_dropout = tf.nn.dropout(x, keep_prob) # create network x_unwrap = [] with tf.variable_scope('conv_lstm', initializer=tf.random_uniform_initializer( -.01, 0.1)): # BasicConvLSTMCell: (shape, filter_size, num_features, state_is_tuple) cell_1 = BasicConvLSTMCell.BasicConvLSTMCell([16, 16], [3, 3], 64, state_is_tuple=True) # new_state: [batch_size, 16, 16, 64] for c and h new_state_1 = cell_1.zero_state(FLAGS.batch_size) # cell_2 cell_2 = BasicConvLSTMCell.BasicConvLSTMCell([16, 16], [3, 3], 64, state_is_tuple=True) new_state_2 = cell_2.zero_state(FLAGS.batch_size) # cell_3 cell_3 = BasicConvLSTMCell.BasicConvLSTMCell([16, 16], [3, 3], 64, state_is_tuple=True) #new_state_3 = cell_3.zero_state(FLAGS.batch_size) # cell_3 cell_4 = BasicConvLSTMCell.BasicConvLSTMCell([16, 16], [3, 3], 64, state_is_tuple=True) #new_state_4 = cell_3.zero_state(FLAGS.batch_size) # conv network # conv_layer: (input, kernel_size, stride, num_features, scope_name_idx, linear, reuseL) # BasicConvLSTMCell: __call__(inputs, state, scope, reuseL) reuseL = None with tf.variable_scope(tf.get_variable_scope()) as scope: for i in xrange(FLAGS.seq_length - 1): # -------------------------- encode ------------------------ # conv1 # [16, 64, 64, 1] -> [16, 32, 32, 8] if i < FLAGS.seq_start: conv1 = ld.conv_layer(x[:, i, :, :, :], 3, 2, 8, "encode_1", reuseL=reuseL) else: conv1 = ld.conv_layer(x[:, i, :, :, :], 3, 2, 8, "encode_1", reuseL=reuseL) # conv2 # [16, 32, 32, 8] -> [16, 16, 16, 16] print conv1.get_shape().as_list() conv2 = ld.conv_layer(conv1, 3, 2, 16, "encode_2", reuseL=reuseL) # convLSTM 3 # [16, 16, 16, 16] -> [16, 16, 16, 64] print conv2.get_shape().as_list() y_0 = conv2 print y_0.get_shape().as_list() y_1, new_state_1 = cell_1(y_0, new_state_1, "encode_3_convLSTM_1", reuseL=reuseL) # convLSTM 4 # [16, 16, 16, 64] -> [16, 16, 16, 64] print y_1.get_shape().as_list() y_2, new_state_2 = cell_2(y_1, new_state_2, "encode_4_convLSTM_2", reuseL=reuseL) print y_2.get_shape().as_list() # ------------------------ decode ------------------------- # convLSTM 5 # [16, 16, 16, 64] -> [16, 16, 16, 64] # copy the initial states and cell outputs from convLSTM 4 if i == 0: new_state_3 = new_state_2 y_3, new_state_3 = cell_3(y_2, new_state_3, "encode_5_convLSTM_3", reuseL=reuseL) # convLSTM 6 # [16, 16, 16, 64] -> [16, 16, 16, 64] print y_3.get_shape().as_list() if i == 0: new_state_4 = new_state_1 y_4, new_state_4 = cell_4(y_3, new_state_4, "encode_6_convLSTM_4", reuseL=reuseL) print y_4.get_shape().as_list() # conv7 # [16, 16, 16, 64] -> [16, 32, 32, 8] conv6 = y_4 conv7 = ld.transpose_conv_layer(conv6, 3, 2, 8, "decode_7", reuseL=reuseL) # x_1 # [16, 32, 32, 8] -> [16, 64, 64, 1] print conv7.get_shape().as_list() x_1 = ld.transpose_conv_layer( conv7, 3, 2, 1, "decode_8", linear=True, reuseL=reuseL) # set activation to linear if i >= FLAGS.seq_start - 1: x_unwrap.append(x_1) # set reuse to true after first go if i == 0: #tf.get_variable_scope().reuse_variables() reuseL = True # stack them all together x_unwrap = tf.stack(x_unwrap) x_unwrap = tf.transpose(x_unwrap, [1, 0, 2, 3, 4]) # calc total loss (compare x_t to x_t+1) loss = tf.nn.l2_loss(x[:, FLAGS.seq_start:, :, :, :] - x_unwrap[:, :, :, :, :]) tf.summary.scalar('loss', loss) # training train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss) # List of all Variables #variables = tf.global_variables() # Build a saver saver = tf.train.Saver(tf.global_variables()) # Summary op summary_op = tf.summary.merge_all() # Build an initialization operation to run below. init = tf.global_variables_initializer() # --------------------- Start running operations on the Graph ------------------------ sess = tf.Session() # init if this is the very time training print("init network from scratch") sess.run(init) # Summary op #graph_def = sess.graph.as_graph_def(add_shapes=True) summary_writer = tf.summary.FileWriter(FLAGS.train_dir, graph=sess.graph) # ------------------------------- training for train_data ------------------------------------- print("now training step:") for step in xrange(FLAGS.max_step): dat = get_batch_data(train_data) t = time.time() _, loss_r = sess.run([train_op, loss], feed_dict={x: dat}) elapsed = time.time() - t if step % 100 == 0 and step != 0: summary_str = sess.run(summary_op, feed_dict={x: dat}) summary_writer.add_summary(summary_str, step) #print("time per batch is " + str(elapsed)) print("at step " + str(step) + ", loss is " + str(loss_r)) #print(loss_r) assert not np.isnan(loss_r), 'Model diverged with loss = NaN' if step % 1000 == 0 and step != 0: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step) print("saved to " + FLAGS.train_dir) # -------------------------- validation for each 1000 steps --------------------- print("at step " + str(step) + ", validate...") v_whole_loss = 0 # [all_size, seq_length, in_height, in_width, in_channel] all_batch_v = np.zeros( (validate_data.shape[0] - FLAGS.seq_length, FLAGS.seq_length, FLAGS.input_height, FLAGS.input_width, FLAGS.input_channel), dtype=np.float32) #all_batch_v = validate_data[v_i, v_i:v_i+FLAGS.seq_length, :, :, :] #while v_i <validate_data.shape[0]: for v_i in xrange(validate_data.shape[0] - FLAGS.seq_length): all_batch_v[v_i] = validate_data[v_i:v_i + FLAGS.seq_length, :, :, :] v_i = 0 while v_i < all_batch_v.shape[0] - FLAGS.batch_size: # [batch_size, seq_length, in_height, in_width, channel] dat_v = all_batch_v[v_i:v_i + FLAGS.batch_size, :, :, :, :] v_loss = sess.run([loss], feed_dict={x: dat_v}) v_whole_loss += v_loss print("validation loss: " + str(v_whole_loss)) # ---------------------------- for test data ------------------------------------- print("now testing step:") t_whole_loss = 0 all_batch_t = np.zeros( (test_data.shape[0] - FLAGS.seq_length, FLAGS.seq_length, FLAGS.input_height, FLAGS.input_width, FLAGS.input_channel), dtype=np.float32) for t_i in xrange(test_data.shape[0] - FLAGS.seq_length): all_batch_t[t_i] = test_data[t_i:t_i + FLAGS.seq_length, :, :, :] t_i = 0 while t_i < all_batch_t.shape[0] - FLAGS.batch_size: dat_t = all_batch_t[t_i:t_i + FLAGS.batch_size, :, :, :, :] predict, t_loss = sess.run([x_unwrap, loss], feed_dict={x: dat_t}) t_whole_loss += t_loss print("test loss: " + str(t_whole_loss)) np.save('prediction.npy', pred)
def forward(self, frames_lr, is_training=True, reuse=False): def MMB(block_in, cell, rnn_state=[], num_b=1, num_d=3, max_feature=64, scope='MMB'): for i in range(num_b): block_out = block_in block_out = slim.conv2d(block_out, 16, [3, 3], scope='conv0_{}'.format(i)) with tf.variable_scope('multi_memory', reuse=False) as multi_memory: with tf.variable_scope(scope, reuse=False): for j in range(num_d): conv1, rnn_state[i][j] = cell[j]( block_out, rnn_state[i][j], scope='rnn{}_{}'.format(i, j)) block_out = tf.concat([block_out, conv1], 3) block_out = slim.conv2d(block_out, max_feature, [3, 3], scope='out_{}'.format(i)) block_in += block_out return block_in, rnn_state def RAB(block_in, scope='RAB'): with tf.variable_scope(scope, reuse=False): conv1 = slim.conv2d(block_in, 64, [3, 3], scope='conv1') conv2 = slim.conv2d(conv1, 64, [3, 3], activation_fn=None, scope='conv2') # channel attention pool = tf.reduce_mean(conv2, axis=[1, 2], keep_dims=True) conv_ca1 = slim.conv2d(pool, 4, [1, 1], scope='conv_ca1') conv_ca2 = slim.conv2d(conv_ca1, 64, [1, 1], activation_fn=None, scope='conv_ca2') # spatial attention conv_sa1 = slim.conv2d(conv2, 64, [1, 1], scope='conv_sa1') conv_sa2 = slim.separable_conv2d(conv_sa1, 64, [3, 3], depth_multiplier=1, activation_fn=None, scope='conv_sa2') out = conv_ca2 + conv_sa2 scale = tf.nn.sigmoid(out) fa = scale * conv2 block_out = slim.conv2d(fa, 64, [1, 1], scope='conv_fa', activation_fn=None) block_out = block_out + block_in return block_out def RB(block_in, scope='RB'): with tf.variable_scope(scope, reuse=False): conv1 = slim.conv2d(block_in, 64, [3, 3], scope='conv1') conv2 = slim.conv2d(conv1, 64, [3, 3], scope='conv2', activation_fn=None) block_out = conv2 + block_in block_out = tf.nn.relu(block_out) return block_out def Feat_Ex(input, scope='feat_ex'): with tf.variable_scope(scope, reuse=False): RB1 = RB(input, scope='RB1') RB2 = RB(RB1, scope='RB2') RB3 = RB(RB2, scope='RB3') RB4 = RB(RB3, scope='RB4') return RB4 def RAG(input, scope='RAG'): with tf.variable_scope(scope, reuse=False): RAB1 = RAB(input, scope='RB1') RAB2 = RAB(RAB1, scope='RB2') RAB3 = RAB(RAB2, scope='RB3') RAB4 = RAB(RAB3, scope='RB4') RAB5 = RAB(RAB4, scope='RB5') RAB6 = RAB(RAB5, scope='RB6') conv_g = slim.conv2d(RAB6, 64, [3, 3], scope='conv_g') out = conv_g + input return out def Recon(input, scope='Recon'): with tf.variable_scope(scope, reuse=False): rag1 = RAG(input, scope='rag1') out1 = slim.conv2d(rag1, 64, [1, 1], scope='out1') rag2 = RAG(rag1, scope='rag2') out2 = slim.conv2d(rag2, 64, [1, 1], scope='out2') rag3 = RAG(rag2, scope='rag3') out3 = slim.conv2d(rag3, 64, [1, 1], scope='out3') rag4 = RAG(rag3, scope='rag4') out4 = slim.conv2d(rag4, 64, [1, 1], scope='out4') rag5 = RAG(rag4, scope='rag5') out5 = slim.conv2d(rag5, 64, [1, 1], scope='out5') rag6 = RAG(rag5, scope='rag6') out6 = slim.conv2d(rag6, 64, [1, 1], scope='out6') # # rag7 = RAG(rag6, scope='rag7') # out7 = slim.conv2d(rag7, 64, [1, 1], scope='out7') # # rag8 = RAG(rag7, scope='rag8') # out8 = slim.conv2d(rag8, 64, [1, 1], scope='out8') # res = out1 + out2 + out3 + out4 res = out1 + out2 + out3 + out4 + out5 + out6 return res num_batch, num_frame, height, width, num_channels = frames_lr.get_shape( ).as_list() out_height = height * self.scale_factor out_width = width * self.scale_factor # calculate flow idx0 = num_frame // 2 frames_y = rgb2y(frames_lr) frame_ref_y = frames_y[:, int(idx0), :, :, :] self.frames_y = frames_y self.frame_ref_y = frame_ref_y # frame_0up_ref = zero_upsampling(frame_ref_y, scale_factor=self.scale_factor) frame_bic_ref = tf.image.resize_images(frame_ref_y, [out_height, out_width], method=2) # tf.summary.image('inp_0', im2uint8(frames_y[0, :, :, :, :]), max_outputs=3) # tf.summary.image('bic', im2uint8(frame_bic_ref), max_outputs=3) x_unwrap = [] with tf.variable_scope('LSTM'): cell = [] cell.append( BasicConvLSTMCell.BasicConvLSTMCell( [out_height // 4, out_width // 4], [3, 3], 16)) cell.append( BasicConvLSTMCell.BasicConvLSTMCell( [out_height // 4, out_width // 4], [3, 3], 32)) cell.append( BasicConvLSTMCell.BasicConvLSTMCell( [out_height // 4, out_width // 4], [3, 3], 64)) rnn_state = [] for i in range(7): rs = [] for j in range(3): rs.append(cell[j].zero_state(batch_size=num_batch, dtype=tf.float32)) rnn_state.append(rs) self.uv = [] frame_i_fw_all = [] max_feature = 64 for i in range(num_frame): if i > 0 and not reuse: reuse = True frame_i = frames_y[:, i, :, :, :] if i == 0: uv = self.flownets.forward(frame_i, frame_ref_y, reuse=reuse) else: uv = self.flownets.forward(frame_i, frame_ref_y, reuse=True) self.uv.append(uv) print('Build model - frame_{}'.format(i), frame_i.get_shape(), uv.get_shape()) frame_i_fw = imwarp_forward(uv, tf.concat([frame_i], -1), [height, width]) if i == 0: tem = imwarp_forward(uv, tf.concat([frame_i], -1), [height, width]) frame_i_fw_all = tem else: tem = imwarp_forward(uv, tf.concat([frame_i], -1), [height, width]) frame_i_fw_all = tf.concat([frame_i_fw_all, tem], axis=0) detail = frame_i - frame_i_fw with tf.variable_scope('srmodel', reuse=reuse) as scope_sr: with slim.arg_scope([slim.conv2d], activation_fn=prelu, stride=1, weights_initializer=tf.contrib.layers.xavier_initializer(uniform=True), biases_initializer=tf.constant_initializer(0.0)), \ slim.arg_scope([slim.batch_norm], center=True, scale=False, updates_collections=None, activation_fn=prelu, epsilon=1e-5, is_training=is_training): rnn_input = tf.concat([frame_i_fw, detail], 3) conv1 = slim.conv2d(rnn_input, 64, [3, 3], scope='enc1') feat = Feat_Ex(conv1, scope='feat') conv2 = slim.conv2d(feat, 64, [3, 3], scope='enc2') block_in = conv2 block_in, rnn_state = MMB(block_in, cell, rnn_state, num_b=7, num_d=3, scope='MMB') feat_r = Recon(block_in, scope='Recon') conv3 = slim.conv2d(feat_r, max_feature, [3, 3], scope='conv6') out = slim.conv2d(conv3, self.scale_factor * self.scale_factor * max_feature, [3, 3], activation_fn=None, scope='out') ps_out = ps._PS(out, self.scale_factor, max_feature) sr_out = slim.conv2d(ps_out, 1, [3, 3], activation_fn=None, scope='sr_out') rnn_out = sr_out + frame_bic_ref if i >= 0: x_unwrap.append(rnn_out) if i == 0: tf.get_variable_scope().reuse_variables() x_unwrap = tf.stack(x_unwrap, 1) self.uv = tf.stack(self.uv, 1) return x_unwrap, frame_i_fw_all