def main(argv=None):
    m_cfg = sys_cfg()
    config = get_config(FLAGS)
    config.batch_size = FLAGS.batch_size_per_gpu * FLAGS.num_gpus
    config.num_layers = 3
    config.num_steps = 5
    #
    eval_config = get_config(FLAGS)
    eval_config.batch_size = 2
    eval_config.num_layers = 3
    eval_config.num_steps = 5
    #============================ I. Model options ==============================#
    #>>>>>>>>>>>>>>>for PWCnet module network
    nn_opts = deepcopy(_DEFAULT_PWCNET_VAL_OPTIONS)
    if FLAGS.flownet_type is 'small':
        nn_opts['use_dense_cx'] = False
        nn_opts['use_res_cx'] = False
        nn_opts['pyr_lvls'] = 6
        nn_opts['flow_pred_lvl'] = 2
        nn_opts[
            'ckpt_path'] = '/work/cascades/lxiaol9/ARC/PWC/checkpoints/pwcnet-sm-6-2-multisteps-chairsthingsmix/pwcnet.ckpt-592000'  # Model to eval
    else:
        nn_opts['use_dense_cx'] = True
        nn_opts['use_res_cx'] = True
        nn_opts['pyr_lvls'] = 6
        nn_opts['flow_pred_lvl'] = 2
        nn_opts[
            'ckpt_path'] = '/work/cascades/lxiaol9/ARC/PWC/checkpoints/pwcnet-lg-6-2-multisteps-chairsthingsmix/pwcnet.ckpt-595000'

    nn_opts['verbose'] = True
    nn_opts['batch_size'] = 32  # This is Batch_size per GPU(16*4/2/2 = 16)
    nn_opts[
        'use_tf_data'] = False  # Don't use tf.data reader for this simple task
    nn_opts['gpu_devices'] = ['/device:GPU:0', '/device:GPU:1']  #
    nn_opts['controller'] = '/device:CPU:0'  # Evaluate on CPU or GPU?
    nn_opts['adapt_info'] = (1, 436, 1024, 2)
    nn_opts['x_shape'] = [2, 512, 512,
                          3]  # image pairs input shape [2, H, W, 3]
    nn_opts['y_shape'] = [512, 512, 2]  # u,v flows output shape [H, W, 2]
    #>>>>>>>>>>>>>>>> For EAST module network
    east_opts = {
        'verbose': True,
        'ckpt_path': FLAGS.pretrained_model_path,
        'batch_size': 40,
        'batch_size_per_gpu': 20,
        'gpu_devices': ['/device:GPU:0', '/device:GPU:1'],
        # controller device to put the model's variables on (usually, /cpu:0 or /gpu:0 -> try both!)
        'controller': '/device:CPU:0',
        'x_dtype': tf.float32,  # image pairs input type
        'x_shape': [512, 512, 3],  # image pairs input shape [2, H, W, 3]
        'y_score_shape': [128, 128, 1],  # u,v flows output type
        'y_geometry_shape': [128, 128, 5],  # u,v flows output shape [H, W, 2]
        'x_mask_shape': [128, 128, 1]
    }
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    #=============================== II. building graph for east + agg =================================#
    # 1.1 Input placeholders
    batch_size = FLAGS.batch_size_per_gpu * FLAGS.num_gpus
    len_seq = FLAGS.num_steps
    # input_images = tf.placeholder(tf.float32, shape=[batch_size*len_seq, 512, 512, 3], name='input_images')
    input_feat_maps = tf.placeholder(tf.float32,
                                     shape=[batch_size, len_seq, 128, 128, 32],
                                     name='input_feature_maps')
    input_flow_maps = tf.placeholder(
        tf.float32,
        shape=[batch_size, len_seq - 1, 128, 128, 2],
        name='input_flow_maps')
    input_score_maps = tf.placeholder(tf.float32,
                                      shape=[batch_size, len_seq, 128, 128, 1],
                                      name='input_score_maps')
    if FLAGS.geometry == 'RBOX':
        input_geo_maps = tf.placeholder(
            tf.float32,
            shape=[batch_size, len_seq, 128, 128, 5],
            name='input_geo_maps')
    else:
        input_geo_maps = tf.placeholder(
            tf.float32,
            shape=[batch_size, len_seq, 128, 128, 8],
            name='input_geo_maps')
    input_training_masks = tf.placeholder(
        tf.float32,
        shape=[batch_size, len_seq, 128, 128, 1],
        name='input_training_masks')
    # 1.2 lr & opt
    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)
    learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                               global_step,
                                               decay_steps=500,
                                               decay_rate=0.8,
                                               staircase=True)
    opt = tf.train.AdamOptimizer(learning_rate)
    # 1.3 add summary
    tf.summary.scalar('learning_rate', learning_rate)
    # tf.summary.image('input_images', input_images[2:20:5, :, :, :])
    # 1.4 build graph in tf
    # input_images_split     = tf.split(input_images, FLAGS.num_gpus)
    input_flow_maps_split = tf.split(input_flow_maps, FLAGS.num_gpus)
    input_feature_split = tf.split(input_feat_maps, FLAGS.num_gpus)
    input_score_maps_split = tf.split(input_score_maps, FLAGS.num_gpus)
    input_geo_maps_split = tf.split(input_geo_maps, FLAGS.num_gpus)
    input_training_masks_split = tf.split(input_training_masks, FLAGS.num_gpus)

    tower_grads = []
    reuse_variables = None
    tvars = []
    gpus = list(range(len(FLAGS.gpu_list.split(','))))
    for i, gpu_id in enumerate(gpus):
        with tf.device('/gpu:%d' % gpu_id):
            with tf.name_scope('model_%d' % gpu_id) as scope:
                iis = input_feature_split[i]
                ifms = input_flow_maps_split[i]
                isms = input_score_maps_split[i]
                igms = input_geo_maps_split[i]
                itms = input_training_masks_split[i]
                # model changed to recurrent one, we only need the recurrent loss returned
                total_loss, model_loss = model_gru_agg.tower_loss(
                    iis,
                    ifms,
                    isms,
                    igms,
                    itms,
                    gpu_id=gpu_id,
                    config=config,
                    reuse_variables=reuse_variables)
                batch_norm_updates_op = tf.group(
                    *tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
                reuse_variables = True
                # tvar1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='tiny_embed')
                # tvar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='pred_module')
                # tvars = tvar1 + tvar2
                # , var_list=tvars
                grads = opt.compute_gradients(total_loss)
                tower_grads.append(grads)
    # 1.5 gradient parsering
    grads = average_gradients(tower_grads)
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
    # 1.6 get training operations
    summary_op = tf.summary.merge_all()
    # variable_averages = tf.train.ExponentialMovingAverage(
    #     FLAGS.moving_average_decay, global_step)
    # variables_averages_op = variable_averages.apply(tf.trainable_variables())
    with tf.control_dependencies([apply_gradient_op, batch_norm_updates_op]):
        train_op = tf.no_op(name='train_op')
    # 1.8 Saver & Session & Restore
    saver = tf.train.Saver(tf.global_variables())
    # sv = tf.train.Supervisor()
    summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_path,
                                           tf.get_default_graph())
    init = tf.global_variables_initializer()
    g = tf.get_default_graph()
    with g.as_default():
        config1 = tf.ConfigProto()
        config1.gpu_options.allow_growth = True
        config1.allow_soft_placement = True
        sess1 = tf.Session(config=config1)
        if FLAGS.restore:
            print('continue training from previous checkpoint')
            ckpt = FLAGS.prev_checkpoint_path
            saver.restore(sess1, ckpt)
        else:
            sess1.run(init)
            var_list1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                          scope='multi_rnn_cell')
            # var_list2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='pred_module')
            var_list_part1 = var_list1
            saver_alter1 = tf.train.Saver(
                {v.op.name: v
                 for v in var_list_part1})
            # # var_list3 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='tiny_embed')
            # # var_list4 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='pred_module')
            # # var_list_part2 = var_list3 + var_list4
            # # saver_alter2 = tf.train.Saver({v.op.name: v for v in var_list_part2})
            print('continue training from previous weights')
            ckpt1 = FLAGS.prev_checkpoint_path
            print('Restore from {}'.format(ckpt1))
            saver_alter1.restore(sess1, ckpt1)
            # # print('continue training from previous Flow weights')
            # # ckpt2 = FLAGS.prev_checkpoint_path
            # # print('Restore from {}'.format(ckpt2))
            # # saver_alter2.restore(sess1, ckpt2)


