Example #1
0
    def compute_losses(self, batch_img1, batch_img2, batch_img3, 
            flow_fw_12, flow_bw_21, flow_fw_23, flow_bw_32,
            mask_fw_12, mask_bw_21, mask_fw_23, mask_bw_32, train=True, is_scale=True):

        img_size = get_shape(batch_img1, train=train)
        img1_warp2 = tf_warp(batch_img1, flow_bw_21['full_res'], img_size[1], img_size[2])
        img2_warp1 = tf_warp(batch_img2, flow_fw_12['full_res'], img_size[1], img_size[2])
        
        img2_warp3 = tf_warp(batch_img2, flow_bw_32['full_res'], img_size[1], img_size[2])
        img3_warp2 = tf_warp(batch_img3, flow_fw_23['full_res'], img_size[1], img_size[2])
        
        losses = {}
        
        abs_robust_mean = {}
        abs_robust_mean['no_occlusion'] = self.abs_robust_loss(batch_img1-img2_warp1, tf.ones_like(mask_fw_12)) + self.abs_robust_loss(batch_img2-img1_warp2, tf.ones_like(mask_bw_21)) + \
                                            self.abs_robust_loss(batch_img2-img3_warp2, tf.ones_like(mask_fw_23)) + self.abs_robust_loss(batch_img3-img2_warp3, tf.ones_like(mask_bw_32))
        abs_robust_mean['occlusion'] = self.abs_robust_loss(batch_img1-img2_warp1, mask_fw_12) + self.abs_robust_loss(batch_img2-img1_warp2, mask_bw_21) + \
                                            self.abs_robust_loss(batch_img2-img3_warp2, mask_fw_23) + self.abs_robust_loss(batch_img3-img2_warp3, mask_bw_32)
        losses['abs_robust_mean'] = abs_robust_mean
        
        census_loss = {}
        census_loss['no_occlusion'] = self.census_loss(batch_img1, img2_warp1, tf.ones_like(mask_fw_12), max_distance=3) + \
                    self.census_loss(batch_img2, img1_warp2, tf.ones_like(mask_bw_21), max_distance=3) + \
                    self.census_loss(batch_img2, img3_warp2, tf.ones_like(mask_fw_23), max_distance=3) + \
                    self.census_loss(batch_img3, img2_warp3, tf.ones_like(mask_bw_32), max_distance=3)
        census_loss['occlusion'] = self.census_loss(batch_img1, img2_warp1, mask_fw_12, max_distance=3) + \
                    self.census_loss(batch_img2, img1_warp2, mask_bw_21, max_distance=3) + \
                    self.census_loss(batch_img2, img3_warp2, mask_fw_23, max_distance=3) + \
                    self.census_loss(batch_img3, img2_warp3, mask_bw_32, max_distance=3)
        losses['census'] = census_loss
        
        return losses
Example #2
0
def estimator(x0, x1, x2, flow_fw, flow_bw, train=True, trainable=True, reuse=None, regularizer=None, name='estimator'):           
    # warp x2 according to flow
    if train:
        x_shape = x1.get_shape().as_list()
    else:
        x_shape = tf.shape(x1)
    H = x_shape[1]
    W = x_shape[2]
    channel = x_shape[3]
    x2_warp = tf_warp(x2, flow_fw, H, W)
    x0_warp = tf_warp(x0, flow_bw, H, W)
    
    # ---------------cost volume-----------------
    
    cost_volume_fw = compute_cost_volume(x1, x2_warp, H, W, channel, d=9)
    cost_volume_bw = compute_cost_volume(x1, x0_warp, H, W, channel, d=9)
    
    cv_concat_fw = tf.concat([cost_volume_fw, cost_volume_bw], -1)
    cv_concat_bw = tf.concat([cost_volume_bw, cost_volume_fw], -1)
    
    flow_concat_fw = tf.concat([flow_fw, -flow_bw], -1)
    flow_concat_bw = tf.concat([flow_bw, -flow_fw], -1)
  
    net_fw = estimator_network(x1, cv_concat_fw, flow_concat_fw, train=train, trainable=trainable, reuse=reuse, regularizer=regularizer, name=name)
    net_bw = estimator_network(x1, cv_concat_bw, flow_concat_bw, train=train, trainable=trainable, reuse=True, regularizer=regularizer, name=name)
    
    return net_fw, net_bw
