def compute_losses(self, batch_img1, batch_img2, batch_img3, flow_fw_12, flow_bw_21, flow_fw_23, flow_bw_32, mask_fw_12, mask_bw_21, mask_fw_23, mask_bw_32, train=True, is_scale=True): img_size = get_shape(batch_img1, train=train) img1_warp2 = tf_warp(batch_img1, flow_bw_21['full_res'], img_size[1], img_size[2]) img2_warp1 = tf_warp(batch_img2, flow_fw_12['full_res'], img_size[1], img_size[2]) img2_warp3 = tf_warp(batch_img2, flow_bw_32['full_res'], img_size[1], img_size[2]) img3_warp2 = tf_warp(batch_img3, flow_fw_23['full_res'], img_size[1], img_size[2]) losses = {} abs_robust_mean = {} abs_robust_mean['no_occlusion'] = self.abs_robust_loss(batch_img1-img2_warp1, tf.ones_like(mask_fw_12)) + self.abs_robust_loss(batch_img2-img1_warp2, tf.ones_like(mask_bw_21)) + \ self.abs_robust_loss(batch_img2-img3_warp2, tf.ones_like(mask_fw_23)) + self.abs_robust_loss(batch_img3-img2_warp3, tf.ones_like(mask_bw_32)) abs_robust_mean['occlusion'] = self.abs_robust_loss(batch_img1-img2_warp1, mask_fw_12) + self.abs_robust_loss(batch_img2-img1_warp2, mask_bw_21) + \ self.abs_robust_loss(batch_img2-img3_warp2, mask_fw_23) + self.abs_robust_loss(batch_img3-img2_warp3, mask_bw_32) losses['abs_robust_mean'] = abs_robust_mean census_loss = {} census_loss['no_occlusion'] = self.census_loss(batch_img1, img2_warp1, tf.ones_like(mask_fw_12), max_distance=3) + \ self.census_loss(batch_img2, img1_warp2, tf.ones_like(mask_bw_21), max_distance=3) + \ self.census_loss(batch_img2, img3_warp2, tf.ones_like(mask_fw_23), max_distance=3) + \ self.census_loss(batch_img3, img2_warp3, tf.ones_like(mask_bw_32), max_distance=3) census_loss['occlusion'] = self.census_loss(batch_img1, img2_warp1, mask_fw_12, max_distance=3) + \ self.census_loss(batch_img2, img1_warp2, mask_bw_21, max_distance=3) + \ self.census_loss(batch_img2, img3_warp2, mask_fw_23, max_distance=3) + \ self.census_loss(batch_img3, img2_warp3, mask_bw_32, max_distance=3) losses['census'] = census_loss return losses
def estimator(x0, x1, x2, flow_fw, flow_bw, train=True, trainable=True, reuse=None, regularizer=None, name='estimator'): # warp x2 according to flow if train: x_shape = x1.get_shape().as_list() else: x_shape = tf.shape(x1) H = x_shape[1] W = x_shape[2] channel = x_shape[3] x2_warp = tf_warp(x2, flow_fw, H, W) x0_warp = tf_warp(x0, flow_bw, H, W) # ---------------cost volume----------------- cost_volume_fw = compute_cost_volume(x1, x2_warp, H, W, channel, d=9) cost_volume_bw = compute_cost_volume(x1, x0_warp, H, W, channel, d=9) cv_concat_fw = tf.concat([cost_volume_fw, cost_volume_bw], -1) cv_concat_bw = tf.concat([cost_volume_bw, cost_volume_fw], -1) flow_concat_fw = tf.concat([flow_fw, -flow_bw], -1) flow_concat_bw = tf.concat([flow_bw, -flow_fw], -1) net_fw = estimator_network(x1, cv_concat_fw, flow_concat_fw, train=train, trainable=trainable, reuse=reuse, regularizer=regularizer, name=name) net_bw = estimator_network(x1, cv_concat_bw, flow_concat_bw, train=train, trainable=trainable, reuse=True, regularizer=regularizer, name=name) return net_fw, net_bw
def decompose(It, Vt_O, Vt_B, I_O_init, I_B_init, A_init): tf.reset_default_graph() It = tf.constant(It, tf.float32) Vt_O = tf.constant(Vt_O, tf.float32) Vt_B = tf.constant(Vt_B, tf.float32) I_O = tf.Variable(I_O_init, name='I_O', dtype=tf.float32) I_B = tf.Variable(I_B_init, name='I_B', dtype=tf.float32) warp_I_O = tf_warp(tf.tile(tf.expand_dims(I_O, 0), [5, 1, 1, 1]), Vt_O) warp_I_B = tf_warp(tf.tile(tf.expand_dims(I_B, 0), [5, 1, 1, 1]), Vt_B) g_O = spatial_gradient(tf.expand_dims(I_O, 0)) g_B = spatial_gradient(tf.expand_dims(I_B, 0)) A = None if A_init is not None: A = tf.Variable(A_init, name='A', dtype=tf.float32) warp_A = tf_warp(tf.tile(tf.expand_dims(A, 0), [5, 1, 1, 1]), Vt_O) residual = l1_norm(It - warp_I_O - tf.multiply(warp_A, warp_I_B)) else: residual = l1_norm(It - warp_I_O - warp_I_B) loss = l1_norm(residual) if A is not None: loss += LAMBDA_1 * \ tf.norm(spatial_gradient(tf.expand_dims(A, 0)), ord=2)**2 loss += LAMBDA_2 * (l1_norm(g_O) + l1_norm(g_B)) loss += LAMBDA_3 * tf.norm(g_O*g_O*g_B*g_B, ord=2)**2 loss += constraint_penalty(I_O) + \ constraint_penalty(I_B) if A is not None: loss += constraint_penalty(A) optimizer = tf.train.AdamOptimizer(learning_rate=0.001) train = optimizer.minimize(loss) with tf.Session() as session: session.run(tf.initialize_all_variables()) for step in range(1000): _, loss_val = session.run([train, loss]) print("step {}:loss = {}".format(step, loss_val)) if A is not None: I_O, I_B, A = session.run([I_O, I_B, A]) else: I_O, I_B = session.run([I_O, I_B]) visualize_image(I_O, 'obstruction') visualize_image(I_B, 'background') if A is not None: visualize_image(A, 'alpha') return I_O, I_B, A
def occlusion(flow_fw, flow_bw): x_shape = tf.shape(flow_fw) H = x_shape[1] W = x_shape[2] flow_bw_warped = tf_warp(flow_bw, flow_fw, H, W) flow_fw_warped = tf_warp(flow_fw, flow_bw, H, W) flow_diff_fw = flow_fw + flow_bw_warped flow_diff_bw = flow_bw + flow_fw_warped mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped) mag_sq_bw = length_sq(flow_bw) + length_sq(flow_fw_warped) occ_thresh_fw = 0.01 * mag_sq_fw + 0.5 occ_thresh_bw = 0.01 * mag_sq_bw + 0.5 occ_fw = tf.cast(length_sq(flow_diff_fw) > occ_thresh_fw, tf.float32) occ_bw = tf.cast(length_sq(flow_diff_bw) > occ_thresh_bw, tf.float32) return occ_fw, occ_bw
def estimator(x1, x2, flow, train=True, trainable=True, reuse=None, regularizer=None, name='estimator'): # warp x2 according to flow x_shape = get_shape(x1, train=train) H = x_shape[1] W = x_shape[2] channel = x_shape[3] x2_warp = tf_warp(x2, flow, H, W) # ---------------cost volume----------------- # normalize x1 = tf.nn.l2_normalize(x1, axis=3) x2_warp = tf.nn.l2_normalize(x2_warp, axis=3) d = 9 #x2_patches = tf.extract_image_patches(x2_warp, [1, d, d, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='SAME') out_channels = d * d w = tf.eye(out_channels * channel, dtype=tf.float32) w = tf.reshape(w, (d, d, channel, out_channels * channel)) x2_patches = tf.nn.conv2d(x2_warp, w, strides=[1, 1, 1, 1], padding='SAME') x2_patches = tf.reshape(x2_patches, [-1, H, W, d, d, channel]) x1_reshape = tf.reshape(x1, [-1, H, W, 1, 1, channel]) x1_dot_x2 = tf.multiply(x1_reshape, x2_patches) cost_volume = tf.reduce_sum(x1_dot_x2, axis=-1) cost_volume = tf.reshape(cost_volume, [-1, H, W, d * d]) # --------------estimator network------------- net_input = tf.concat([cost_volume, x1, flow], axis=-1) with tf.variable_scope(name, reuse=reuse, regularizer=regularizer): with slim.arg_scope([slim.conv2d], activation_fn=lrelu, kernel_size=3, padding='SAME', trainable=trainable): net = {} net['conv1'] = slim.conv2d(net_input, 128, scope='conv1') net['conv2'] = slim.conv2d(net['conv1'], 128, scope='conv2') net['conv3'] = slim.conv2d(net['conv2'], 96, scope='conv3') net['conv4'] = slim.conv2d(net['conv3'], 64, scope='conv4') net['conv5'] = slim.conv2d(net['conv4'], 32, scope='conv5') net['conv6'] = slim.conv2d(net['conv5'], 2, activation_fn=None, scope='conv6') #flow_estimated = net['conv6'] return net
def estimate_motion(It, I_O, I_B, A, Vt_O_init, Vt_B_init): tf.reset_default_graph() It = tf.constant(It, tf.float32) I_O = tf.constant(I_O, tf.float32) I_B = tf.constant(I_B, tf.float32) if A is not None: A = tf.constant(A, tf.float32) Vt_O = tf.Variable(Vt_O_init, name='Vt_O', dtype=tf.float32) Vt_B = tf.Variable(Vt_B_init, name='Vt_B', dtype=tf.float32) warp_I_O = tf_warp(tf.tile(tf.expand_dims(I_O, 0), [5, 1, 1, 1]), Vt_O) warp_I_B = tf_warp(tf.tile(tf.expand_dims(I_B, 0), [5, 1, 1, 1]), Vt_B) if A is not None: warp_A = tf_warp(tf.tile(tf.expand_dims(A, 0), [5, 1, 1, 1]), Vt_O) residual = It - warp_I_O - tf.multiply(warp_A, warp_I_B) else: residual = It - warp_I_O - warp_I_B loss = l1_norm(residual) loss += LAMBDA_4 * (l1_norm(spatial_gradient(Vt_O)) + l1_norm(spatial_gradient(Vt_B))) optimizer = tf.train.AdamOptimizer(learning_rate=0.01) train = optimizer.minimize(loss) with tf.Session() as session: session.run(tf.initialize_all_variables()) for step in range(500): _, loss_val = session.run([train, loss]) print("step {}:loss = {}".format(step, loss_val)) Vt_O, Vt_B = session.run([Vt_O, Vt_B]) for i in range(5): visualize_dense_motion(Vt_O[i]) visualize_dense_motion(Vt_B[i]) return Vt_O, Vt_B
def train(model1, model2, model3, model4, model5, train_data, valid_data): """ model1: model path for translation model2: model path for rotation model3: model path for growth/decay model4: model path for gating model5: model path for the whole STMoE model train_data: np.array for training valid_data: np.array for validation """ x = tf.placeholder( tf.float32, [FLAGS.batch_size, FLAGS.height, FLAGS.width, FLAGS.seq_length]) x_g = [] weights = [] f_g_s = [] x_g_s, x_g_r, x_g_g = [], [], [] hidden_state_1, hidden_state_diff_1, cell_state_1, cell_state_diff_1, st_memory_1 = [], [], [], [], [] hidden_state_2, hidden_state_diff_2, cell_state_2, cell_state_diff_2, st_memory_2 = [], [], [], [], [] for i in range(FLAGS.seq_start - 1): with tf.variable_scope('expert2'): inputs = x[:, :, :, i] inputs = inputs[:, :, :, tf.newaxis] x_generate_r, hidden_state_1, hidden_state_diff_1, cell_state_1, cell_state_diff_1, st_memory_1 = network_grow( inputs, i, hidden_state_1, hidden_state_diff_1, cell_state_1, cell_state_diff_1, st_memory_1, 3, [8, 8, 8], (3, 3), FLAGS.h_conv_ksize, stride=1, tln=True, trainable_last=FLAGS.trainable_last) with tf.variable_scope('expert3'): inputs = x[:, :, :, i] inputs = inputs[:, :, :, tf.newaxis] x_generate_g, hidden_state_2, hidden_state_diff_2, cell_state_2, cell_state_diff_2, st_memory_2 = network_grow( inputs, i, hidden_state_2, hidden_state_diff_2, cell_state_2, cell_state_diff_2, st_memory_2, 3, [8, 8, 8], (3, 3), FLAGS.h_conv_ksize, stride=1, tln=True, trainable_last=FLAGS.trainable_last) # predict recursively for i in range(FLAGS.seq_length - FLAGS.seq_start): print('frame_{}'.format(i)) if i == 0: with tf.variable_scope('expert1'): f_generate = network_shift(x[:, :, :, i:i + FLAGS.seq_start]) f_g_s.append(f_generate) last_x = x[:, :, :, FLAGS.seq_start - 1] x_generate_s = tf_warp(last_x[:, :, :, tf.newaxis], f_generate, FLAGS.height, FLAGS.width) x_generate_s = tf.reshape( x_generate_s[:, :, :, 0], [FLAGS.batch_size, FLAGS.height, FLAGS.width, 1]) x_g_s.append(x_generate_s) with tf.variable_scope('expert2'): inputs = x[:, :, :, FLAGS.seq_start - 1] inputs = inputs[:, :, :, tf.newaxis] x_generate_r, hidden_state_1, hidden_state_diff_1, cell_state_1, cell_state_diff_1, st_memory_1 = network_grow( inputs, i + FLAGS.seq_start, hidden_state_1, hidden_state_diff_1, cell_state_1, cell_state_diff_1, st_memory_1, 3, [8, 8, 8], (3, 3), FLAGS.h_conv_ksize, stride=1, tln=True, trainable_last=FLAGS.trainable_last) x_g_r.append(x_generate_r) with tf.variable_scope('expert3'): inputs = x[:, :, :, FLAGS.seq_start - 1] inputs = inputs[:, :, :, tf.newaxis] x_generate_g, hidden_state_2, hidden_state_diff_2, cell_state_2, cell_state_diff_2, st_memory_2 = network_grow( inputs, i + FLAGS.seq_start, hidden_state_2, hidden_state_diff_2, cell_state_2, cell_state_diff_2, st_memory_2, 3, [8, 8, 8], (3, 3), FLAGS.h_conv_ksize, stride=1, tln=True, trainable_last=FLAGS.trainable_last) x_g_g.append(x_generate_g) with tf.variable_scope('gating_network'): weight = network_gate(x[:, :, :, i:i + FLAGS.seq_start], FLAGS.gating_num, moe='moe1') x_sr = tf.concat([x_generate_s, x_generate_r, x_generate_g], axis=-1) x_generate = weight * x_sr x_generate = tf.reduce_sum(x_generate, axis=-1) x_g.append(x_generate) weights.append(weight) else: x_gen = tf.stack(x_g) print(x_gen.shape) x_gen = tf.transpose(x_gen, [1, 2, 3, 0]) if i < FLAGS.seq_start: x_input = tf.concat( [x[:, :, :, i:FLAGS.seq_start], x_gen[:, :, :, :i]], axis=3) else: x_input = x_gen[:, :, :, i - FLAGS.seq_start:i] with tf.variable_scope('expert1'): f_generate = network_shift(x_input) f_g_s.append(f_generate) last_x = x_g[-1] x_generate_s = tf_warp(last_x[:, :, :, tf.newaxis], f_generate, FLAGS.height, FLAGS.width) x_generate_s = tf.reshape( x_generate_s[:, :, :, 0], [FLAGS.batch_size, FLAGS.height, FLAGS.width, 1]) x_g_s.append(x_generate_s) with tf.variable_scope('expert2'): inputs = x_g[-1] inputs = inputs[:, :, :, tf.newaxis] x_generate_r, hidden_state_1, hidden_state_diff_1, cell_state_1, cell_state_diff_1, st_memory_1 = network_grow( inputs, i + FLAGS.seq_start, hidden_state_1, hidden_state_diff_1, cell_state_1, cell_state_diff_1, st_memory_1, 3, [8, 8, 8], (3, 3), FLAGS.h_conv_ksize, stride=1, tln=True, trainable_last=FLAGS.trainable_last) x_g_r.append(x_generate_r) with tf.variable_scope('expert3'): inputs = x_g[-1] inputs = inputs[:, :, :, tf.newaxis] x_generate_g, hidden_state_2, hidden_state_diff_2, cell_state_2, cell_state_diff_2, st_memory_2 = network_grow( inputs, i + FLAGS.seq_start, hidden_state_2, hidden_state_diff_2, cell_state_2, cell_state_diff_2, st_memory_2, 3, [8, 8, 8], (3, 3), FLAGS.h_conv_ksize, stride=1, tln=True, trainable_last=FLAGS.trainable_last) x_g_g.append(x_generate_g) with tf.variable_scope('gating_network'): weight = network_gate(x_input, FLAGS.gating_num, moe='moe1') x_sr = tf.concat([x_generate_s, x_generate_r, x_generate_g], axis=-1) x_generate = weight * x_sr x_generate = tf.reduce_sum(x_generate, axis=-1) x_g.append(x_generate) weights.append(weight) x_g = tf.stack(x_g) x_g = tf.transpose(x_g, [1, 2, 3, 0]) f_g_s = tf.stack(f_g_s) f_g_s = tf.transpose(f_g_s, [1, 0, 2, 3, 4]) weights = tf.stack(weights) weights = tf.transpose(weights, [1, 0, 2, 3, 4]) # build a saver expert1_varlist = { v.op.name.lstrip("expert1/"): v for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="expert1/") } expert1_saver = tf.train.Saver(var_list=expert1_varlist) expert2_varlist = { v.op.name[8:]: v for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="expert2/") } expert2_saver = tf.train.Saver(var_list=expert2_varlist) expert3_varlist = { v.op.name[8:]: v for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="expert3/") } expert3_saver = tf.train.Saver(var_list=expert3_varlist) # build a gating saver gating_varlist = { v.name.lstrip("gating_network/"): v for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="gating_network/") } gating_saver = tf.train.Saver(var_list=gating_varlist, max_to_keep=100) # w time smoothness loss wdt = tf.losses.mean_squared_error(weights[:, 1:], weights[:, :-1]) # MSE loss MSE = tf.losses.mean_squared_error(x[:, :, :, FLAGS.seq_start:], x_g[:, :, :, :]) # loss func if FLAGS.training == 'all': first = tf.reduce_mean( lossfunc(x[:, :, :, FLAGS.seq_start:] - x_g[:, :, :, :], alpha=tf.constant(0.0), scale=tf.constant(FLAGS.robust_x))) second = tf.reduce_mean( lossfunc(-tf.log(weights + 1e-10) * weights, alpha=tf.constant(0.0), scale=tf.constant(0.3))) third = flow_const(f_g_s, FLAGS.time_smo, FLAGS.smo, FLAGS.mag, robust=True) fourth = tf.reduce_mean( lossfunc(weights[:, 1:] - weights[:, :-1], alpha=tf.constant(0.0), scale=tf.constant(0.75))) loss = first + FLAGS.w_ent * second + third + FLAGS.w_time_smo * fourth else: first = tf.reduce_mean( lossfunc(x[:, :, :, FLAGS.seq_start:] - x_g[:, :, :, :], alpha=tf.constant(0.0), scale=tf.constant(FLAGS.robust_x))) second = tf.reduce_mean( lossfunc(-tf.log(weights + 1e-10) * weights, alpha=tf.constant(0.0), scale=tf.constant(0.3))) fourth = tf.reduce_mean( lossfunc(weights[:, 1:] - weights[:, :-1], alpha=tf.constant(0.0), scale=tf.constant(0.75))) loss = first + FLAGS.w_ent * second + FLAGS.w_time_smo * fourth if FLAGS.training == 'gating': train_var = list( tf.get_collection(tf.GraphKeys.VARIABLES, scope="gating_network/")) elif FLAGS.training == 'all': train_var = list(tf.get_collection(tf.GraphKeys.VARIABLES)) train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss, var_list=train_var) # List of all varables variables = tf.global_variables() # strat rinning operations on Graph sess = tf.Session() init = tf.global_variables_initializer() print('init netwrok from scratch....') sess.run(init) expert1_saver.restore(sess, model1) expert2_saver.restore(sess, model2) expert3_saver.restore(sess, model3) # restore gating saver if FLAGS.restore_gating: gating_saver.restore(sess, model4) # restore all saver all_saver = tf.train.Saver(max_to_keep=100) if FLAGS.restore_all: all_saver.restore(sess, model5) Loss_MSE, Loss_MSE_v = [], [] all_Loss, all_Loss_v = [], [] np.random.seed(2020) # train for epoch in range(FLAGS.num_epoch): loss_MSE, loss_MSE_v = [], [] loss_all_epoch, loss_all_epoch_v = [], [] sff_idx = np.random.permutation(train_data.shape[0]) sff_idx_v = np.random.permutation(valid_data.shape[0]) # train for idx in range(0, train_data.shape[0], FLAGS.batch_size): if idx + FLAGS.batch_size < train_data.shape[0]: batch_x = train_data[sff_idx[idx:idx + FLAGS.batch_size]] batch_x = batch_x.transpose(0, 2, 3, 1) __, train_loss_all, train_mse = sess.run( [train_op, loss, MSE], feed_dict={x: batch_x}) loss_MSE.append(train_mse) loss_all_epoch.append(train_loss_all) # validation for idx in range(0, valid_data.shape[0], FLAGS.batch_size): if idx + FLAGS.batch_size < valid_data.shape[0]: batch_x = valid_data[sff_idx_v[idx:idx + FLAGS.batch_size]] batch_x = batch_x.transpose(0, 2, 3, 1) valid_all_loss, valid_mse = sess.run([loss, MSE], feed_dict={x: batch_x}) loss_MSE_v.append(valid_mse) loss_all_epoch_v.append(valid_all_loss) Loss_MSE.append(np.mean(loss_MSE)) Loss_MSE_v.append(np.mean(loss_MSE_v)) all_Loss.append(np.mean(loss_all_epoch)) all_Loss_v.append(np.mean(loss_all_epoch_v)) print('epoch, MSE, valid_MSE:{} {} {}'.format(epoch, Loss_MSE[-1], Loss_MSE_v[-1])) if (epoch + 1) % 10 == 0 or epoch == 0: checkpoint_path = os.path.join(FLAGS.train_dir, 'STMoE-1_{}'.format(FLAGS.training)) if FLAGS.training == 'gating': print('save gating saver') gating_saver.save(sess, checkpoint_path, global_step=epoch + 1) elif FLAGS.training == 'all': print('save all saver') all_saver.save(sess, checkpoint_path, global_step=epoch + 1)
def estimator(x1, x2, flow, train=True, trainable=True, reuse=None, regularizer=None, name='estimator'): """Estimator network.""" # warp x2 according to flow x_shape = get_shape(x1, train=train) height = x_shape[1] width = x_shape[2] # channel = x_shape[3] channel = x1.get_shape().as_list()[-1] x2_warp = tf_warp(x2, flow, height, width) # ---------------cost volume----------------- # normalize x1 = tf.nn.l2_normalize(x1, axis=3) x2_warp = tf.nn.l2_normalize(x2_warp, axis=3) d = 9 x2_patches = tf.extract_image_patches( x2_warp, [1, d, d, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='SAME') x2_patches = tf.reshape(x2_patches, [-1, height, width, d, d, channel]) x1_reshape = tf.reshape(x1, [-1, height, width, 1, 1, channel]) # get symmetric positive definite kernel matrix with tf.variable_scope(name, reuse=reuse, regularizer=regularizer): # obtain orthogonal matrix raw_P = tf.get_variable( 'raw_P', shape=[channel, channel], initializer=tf.keras.initializers.Identity(), dtype=tf.float32) raw_P_upper = tf.matrix_band_part(raw_P, 0, -1) skew_P = (raw_P_upper - tf.transpose(raw_P_upper)) / 2 # Cayley transformation, W is in Special Orthogonal Group SO(n) P = tf.matmul((tf.eye(channel) + skew_P), tf.linalg.inv(tf.eye(channel) - skew_P)) # obtain the diagonal matrix with positive numbers raw_D = tf.get_variable( 'raw_D', shape=[channel,], initializer=tf.zeros_initializer(), dtype=tf.float32) trans_D = tf.atan(raw_D) * 2 / math.pi D = tf.matrix_diag(tf.div(1 + trans_D, 1 - trans_D)) # the symmetric positive definite kernal matrix is W = tf.matmul(tf.matmul(tf.transpose(P), D), P) x1_dot_x2 = tf.multiply(tf.tensordot(x1_reshape, W, axes=[-1,0]), x2_patches) cost_volume = tf.reduce_sum(x1_dot_x2, axis=-1) cost_volume = tf.reshape(cost_volume, [-1, height, width, d * d]) # --------------estimator network------------- net_input = tf.concat([cost_volume, x1, flow], axis=-1) with tf.variable_scope(name, reuse=reuse, regularizer=regularizer): with slim.arg_scope([slim.conv2d], activation_fn=lrelu, kernel_size=3, padding='SAME', trainable=trainable): net = {} net['conv1'] = slim.conv2d(net_input, 128, scope='conv1') net['conv2'] = slim.conv2d(net['conv1'], 128, scope='conv2') net['conv3'] = slim.conv2d(net['conv2'], 96, scope='conv3') net['conv4'] = slim.conv2d(net['conv3'], 64, scope='conv4') net['conv5'] = slim.conv2d(net['conv4'], 32, scope='conv5') net['conv6'] = slim.conv2d( net['conv5'], 2, activation_fn=None, scope='conv6') # flow_estimated = net['conv6'] return net