#============================= III. Other necessary componets before training =============================#
    print("Step 1: AGG model has been reconstructed")
    GPUtil.showUtilization()
    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>> EAST model >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> #
    east_net = model_flow_east.EAST(mode='test', options=east_opts)
    print("Step 2: EAST model has been reconstructed")
    GPUtil.showUtilization()
    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>> PWCnet model >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>#
    nn = ModelPWCNet(mode='test', options=nn_opts)
    print("Step 3: PWC model has been reconstructed")
    GPUtil.showUtilization()
    train_data_generator = icdar_smart.get_batch_seq(
        num_workers=FLAGS.num_readers, config=config, is_training=True)
    # val_data_generator = icdar.get_batch_seq(num_workers=FLAGS.num_readers, config=eval_config, is_training=False)
    start = time.time()
    #============================= IV. Training over Steps(!!!)================================================#
    print("Now we're starting training!!!")
    if not tf.gfile.Exists(FLAGS.checkpoint_path):
        tf.gfile.MkDir(FLAGS.checkpoint_path)
    else:
        if not FLAGS.restore:
            tf.gfile.DeleteRecursively(FLAGS.checkpoint_path)
            tf.gfile.MkDir(FLAGS.checkpoint_path)
    for step in range(FLAGS.max_steps):
        #>>>>>>>>>>>>> data
        if FLAGS.mode == "debug":
            data = []
            data.append(
                np.ones((config.batch_size, FLAGS.num_steps, 512, 512, 3),
                        dtype=np.float32))
            data.append(
                np.ones((batch_size, len_seq, 128, 128, 1), dtype=np.float32))
            data.append(
                np.ones((batch_size, len_seq, 128, 128, 5), dtype=np.float32))
            data.append(
                np.ones((batch_size, len_seq, 128, 128, 1), dtype=np.float32))
        else:
            data = next(train_data_generator)

        if step < 10:
            print("Data ready!!!")
            plt.figure(dpi=300)
            ax = plt.subplot(121)
            plt.imshow(data[1][0][0, :, :, 0])
            plt.title("score map")
            ax = plt.subplot(122)
            plt.imshow(data[3][0][0, :, :, 0] * 255)
            plt.title("training mask")
            print("saving figure")
            plt.savefig("/home/lxiaol9/debug/running/" + str(step) + ".png")
        east_feed = np.reshape(data[0], [-1, 512, 512, 3])
        target_frame = np.reshape(
            np.array(data[0])[:, 0:4, :, :, :], [-1, 512, 512, 3])
        source_frame = np.reshape(
            np.array(data[0])[:, 1:5, :, :, :], [-1, 512, 512, 3])
        flow_feed = np.concatenate((source_frame[:, np.newaxis, :, :, :],
                                    target_frame[:, np.newaxis, :, :, :]),
                                   axis=1)
        flow_maps_stack = []
        # >>>>>>>>>>>>>>>>>>>>>>>>>>> feature extraction with EAST >>>>>>>>>>>>>>>>>>>>>>>> #
        rounds = int(east_feed.shape[0] / east_opts['batch_size'])
        feature_stack = []
        flow_maps_stack = []
        for r in range(rounds):
            feature_stack.append(
                east_net.sess.run(
                    [east_net.y_hat_test_tnsr],
                    feed_dict={
                        east_net.x_tnsr:
                        east_feed[r * east_opts['batch_size']:(r + 1) *
                                  east_opts['batch_size'], :, :, :]
                    })[0][0])
        feature_maps = np.concatenate(feature_stack, axis=0)
        feature_maps_reshape = np.reshape(feature_maps,
                                          [-1, config.num_steps, 128, 128, 32])
        #>>>>>>>>>>>>>>> flow estimation with PWCnet
        # x: [batch_size,2,H,W,3] uint8; x_adapt: [batch_size,2,H,W,3] float32
        x_adapt, x_adapt_info = nn.adapt_x(flow_feed)
        if x_adapt_info is not None:
            y_adapt_info = (x_adapt_info[0], x_adapt_info[2], x_adapt_info[3],
                            2)
        else:
            y_adapt_info = None
        mini_batch = nn_opts['batch_size'] * nn.num_gpus
        rounds = int(flow_feed.shape[0] / mini_batch)
        for r in range(rounds):
            feed_dict = {
                nn.x_tnsr:
                x_adapt[r * mini_batch:(r + 1) * mini_batch, :, :, :, :]
            }
            y_hat = nn.sess.run(nn.y_hat_test_tnsr, feed_dict=feed_dict)
            if FLAGS.mode == "debug":
                print(
                    "Step 5: now finish running one round of PWCnet for flow estimation"
                )
                GPUtil.showUtilization()
            y_hats, _ = nn.postproc_y_hat_test(
                y_hat, y_adapt_info)  # suppose to be [batch, height, width, 2]
            flow_maps_stack.append(y_hats[:, 1::4, 1::4, :] / 4)
        flow_maps = np.concatenate(flow_maps_stack, axis=0)
        # print("flow maps has shape ", flow_maps.shape[:])
        flow_maps = np.reshape(flow_maps,
                               [-1, FLAGS.num_steps - 1, 128, 128, 2])
        #>>>>>>>>>>>>>>> running training session
        with g.as_default():
            ml, tl, _ = sess1.run([model_loss, total_loss, train_op], \
                                        feed_dict={input_feat_maps: feature_maps_reshape,
                                                   input_score_maps: data[1],
                                                   input_geo_maps: data[2],
                                                   input_training_masks: data[3],
                                                   input_flow_maps: flow_maps
                                                   })
            if FLAGS.mode == "debug":
                print("Step 6: running one round on training!!!")
                GPUtil.showUtilization()
            if np.isnan(tl):
                print('Loss diverged, stop training')
                break
            if step % 10 == 0:
                avg_time_per_step = (time.time() - start) / 10
                avg_examples_per_second = (10 * FLAGS.batch_size_per_gpu *
                                           len(gpus)) / (time.time() - start)
                start = time.time()
                print(
                    'Step {:06d}, model loss {:.4f}, total loss {:.4f}, {:.2f} seconds/step, {:.2f} examples/second'
                    .format(step, ml, tl, avg_time_per_step,
                            avg_examples_per_second))

            if step % FLAGS.save_checkpoint_steps == 0:
                saver.save(sess1,
                           FLAGS.checkpoint_path + 'model.ckpt',
                           global_step=global_step)

            if step % FLAGS.save_summary_steps == 0:
                _, tl, summary_str = sess1.run(
                    [train_op, total_loss, summary_op],
                    feed_dict={
                        input_feat_maps: feature_maps_reshape,
                        input_score_maps: data[1],
                        input_geo_maps: data[2],
                        input_training_masks: data[3],
                        input_flow_maps: flow_maps
                    })
                summary_writer.add_summary(summary_str, global_step=step)