def decompose(It, Vt_O, Vt_B, I_O_init, I_B_init, A_init):
    tf.reset_default_graph()
    It = tf.constant(It, tf.float32)
    Vt_O = tf.constant(Vt_O, tf.float32)
    Vt_B = tf.constant(Vt_B, tf.float32)

    I_O = tf.Variable(I_O_init, name='I_O', dtype=tf.float32)
    I_B = tf.Variable(I_B_init, name='I_B', dtype=tf.float32)

    warp_I_O = tf_warp(tf.tile(tf.expand_dims(I_O, 0), [5, 1, 1, 1]), Vt_O)
    warp_I_B = tf_warp(tf.tile(tf.expand_dims(I_B, 0), [5, 1, 1, 1]), Vt_B)
    
    g_O = spatial_gradient(tf.expand_dims(I_O, 0))
    g_B = spatial_gradient(tf.expand_dims(I_B, 0))

    A = None
    if A_init is not None:
        A = tf.Variable(A_init, name='A', dtype=tf.float32)
        warp_A = tf_warp(tf.tile(tf.expand_dims(A, 0), [5, 1, 1, 1]), Vt_O)
        residual = l1_norm(It - warp_I_O - tf.multiply(warp_A, warp_I_B))
    else:
        residual = l1_norm(It - warp_I_O - warp_I_B)
    
    loss = l1_norm(residual)
    if A is not None:
        loss += LAMBDA_1 * \
            tf.norm(spatial_gradient(tf.expand_dims(A, 0)), ord=2)**2
    loss += LAMBDA_2 * (l1_norm(g_O) +
                        l1_norm(g_B))
    loss += LAMBDA_3 * tf.norm(g_O*g_O*g_B*g_B, ord=2)**2

    loss += constraint_penalty(I_O) + \
        constraint_penalty(I_B)
    
    if A is not None:
        loss += constraint_penalty(A)

    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
    train = optimizer.minimize(loss)

    with tf.Session() as session:
        session.run(tf.initialize_all_variables())
        for step in range(1000):
            _, loss_val = session.run([train, loss])
            print("step {}:loss = {}".format(step, loss_val))
        if A is not None:
            I_O, I_B, A = session.run([I_O, I_B, A])
        else:
            I_O, I_B = session.run([I_O, I_B])
        visualize_image(I_O, 'obstruction')
        visualize_image(I_B, 'background')
        if A is not None:
            visualize_image(A, 'alpha')
    return I_O, I_B, A
Example #4
0
def occlusion(flow_fw, flow_bw):
    x_shape = tf.shape(flow_fw)
    H = x_shape[1]
    W = x_shape[2]
    flow_bw_warped = tf_warp(flow_bw, flow_fw, H, W)
    flow_fw_warped = tf_warp(flow_fw, flow_bw, H, W)
    flow_diff_fw = flow_fw + flow_bw_warped
    flow_diff_bw = flow_bw + flow_fw_warped
    mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped)
    mag_sq_bw = length_sq(flow_bw) + length_sq(flow_fw_warped)
    occ_thresh_fw = 0.01 * mag_sq_fw + 0.5
    occ_thresh_bw = 0.01 * mag_sq_bw + 0.5
    occ_fw = tf.cast(length_sq(flow_diff_fw) > occ_thresh_fw, tf.float32)
    occ_bw = tf.cast(length_sq(flow_diff_bw) > occ_thresh_bw, tf.float32)

    return occ_fw, occ_bw