Exemple #2
0
def main(argv=None):
    m_cfg = sys_cfg()
    #============================ I. PWCnet model options ==============================#
    nn_opts = deepcopy(_DEFAULT_PWCNET_VAL_OPTIONS)
    if FLAGS.flownet_type is 'small':
        nn_opts['use_dense_cx'] = False
        nn_opts['use_res_cx'] = False
        nn_opts['pyr_lvls'] = 6
        nn_opts['flow_pred_lvl'] = 2
        nn_opts[
            'ckpt_path'] = '/work/cascades/lxiaol9/ARC/PWC/checkpoints/pwcnet-sm-6-2-multisteps-chairsthingsmix/pwcnet.ckpt-592000'  # Model to eval
    else:
        nn_opts['use_dense_cx'] = True
        nn_opts['use_res_cx'] = True
        nn_opts['pyr_lvls'] = 6
        nn_opts['flow_pred_lvl'] = 2
        nn_opts[
            'ckpt_path'] = '/work/cascades/lxiaol9/ARC/PWC/checkpoints/pwcnet-lg-6-2-multisteps-chairsthingsmix/pwcnet.ckpt-595000'

    nn_opts['verbose'] = True
    nn_opts['batch_size'] = 10  # This is Batch_size per GPU
    nn_opts[
        'use_tf_data'] = False  # Don't use tf.data reader for this simple task
    nn_opts['gpu_devices'] = ['/device:GPU:0', '/device:GPU:1']  #
    nn_opts['controller'] = '/device:GPU:0'  # Evaluate on CPU or GPU?
    nn_opts['adapt_info'] = (1, 436, 1024, 2)
    nn_opts['x_shape'] = [2, 512, 512,
                          3]  # image pairs input shape [2, H, W, 3]
    nn_opts['y_shape'] = [512, 512, 2]  # u,v flows output shape [H, W, 2]
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    if not tf.gfile.Exists(FLAGS.checkpoint_path):
        tf.gfile.MkDir(FLAGS.checkpoint_path)
    else:
        if not FLAGS.restore:
            tf.gfile.DeleteRecursively(FLAGS.checkpoint_path)
            tf.gfile.MkDir(FLAGS.checkpoint_path)
#=============================== II. building graph for east + agg =================================#
# 1.1 Input placeholders
    batch_size = FLAGS.batch_size_per_gpu * FLAGS.num_gpus
    len_seq = FLAGS.num_steps
    input_images = tf.placeholder(tf.float32,
                                  shape=[batch_size * len_seq, 512, 512, 3],
                                  name='input_images')
    input_flow_maps = tf.placeholder(tf.float32,
                                     shape=[batch_size * len_seq, 128, 128, 2],
                                     name='input_flow_maps')
    input_score_maps = tf.placeholder(tf.float32,
                                      shape=[batch_size, 128, 128, 1],
                                      name='input_score_maps')
    if FLAGS.geometry == 'RBOX':
        input_geo_maps = tf.placeholder(tf.float32,
                                        shape=[batch_size, 128, 128, 5],
                                        name='input_geo_maps')
    else:
        input_geo_maps = tf.placeholder(tf.float32,
                                        shape=[batch_size, 128, 128, 8],
                                        name='input_geo_maps')
    input_training_masks = tf.placeholder(tf.float32,
                                          shape=[batch_size, 128, 128, 1],
                                          name='input_training_masks')
    # 1.2 lr & opt
    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)
    learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                               global_step,
                                               decay_steps=5000,
                                               decay_rate=0.8,
                                               staircase=True)
    opt = tf.train.AdamOptimizer(learning_rate)
    # 1.3 add summary
    tf.summary.scalar('learning_rate', learning_rate)
    tf.summary.image('input_images', input_images[2:20:5, :, :, :])
    # 1.4 build graph in tf
    input_images_split = tf.split(input_images, FLAGS.num_gpus)
    input_score_maps_split = tf.split(input_score_maps, FLAGS.num_gpus)
    input_geo_maps_split = tf.split(input_geo_maps, FLAGS.num_gpus)
    input_training_masks_split = tf.split(input_training_masks, FLAGS.num_gpus)
    input_flow_maps_split = tf.split(input_flow_maps, FLAGS.num_gpus)
    tower_grads = []
    reuse_variables = None
    tvars = []
    gpus = list(range(len(FLAGS.gpu_list.split(','))))
    for i, gpu_id in enumerate(gpus):
        with tf.device('/gpu:%d' % gpu_id):
            with tf.name_scope('model_%d' % gpu_id) as scope:
                iis = input_images_split[i]
                ifms = input_flow_maps_split[i]
                isms = input_score_maps_split[i]
                igms = input_geo_maps_split[i]
                itms = input_training_masks_split[i]
                total_loss, model_loss = tower_loss(iis, ifms, isms, igms,
                                                    itms, m_cfg,
                                                    reuse_variables)
                batch_norm_updates_op = tf.group(
                    *tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
                reuse_variables = True
                # tvar1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='tiny_embed')
                # tvar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='pred_module')
                # tvars = tvar1 + tvar2
                grads = opt.compute_gradients(total_loss)
                tower_grads.append(grads)
    # 1.5 gradient parsering
    grads = average_gradients(tower_grads)
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
    # 1.6 get training operations
    summary_op = tf.summary.merge_all()
    variable_averages = tf.train.ExponentialMovingAverage(
        FLAGS.moving_average_decay, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())
    with tf.control_dependencies(
        [variables_averages_op, apply_gradient_op, batch_norm_updates_op]):
        train_op = tf.no_op(name='train_op')
    # 1.8 Saver & Session & Restore
    saver = tf.train.Saver(tf.global_variables())
    summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_path,
                                           tf.get_default_graph())
    init = tf.global_variables_initializer()
    g = tf.get_default_graph()
    with g.as_default():
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        sess1 = tf.Session(config=config)
        if FLAGS.restore:
            print('continue training from previous checkpoint')
            ckpt = FLAGS.prev_checkpoint_path + '/model.ckpt-28601'
            saver.restore(sess1, ckpt)
        else:
            sess1.run(init)
            var_list1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                          scope='feature_fusion')
            var_list2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                          scope='resnet_v1_50')
            var_list_part1 = var_list1 + var_list2
            saver_alter1 = tf.train.Saver(
                {v.op.name: v
                 for v in var_list_part1})
            var_list3 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                          scope='tiny_embed')
            var_list4 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                          scope='pred_module')
            var_list_part2 = var_list3 + var_list4
            saver_alter2 = tf.train.Saver(
                {v.op.name: v
                 for v in var_list_part2})
            print('continue training from previous EAST weights')
            ckpt1 = FLAGS.pretrained_model_path
            print('Restore from {}'.format(ckpt1))
            saver_alter1.restore(sess1, ckpt1)
            print('continue training from previous Flow weights')
            ckpt2 = FLAGS.prev_checkpoint_path
            print('Restore from {}'.format(ckpt2))
            saver_alter2.restore(sess1, ckpt2)