Example #5
0
def estimator(x1,
              x2,
              flow,
              train=True,
              trainable=True,
              reuse=None,
              regularizer=None,
              name='estimator'):
    # warp x2 according to flow
    x_shape = get_shape(x1, train=train)
    H = x_shape[1]
    W = x_shape[2]
    channel = x_shape[3]
    x2_warp = tf_warp(x2, flow, H, W)

    # ---------------cost volume-----------------
    # normalize
    x1 = tf.nn.l2_normalize(x1, axis=3)
    x2_warp = tf.nn.l2_normalize(x2_warp, axis=3)
    d = 9
    #x2_patches = tf.extract_image_patches(x2_warp, [1, d, d, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='SAME')
    out_channels = d * d
    w = tf.eye(out_channels * channel, dtype=tf.float32)
    w = tf.reshape(w, (d, d, channel, out_channels * channel))
    x2_patches = tf.nn.conv2d(x2_warp, w, strides=[1, 1, 1, 1], padding='SAME')
    x2_patches = tf.reshape(x2_patches, [-1, H, W, d, d, channel])
    x1_reshape = tf.reshape(x1, [-1, H, W, 1, 1, channel])
    x1_dot_x2 = tf.multiply(x1_reshape, x2_patches)
    cost_volume = tf.reduce_sum(x1_dot_x2, axis=-1)
    cost_volume = tf.reshape(cost_volume, [-1, H, W, d * d])

    # --------------estimator network-------------
    net_input = tf.concat([cost_volume, x1, flow], axis=-1)
    with tf.variable_scope(name, reuse=reuse, regularizer=regularizer):
        with slim.arg_scope([slim.conv2d],
                            activation_fn=lrelu,
                            kernel_size=3,
                            padding='SAME',
                            trainable=trainable):
            net = {}
            net['conv1'] = slim.conv2d(net_input, 128, scope='conv1')
            net['conv2'] = slim.conv2d(net['conv1'], 128, scope='conv2')
            net['conv3'] = slim.conv2d(net['conv2'], 96, scope='conv3')
            net['conv4'] = slim.conv2d(net['conv3'], 64, scope='conv4')
            net['conv5'] = slim.conv2d(net['conv4'], 32, scope='conv5')
            net['conv6'] = slim.conv2d(net['conv5'],
                                       2,
                                       activation_fn=None,
                                       scope='conv6')

    #flow_estimated = net['conv6']

    return net
def estimate_motion(It, I_O, I_B, A, Vt_O_init, Vt_B_init):
    tf.reset_default_graph()
    It = tf.constant(It, tf.float32)
    I_O = tf.constant(I_O, tf.float32)
    I_B = tf.constant(I_B, tf.float32)
    if A is not None:
        A = tf.constant(A, tf.float32)

    Vt_O = tf.Variable(Vt_O_init, name='Vt_O', dtype=tf.float32)
    Vt_B = tf.Variable(Vt_B_init, name='Vt_B', dtype=tf.float32)
    
    warp_I_O = tf_warp(tf.tile(tf.expand_dims(I_O, 0), [5, 1, 1, 1]), Vt_O)
    warp_I_B = tf_warp(tf.tile(tf.expand_dims(I_B, 0), [5, 1, 1, 1]), Vt_B)

    if A is not None:
        warp_A = tf_warp(tf.tile(tf.expand_dims(A, 0), [5, 1, 1, 1]), Vt_O)
        residual = It - warp_I_O - tf.multiply(warp_A, warp_I_B)
    else:
        residual = It - warp_I_O - warp_I_B

    loss = l1_norm(residual)
    loss += LAMBDA_4 * (l1_norm(spatial_gradient(Vt_O)) +
                        l1_norm(spatial_gradient(Vt_B)))

    optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
    train = optimizer.minimize(loss)

    with tf.Session() as session:
        session.run(tf.initialize_all_variables())
        for step in range(500):
            _, loss_val = session.run([train, loss])
            print("step {}:loss = {}".format(step, loss_val))
        Vt_O, Vt_B = session.run([Vt_O, Vt_B])
        for i in range(5):
            visualize_dense_motion(Vt_O[i])
            visualize_dense_motion(Vt_B[i])
    return Vt_O, Vt_B