#============================= III. Other necessary componets before training =============================#
    nn = ModelPWCNet(mode='test', options=nn_opts)
    data_generator = icdar.get_batch_flow(num_workers=FLAGS.num_readers,
                                          config=m_cfg,
                                          is_training=True)
    start = time.time()
    #============================= IV. Training over Steps(!!!)================================================#
    for step in range(FLAGS.max_steps):
        #>>>>>>>>>>>>> data
        data = next(data_generator)
        east_feed = np.reshape(data[0], [-1, 512, 512, 3])
        # data for flow net
        center_frame = np.array(data[0])[:, 2, :, :, :][:, np.newaxis, :, :, :]
        flow_feed_1 = np.reshape(
            np.tile(center_frame, (1, m_cfg.num_steps, 1, 1, 1)),
            [-1, 512, 512, 3])
        # we're calculating center frame to other frames
        flow_feed = np.concatenate((flow_feed_1[:, np.newaxis, :, :, :],
                                    east_feed[:, np.newaxis, :, :, :]),
                                   axis=1)
        flow_maps_stack = []
        #>>>>>>>>>>>>>>> flow estimation with PWCnet
        # x: [batch_size,2,H,W,3] uint8; x_adapt: [batch_size,2,H,W,3] float32
        x_adapt, x_adapt_info = nn.adapt_x(flow_feed)
        if x_adapt_info is not None:
            y_adapt_info = (x_adapt_info[0], x_adapt_info[2], x_adapt_info[3],
                            2)
        else:
            y_adapt_info = None
        mini_batch = 20
        rounds = int(flow_feed.shape[0] / mini_batch)
        for r in range(rounds):
            feed_dict = {
                nn.x_tnsr:
                x_adapt[r * mini_batch:(r + 1) * mini_batch, :, :, :, :]
            }
            y_hat = nn.sess.run(nn.y_hat_test_tnsr, feed_dict=feed_dict)
            y_hats, _ = nn.postproc_y_hat_test(
                y_hat, y_adapt_info)  # suppose to be [batch, height, width, 2]
            flow_maps_stack.append(y_hats[:, 1::4, 1::4, :] / 4)
        flow_maps = np.concatenate(flow_maps_stack, axis=0)
        #>>>>>>>>>>>>>>> running training session
        with g.as_default():
            ml, tl, _ = sess1.run([model_loss, total_loss, train_op], \
                                        feed_dict={input_images: east_feed,
                                                   input_score_maps: data[1],
                                                   input_geo_maps: data[2],
                                                   input_training_masks: data[3],
                                                   input_flow_maps: flow_maps
                                                   })

            if np.isnan(tl):
                print('Loss diverged, stop training')
                break
            if step % 10 == 0:
                avg_time_per_step = (time.time() - start) / 10
                avg_examples_per_second = (10 * FLAGS.batch_size_per_gpu *
                                           len(gpus)) / (time.time() - start)
                start = time.time()
                print(
                    'Step {:06d}, model loss {:.4f}, total loss {:.4f}, {:.2f} seconds/step, {:.2f} examples/second'
                    .format(step, ml, tl, avg_time_per_step,
                            avg_examples_per_second))

            if step % FLAGS.save_checkpoint_steps == 0:
                saver.save(sess1,
                           FLAGS.checkpoint_path + 'model.ckpt',
                           global_step=global_step)

            if step % FLAGS.save_summary_steps == 0:
                _, tl, summary_str = sess1.run(
                    [train_op, total_loss, summary_op],
                    feed_dict={
                        input_images: east_feed,
                        input_score_maps: data[1],
                        input_geo_maps: data[2],
                        input_training_masks: data[3],
                        input_flow_maps: flow_maps
                    })
                summary_writer.add_summary(summary_str, global_step=step)