def train(model1, model2, model3, model4, model5, train_data, valid_data):
    """
    model1: model path for translation
    model2: model path for rotation
    model3: model path for growth/decay
    model4: model path for gating
    model5: model path for the whole STMoE model
    train_data: np.array for training
    valid_data: np.array for validation
    """
    x = tf.placeholder(
        tf.float32,
        [FLAGS.batch_size, FLAGS.height, FLAGS.width, FLAGS.seq_length])
    x_g = []
    weights = []
    f_g_s = []
    x_g_s, x_g_r, x_g_g = [], [], []

    hidden_state_1, hidden_state_diff_1, cell_state_1, cell_state_diff_1, st_memory_1 = [], [], [], [], []
    hidden_state_2, hidden_state_diff_2, cell_state_2, cell_state_diff_2, st_memory_2 = [], [], [], [], []

    for i in range(FLAGS.seq_start - 1):
        with tf.variable_scope('expert2'):
            inputs = x[:, :, :, i]
            inputs = inputs[:, :, :, tf.newaxis]
            x_generate_r, hidden_state_1, hidden_state_diff_1, cell_state_1, cell_state_diff_1, st_memory_1 = network_grow(
                inputs,
                i,
                hidden_state_1,
                hidden_state_diff_1,
                cell_state_1,
                cell_state_diff_1,
                st_memory_1,
                3, [8, 8, 8], (3, 3),
                FLAGS.h_conv_ksize,
                stride=1,
                tln=True,
                trainable_last=FLAGS.trainable_last)

        with tf.variable_scope('expert3'):
            inputs = x[:, :, :, i]
            inputs = inputs[:, :, :, tf.newaxis]
            x_generate_g, hidden_state_2, hidden_state_diff_2, cell_state_2, cell_state_diff_2, st_memory_2 = network_grow(
                inputs,
                i,
                hidden_state_2,
                hidden_state_diff_2,
                cell_state_2,
                cell_state_diff_2,
                st_memory_2,
                3, [8, 8, 8], (3, 3),
                FLAGS.h_conv_ksize,
                stride=1,
                tln=True,
                trainable_last=FLAGS.trainable_last)

    # predict recursively
    for i in range(FLAGS.seq_length - FLAGS.seq_start):
        print('frame_{}'.format(i))
        if i == 0:
            with tf.variable_scope('expert1'):
                f_generate = network_shift(x[:, :, :, i:i + FLAGS.seq_start])
                f_g_s.append(f_generate)
                last_x = x[:, :, :, FLAGS.seq_start - 1]
                x_generate_s = tf_warp(last_x[:, :, :, tf.newaxis], f_generate,
                                       FLAGS.height, FLAGS.width)
                x_generate_s = tf.reshape(
                    x_generate_s[:, :, :, 0],
                    [FLAGS.batch_size, FLAGS.height, FLAGS.width, 1])
                x_g_s.append(x_generate_s)

            with tf.variable_scope('expert2'):
                inputs = x[:, :, :, FLAGS.seq_start - 1]
                inputs = inputs[:, :, :, tf.newaxis]
                x_generate_r, hidden_state_1, hidden_state_diff_1, cell_state_1, cell_state_diff_1, st_memory_1 = network_grow(
                    inputs,
                    i + FLAGS.seq_start,
                    hidden_state_1,
                    hidden_state_diff_1,
                    cell_state_1,
                    cell_state_diff_1,
                    st_memory_1,
                    3, [8, 8, 8], (3, 3),
                    FLAGS.h_conv_ksize,
                    stride=1,
                    tln=True,
                    trainable_last=FLAGS.trainable_last)
                x_g_r.append(x_generate_r)

            with tf.variable_scope('expert3'):
                inputs = x[:, :, :, FLAGS.seq_start - 1]
                inputs = inputs[:, :, :, tf.newaxis]
                x_generate_g, hidden_state_2, hidden_state_diff_2, cell_state_2, cell_state_diff_2, st_memory_2 = network_grow(
                    inputs,
                    i + FLAGS.seq_start,
                    hidden_state_2,
                    hidden_state_diff_2,
                    cell_state_2,
                    cell_state_diff_2,
                    st_memory_2,
                    3, [8, 8, 8], (3, 3),
                    FLAGS.h_conv_ksize,
                    stride=1,
                    tln=True,
                    trainable_last=FLAGS.trainable_last)
                x_g_g.append(x_generate_g)

            with tf.variable_scope('gating_network'):
                weight = network_gate(x[:, :, :, i:i + FLAGS.seq_start],
                                      FLAGS.gating_num,
                                      moe='moe1')
                x_sr = tf.concat([x_generate_s, x_generate_r, x_generate_g],
                                 axis=-1)
                x_generate = weight * x_sr
                x_generate = tf.reduce_sum(x_generate, axis=-1)
                x_g.append(x_generate)
                weights.append(weight)

        else:
            x_gen = tf.stack(x_g)
            print(x_gen.shape)
            x_gen = tf.transpose(x_gen, [1, 2, 3, 0])

            if i < FLAGS.seq_start:
                x_input = tf.concat(
                    [x[:, :, :, i:FLAGS.seq_start], x_gen[:, :, :, :i]],
                    axis=3)
            else:
                x_input = x_gen[:, :, :, i - FLAGS.seq_start:i]

            with tf.variable_scope('expert1'):
                f_generate = network_shift(x_input)
                f_g_s.append(f_generate)
                last_x = x_g[-1]
                x_generate_s = tf_warp(last_x[:, :, :, tf.newaxis], f_generate,
                                       FLAGS.height, FLAGS.width)
                x_generate_s = tf.reshape(
                    x_generate_s[:, :, :, 0],
                    [FLAGS.batch_size, FLAGS.height, FLAGS.width, 1])
                x_g_s.append(x_generate_s)

            with tf.variable_scope('expert2'):
                inputs = x_g[-1]
                inputs = inputs[:, :, :, tf.newaxis]
                x_generate_r, hidden_state_1, hidden_state_diff_1, cell_state_1, cell_state_diff_1, st_memory_1 = network_grow(
                    inputs,
                    i + FLAGS.seq_start,
                    hidden_state_1,
                    hidden_state_diff_1,
                    cell_state_1,
                    cell_state_diff_1,
                    st_memory_1,
                    3, [8, 8, 8], (3, 3),
                    FLAGS.h_conv_ksize,
                    stride=1,
                    tln=True,
                    trainable_last=FLAGS.trainable_last)
                x_g_r.append(x_generate_r)

            with tf.variable_scope('expert3'):
                inputs = x_g[-1]
                inputs = inputs[:, :, :, tf.newaxis]
                x_generate_g, hidden_state_2, hidden_state_diff_2, cell_state_2, cell_state_diff_2, st_memory_2 = network_grow(
                    inputs,
                    i + FLAGS.seq_start,
                    hidden_state_2,
                    hidden_state_diff_2,
                    cell_state_2,
                    cell_state_diff_2,
                    st_memory_2,
                    3, [8, 8, 8], (3, 3),
                    FLAGS.h_conv_ksize,
                    stride=1,
                    tln=True,
                    trainable_last=FLAGS.trainable_last)
                x_g_g.append(x_generate_g)

            with tf.variable_scope('gating_network'):
                weight = network_gate(x_input, FLAGS.gating_num, moe='moe1')
                x_sr = tf.concat([x_generate_s, x_generate_r, x_generate_g],
                                 axis=-1)
                x_generate = weight * x_sr
                x_generate = tf.reduce_sum(x_generate, axis=-1)
                x_g.append(x_generate)
                weights.append(weight)

    x_g = tf.stack(x_g)
    x_g = tf.transpose(x_g, [1, 2, 3, 0])

    f_g_s = tf.stack(f_g_s)
    f_g_s = tf.transpose(f_g_s, [1, 0, 2, 3, 4])

    weights = tf.stack(weights)
    weights = tf.transpose(weights, [1, 0, 2, 3, 4])

    # build a saver
    expert1_varlist = {
        v.op.name.lstrip("expert1/"): v
        for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="expert1/")
    }

    expert1_saver = tf.train.Saver(var_list=expert1_varlist)

    expert2_varlist = {
        v.op.name[8:]: v
        for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="expert2/")
    }
    expert2_saver = tf.train.Saver(var_list=expert2_varlist)

    expert3_varlist = {
        v.op.name[8:]: v
        for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="expert3/")
    }
    expert3_saver = tf.train.Saver(var_list=expert3_varlist)

    # build a gating saver
    gating_varlist = {
        v.name.lstrip("gating_network/"): v
        for v in tf.get_collection(tf.GraphKeys.VARIABLES,
                                   scope="gating_network/")
    }
    gating_saver = tf.train.Saver(var_list=gating_varlist, max_to_keep=100)

    # w time smoothness loss
    wdt = tf.losses.mean_squared_error(weights[:, 1:], weights[:, :-1])

    # MSE loss
    MSE = tf.losses.mean_squared_error(x[:, :, :, FLAGS.seq_start:],
                                       x_g[:, :, :, :])

    # loss func
    if FLAGS.training == 'all':
        first = tf.reduce_mean(
            lossfunc(x[:, :, :, FLAGS.seq_start:] - x_g[:, :, :, :],
                     alpha=tf.constant(0.0),
                     scale=tf.constant(FLAGS.robust_x)))
        second = tf.reduce_mean(
            lossfunc(-tf.log(weights + 1e-10) * weights,
                     alpha=tf.constant(0.0),
                     scale=tf.constant(0.3)))
        third = flow_const(f_g_s,
                           FLAGS.time_smo,
                           FLAGS.smo,
                           FLAGS.mag,
                           robust=True)
        fourth = tf.reduce_mean(
            lossfunc(weights[:, 1:] - weights[:, :-1],
                     alpha=tf.constant(0.0),
                     scale=tf.constant(0.75)))
        loss = first + FLAGS.w_ent * second + third + FLAGS.w_time_smo * fourth
    else:
        first = tf.reduce_mean(
            lossfunc(x[:, :, :, FLAGS.seq_start:] - x_g[:, :, :, :],
                     alpha=tf.constant(0.0),
                     scale=tf.constant(FLAGS.robust_x)))
        second = tf.reduce_mean(
            lossfunc(-tf.log(weights + 1e-10) * weights,
                     alpha=tf.constant(0.0),
                     scale=tf.constant(0.3)))
        fourth = tf.reduce_mean(
            lossfunc(weights[:, 1:] - weights[:, :-1],
                     alpha=tf.constant(0.0),
                     scale=tf.constant(0.75)))
        loss = first + FLAGS.w_ent * second + FLAGS.w_time_smo * fourth

    if FLAGS.training == 'gating':
        train_var = list(
            tf.get_collection(tf.GraphKeys.VARIABLES, scope="gating_network/"))
    elif FLAGS.training == 'all':
        train_var = list(tf.get_collection(tf.GraphKeys.VARIABLES))
    train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss,
                                                         var_list=train_var)

    # List of all varables
    variables = tf.global_variables()

    # strat rinning operations on Graph
    sess = tf.Session()
    init = tf.global_variables_initializer()

    print('init netwrok from scratch....')
    sess.run(init)

    expert1_saver.restore(sess, model1)

    expert2_saver.restore(sess, model2)

    expert3_saver.restore(sess, model3)

    # restore gating saver
    if FLAGS.restore_gating:
        gating_saver.restore(sess, model4)

    # restore all saver
    all_saver = tf.train.Saver(max_to_keep=100)

    if FLAGS.restore_all:
        all_saver.restore(sess, model5)

    Loss_MSE, Loss_MSE_v = [], []
    all_Loss, all_Loss_v = [], []

    np.random.seed(2020)

    # train
    for epoch in range(FLAGS.num_epoch):
        loss_MSE, loss_MSE_v = [], []
        loss_all_epoch, loss_all_epoch_v = [], []
        sff_idx = np.random.permutation(train_data.shape[0])
        sff_idx_v = np.random.permutation(valid_data.shape[0])
        # train
        for idx in range(0, train_data.shape[0], FLAGS.batch_size):
            if idx + FLAGS.batch_size < train_data.shape[0]:
                batch_x = train_data[sff_idx[idx:idx + FLAGS.batch_size]]
                batch_x = batch_x.transpose(0, 2, 3, 1)
                __, train_loss_all, train_mse = sess.run(
                    [train_op, loss, MSE], feed_dict={x: batch_x})
                loss_MSE.append(train_mse)
                loss_all_epoch.append(train_loss_all)

        # validation
        for idx in range(0, valid_data.shape[0], FLAGS.batch_size):
            if idx + FLAGS.batch_size < valid_data.shape[0]:
                batch_x = valid_data[sff_idx_v[idx:idx + FLAGS.batch_size]]
                batch_x = batch_x.transpose(0, 2, 3, 1)
                valid_all_loss, valid_mse = sess.run([loss, MSE],
                                                     feed_dict={x: batch_x})
                loss_MSE_v.append(valid_mse)
                loss_all_epoch_v.append(valid_all_loss)

        Loss_MSE.append(np.mean(loss_MSE))
        Loss_MSE_v.append(np.mean(loss_MSE_v))
        all_Loss.append(np.mean(loss_all_epoch))
        all_Loss_v.append(np.mean(loss_all_epoch_v))
        print('epoch, MSE, valid_MSE:{} {} {}'.format(epoch, Loss_MSE[-1],
                                                      Loss_MSE_v[-1]))

        if (epoch + 1) % 10 == 0 or epoch == 0:
            checkpoint_path = os.path.join(FLAGS.train_dir,
                                           'STMoE-1_{}'.format(FLAGS.training))

            if FLAGS.training == 'gating':
                print('save gating saver')
                gating_saver.save(sess, checkpoint_path, global_step=epoch + 1)

            elif FLAGS.training == 'all':
                print('save all saver')
                all_saver.save(sess, checkpoint_path, global_step=epoch + 1)
Example #8
0
def estimator(x1,
              x2,
              flow,
              train=True,
              trainable=True,
              reuse=None,
              regularizer=None,
              name='estimator'):
  """Estimator network."""
  # warp x2 according to flow
  x_shape = get_shape(x1, train=train)
  height = x_shape[1]
  width = x_shape[2]
  # channel = x_shape[3]
  channel = x1.get_shape().as_list()[-1]
  x2_warp = tf_warp(x2, flow, height, width)

  # ---------------cost volume-----------------
  # normalize
  x1 = tf.nn.l2_normalize(x1, axis=3)
  x2_warp = tf.nn.l2_normalize(x2_warp, axis=3)
  d = 9
  x2_patches = tf.extract_image_patches(
      x2_warp, [1, d, d, 1],
      strides=[1, 1, 1, 1],
      rates=[1, 1, 1, 1],
      padding='SAME')
  x2_patches = tf.reshape(x2_patches, [-1, height, width, d, d, channel])
  x1_reshape = tf.reshape(x1, [-1, height, width, 1, 1, channel])

  # get symmetric positive definite kernel matrix
  with tf.variable_scope(name, reuse=reuse, regularizer=regularizer):
    # obtain orthogonal matrix
    raw_P = tf.get_variable(
        'raw_P',
        shape=[channel, channel],
        initializer=tf.keras.initializers.Identity(),
        dtype=tf.float32)
    raw_P_upper = tf.matrix_band_part(raw_P, 0, -1)
    skew_P = (raw_P_upper - tf.transpose(raw_P_upper)) / 2
    # Cayley transformation, W is in Special Orthogonal Group SO(n)
    P = tf.matmul((tf.eye(channel) + skew_P), tf.linalg.inv(tf.eye(channel) - skew_P))

    # obtain the diagonal matrix with positive numbers
    raw_D = tf.get_variable(
        'raw_D',
        shape=[channel,],
        initializer=tf.zeros_initializer(),
        dtype=tf.float32)
    trans_D = tf.atan(raw_D) * 2 / math.pi
    D = tf.matrix_diag(tf.div(1 + trans_D, 1 - trans_D))

    # the symmetric positive definite kernal matrix is
    W = tf.matmul(tf.matmul(tf.transpose(P), D), P)

    x1_dot_x2 = tf.multiply(tf.tensordot(x1_reshape, W, axes=[-1,0]), x2_patches)
    cost_volume = tf.reduce_sum(x1_dot_x2, axis=-1)
    cost_volume = tf.reshape(cost_volume, [-1, height, width, d * d])

  # --------------estimator network-------------
  net_input = tf.concat([cost_volume, x1, flow], axis=-1)
  with tf.variable_scope(name, reuse=reuse, regularizer=regularizer):
    with slim.arg_scope([slim.conv2d],
                        activation_fn=lrelu,
                        kernel_size=3,
                        padding='SAME',
                        trainable=trainable):
      net = {}
      net['conv1'] = slim.conv2d(net_input, 128, scope='conv1')
      net['conv2'] = slim.conv2d(net['conv1'], 128, scope='conv2')
      net['conv3'] = slim.conv2d(net['conv2'], 96, scope='conv3')
      net['conv4'] = slim.conv2d(net['conv3'], 64, scope='conv4')
      net['conv5'] = slim.conv2d(net['conv4'], 32, scope='conv5')
      net['conv6'] = slim.conv2d(
          net['conv5'], 2, activation_fn=None, scope='conv6')
  # flow_estimated = net['conv6']
  return net