def main(argv=None):
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    if not tf.gfile.Exists(FLAGS.checkpoint_path):
        tf.gfile.MkDir(FLAGS.checkpoint_path)
    else:
        if not FLAGS.restore:
            tf.gfile.DeleteRecursively(FLAGS.checkpoint_path)
            tf.gfile.MkDir(FLAGS.checkpoint_path)
    # Placeholder needed by the new model >>>>>>>>>>>>>>> Output new score maps, and geometry maps
    input_feat_maps = tf.placeholder(
        tf.float32,
        shape=[cfg_flow.batch_size * cfg_flow.num_steps, 128, 128, 32],
        name='input_images')
    input_flow_maps = tf.placeholder(
        tf.float32,
        shape=[cfg_flow.batch_size * cfg_flow.num_steps, 128, 128, 2],
        name='input_flow_maps')
    input_score_maps = tf.placeholder(tf.float32,
                                      shape=[cfg_flow.batch_size, 128, 128, 1],
                                      name='input_score_maps')
    if FLAGS.geometry == 'RBOX':
        input_geo_maps = tf.placeholder(
            tf.float32,
            shape=[cfg_flow.batch_size, 128, 128, 5],
            name='input_geo_maps')
    else:
        input_geo_maps = tf.placeholder(
            tf.float32,
            shape=[cfg_flow.batch_size, 128, 128, 8],
            name='input_geo_maps')
    input_training_masks = tf.placeholder(
        tf.float32,
        shape=[cfg_flow.batch_size, 128, 128, 1],
        name='input_training_masks')

    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)
    learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                               global_step,
                                               decay_steps=5000,
                                               decay_rate=0.5,
                                               staircase=True)
    # add summary on learning rate
    tf.summary.scalar('learning_rate', learning_rate)
    opt = tf.train.AdamOptimizer(learning_rate)
    # opt = tf.train.MomentumOptimizer(learning_rate, 0.9)

    # split only among the features, flow maps, and the GT labels
    input_feature_split = tf.split(input_feat_maps, len(gpus))
    input_score_maps_split = tf.split(input_score_maps, len(gpus))
    input_geo_maps_split = tf.split(input_geo_maps, len(gpus))
    input_training_masks_split = tf.split(input_training_masks, len(gpus))
    input_flow_maps_split = tf.split(input_flow_maps, len(gpus))

    tower_grads = []
    reuse_variables = None
    tvars = []
    for i, gpu_id in enumerate(gpus):
        with tf.device('/gpu:%d' % gpu_id):
            with tf.name_scope('model_%d' % gpu_id) as scope:
                iis = input_feature_split[i]
                ifms = input_flow_maps_split[i]
                isms = input_score_maps_split[i]
                igms = input_geo_maps_split[i]
                itms = input_training_masks_split[i]
                # create separate graph in different device
                total_loss, model_loss = tower_loss(iis, ifms, isms, igms,
                                                    itms, reuse_variables)
                batch_norm_updates_op = tf.group(
                    *tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
                reuse_variables = True
                tvar1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                          scope='tiny_embed')
                tvar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                          scope='pred_module')
                tvars = tvar1 + tvar2
                grads = opt.compute_gradients(total_loss, tvars)
                tower_grads.append(grads)
    # #>>>>>>>>>>>>>>>>>>>>>>>>>>>> collect gradients from different devices and get the averaged gradient, large batch size
    grads = average_gradients(tower_grads)
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

    summary_op = tf.summary.merge_all()
    # save moving average
    variable_averages = tf.train.ExponentialMovingAverage(
        FLAGS.moving_average_decay, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())
    # batch norm updates
    with tf.control_dependencies(
        [variables_averages_op, apply_gradient_op, batch_norm_updates_op]):
        train_op = tf.no_op(name='train_op')

    saver = tf.train.Saver(tf.global_variables())
    summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_path,
                                           tf.get_default_graph())

    init = tf.global_variables_initializer()
    g = tf.get_default_graph()
    print("Step 2:Get default graph!")
    # GPUtil.showUtilization()
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.per_process_gpu_memory_fraction = 1
    config.gpu_options.allow_growth = True
    with g.as_default():
        sess1 = tf.Session(config=config)
        if FLAGS.restore:
            print('continue training from previous checkpoint')
            ckpt = FLAGS.prev_checkpoint_path + 'model.ckpt-14301'
            saver.restore(sess1, ckpt)
        else:
            sess1.run(init)
    east_net = model_flow_east.EAST(mode='test', options=east_opts)
    print("Step 3: EAST model has been reconstructed")
    GPUtil.showUtilization()
    data_generator = icdar.get_batch_flow(num_workers=FLAGS.num_readers,
                                          config=cfg_flow,
                                          is_training=True)
    start = time.time()
    # Start training
    print('Step 4: start training!!!')
    for step in range(FLAGS.max_steps):
        data = next(data_generator)
        east_feed = np.reshape(data[0], [-1, 512, 512, 3])
        # data for flow net
        center_frame = np.array(data[0])[:, 1, :, :, :][:, np.newaxis, :, :, :]
        # >>>>>>>>>>>>>>>>>>>>>>>>>>> feature extraction with EAST >>>>>>>>>>>>>>>>>>>>>>>> #
        # sometimes we need to run several rounds
        feature_stack = []
        flow_maps_stack = []
        # >>>>>>>>>>>>>>>>>>>>>>>>>>> flow estimation with PWCnet >>>>>>>>>>>>>>>>>>>>>>>>> #
        # x: [batch_size,3, H,W] uint8; output: [batch_size,2,H,W] float32
        center_frame_rep = np.reshape(
            np.tile(center_frame, (1, cfg_flow.num_steps, 1, 1, 1)),
            [-1, 512, 512, 3])
        MINI_BATCH = 6
        rounds = int(east_feed.shape[0] / MINI_BATCH)
        print("rounds number is %d" % MINI_BATCH)
        for rr in range(rounds):
            torch.cuda.current_device()
            torch.cuda.set_device(0)
            tensorFirst = torch.FloatTensor(
                center_frame_rep[rr * MINI_BATCH:(rr + 1) *
                                 MINI_BATCH].transpose(0, 3, 1, 2).astype(
                                     np.float32) * (1.0 / 255.0))
            tensorSecond = torch.FloatTensor(
                east_feed[rr * MINI_BATCH:(rr + 1) * MINI_BATCH].transpose(
                    0, 3, 1, 2).astype(np.float32) * (1.0 / 255.0))
            print("Step 4-1: flow estimation ready to begin!")
            # GPUtil.showUtilization()
            tensorOutput = estimate(tensorFirst, tensorSecond)
            flow_maps_stack.append(
                tensorOutput.numpy().transpose(0, 2, 3, 1)[:, 1::4, 1::4, :] /
                4)
        torch.cuda.empty_cache()
        flow_maps = np.concatenate(flow_maps_stack, axis=0)
        print('Step 5: optical-flow maps done!!!')
        GPUtil.showUtilization()
        MINI_BATCH = 12
        rounds = int(east_feed.shape[0] / MINI_BATCH)
        for r in range(rounds):
            feature_stack.append(
                east_net.sess.run(
                    [east_net.y_hat_test_tnsr],
                    feed_dict={
                        east_net.x_tnsr:
                        east_feed[r * MINI_BATCH:(r + 1) * MINI_BATCH, :, :, :]
                    })[0][0])
        feature_maps = np.concatenate(feature_stack, axis=0)
        print('Step 6: feature maps done!!!')
        # GPUtil.showUtilization()
        # display_img_pairs_w_flows(img_pairs, pred_labels)
        with g.as_default():
            ml, tl, _ = sess1.run(
                [model_loss, total_loss, train_op],
                feed_dict={
                    input_feat_maps: feature_maps,
                    input_score_maps: data[1],
                    input_geo_maps: data[2],
                    input_training_masks: data[3],
                    input_flow_maps: flow_maps
                })
            if np.isnan(tl):
                print('Loss diverged, stop training')
                break
            if step % 10 == 0:
                avg_time_per_step = (time.time() - start) / 10
                avg_examples_per_second = (10 * FLAGS.batch_size_per_gpu *
                                           len(gpus)) / (time.time() - start)
                start = time.time()
                print(
                    'Step {:06d}, model loss {:.4f}, total loss {:.4f}, {:.2f} seconds/step, {:.2f} examples/second'
                    .format(step, ml, tl, avg_time_per_step,
                            avg_examples_per_second))

            if step % FLAGS.save_checkpoint_steps == 0:
                saver.save(sess1,
                           FLAGS.checkpoint_path + 'model.ckpt',
                           global_step=global_step)

            if step % FLAGS.save_summary_steps == 0:
                _, tl, summary_str = sess1.run(
                    [train_op, total_loss, summary_op],
                    feed_dict={
                        input_feat_maps: feature_maps,
                        input_score_maps: data[1],
                        input_geo_maps: data[2],
                        input_training_masks: data[3],
                        input_flow_maps: flow_maps
                    })
                summary_writer.add_summary(summary_str, global_step=step)
def main(argv=None):
    m_cfg = sys_cfg()
    config = get_config(FLAGS)
    config.batch_size = FLAGS.batch_size_per_gpu * FLAGS.num_gpus
    config.num_layers = 3
    config.num_steps = 5
    #
    eval_config = get_config(FLAGS)
    eval_config.batch_size = 2
    eval_config.num_layers = 3
    eval_config.num_steps = 5
    #============================ I. Model options ==============================#
    #>>>>>>>>>>>>>>>for PWCnet module network
    nn_opts = deepcopy(_DEFAULT_PWCNET_VAL_OPTIONS)
    if FLAGS.flownet_type is 'small':
        nn_opts['use_dense_cx'] = False
        nn_opts['use_res_cx'] = False
        nn_opts['pyr_lvls'] = 6
        nn_opts['flow_pred_lvl'] = 2
        nn_opts[
            'ckpt_path'] = '/work/cascades/lxiaol9/ARC/PWC/checkpoints/pwcnet-sm-6-2-multisteps-chairsthingsmix/pwcnet.ckpt-592000'  # Model to eval
    else:
        nn_opts['use_dense_cx'] = True
        nn_opts['use_res_cx'] = True
        nn_opts['pyr_lvls'] = 6
        nn_opts['flow_pred_lvl'] = 2
        nn_opts[
            'ckpt_path'] = '/work/cascades/lxiaol9/ARC/PWC/checkpoints/pwcnet-lg-6-2-multisteps-chairsthingsmix/pwcnet.ckpt-595000'

    nn_opts['verbose'] = True
    nn_opts['batch_size'] = 32  # This is Batch_size per GPU(16*4/2/2 = 16)
    nn_opts[
        'use_tf_data'] = False  # Don't use tf.data reader for this simple task
    nn_opts['gpu_devices'] = ['/device:GPU:0', '/device:GPU:1']  #
    nn_opts['controller'] = '/device:CPU:0'  # Evaluate on CPU or GPU?
    nn_opts['adapt_info'] = (1, 436, 1024, 2)
    nn_opts['x_shape'] = [2, 512, 512,
                          3]  # image pairs input shape [2, H, W, 3]
    nn_opts['y_shape'] = [512, 512, 2]  # u,v flows output shape [H, W, 2]
    #>>>>>>>>>>>>>>>> For EAST module network
    east_opts = {
        'verbose': True,
        'ckpt_path': FLAGS.pretrained_model_path,
        'batch_size': 40,
        'batch_size_per_gpu': 20,
        'gpu_devices': ['/device:GPU:0', '/device:GPU:1'],
        # controller device to put the model's variables on (usually, /cpu:0 or /gpu:0 -> try both!)
        'controller': '/device:CPU:0',
        'x_dtype': tf.float32,  # image pairs input type
        'x_shape': [512, 512, 3],  # image pairs input shape [2, H, W, 3]
        'y_score_shape': [128, 128, 1],  # u,v flows output type
        'y_geometry_shape': [128, 128, 5],  # u,v flows output shape [H, W, 2]
        'x_mask_shape': [128, 128, 1]
    }
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    #=============================== II. building graph for east + agg =================================#
    # 1.1 Input placeholders
    batch_size = FLAGS.batch_size_per_gpu * FLAGS.num_gpus
    len_seq = FLAGS.num_steps
    # input_images = tf.placeholder(tf.float32, shape=[batch_size*len_seq, 512, 512, 3], name='input_images')
    input_feat_maps = tf.placeholder(tf.float32,
                                     shape=[batch_size, len_seq, 128, 128, 32],
                                     name='input_feature_maps')
    input_flow_maps = tf.placeholder(
        tf.float32,
        shape=[batch_size, len_seq - 1, 128, 128, 2],
        name='input_flow_maps')
    input_score_maps = tf.placeholder(tf.float32,
                                      shape=[batch_size, len_seq, 128, 128, 1],
                                      name='input_score_maps')
    if FLAGS.geometry == 'RBOX':
        input_geo_maps = tf.placeholder(
            tf.float32,
            shape=[batch_size, len_seq, 128, 128, 5],
            name='input_geo_maps')
    else:
        input_geo_maps = tf.placeholder(
            tf.float32,
            shape=[batch_size, len_seq, 128, 128, 8],
            name='input_geo_maps')
    input_training_masks = tf.placeholder(
        tf.float32,
        shape=[batch_size, len_seq, 128, 128, 1],
        name='input_training_masks')
    # 1.2 lr & opt
    global_step = tf.get_variable('global_step', [],
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)
    learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                               global_step,
                                               decay_steps=500,
                                               decay_rate=0.8,
                                               staircase=True)
    opt = tf.train.AdamOptimizer(learning_rate)
    # 1.3 add summary
    tf.summary.scalar('learning_rate', learning_rate)
    # tf.summary.image('input_images', input_images[2:20:5, :, :, :])
    # 1.4 build graph in tf
    # input_images_split     = tf.split(input_images, FLAGS.num_gpus)
    input_flow_maps_split = tf.split(input_flow_maps, FLAGS.num_gpus)
    input_feature_split = tf.split(input_feat_maps, FLAGS.num_gpus)
    input_score_maps_split = tf.split(input_score_maps, FLAGS.num_gpus)
    input_geo_maps_split = tf.split(input_geo_maps, FLAGS.num_gpus)
    input_training_masks_split = tf.split(input_training_masks, FLAGS.num_gpus)

    tower_grads = []
    reuse_variables = None
    tvars = []
    gpus = list(range(len(FLAGS.gpu_list.split(','))))
    for i, gpu_id in enumerate(gpus):
        with tf.device('/gpu:%d' % gpu_id):
            with tf.name_scope('model_%d' % gpu_id) as scope:
                iis = input_feature_split[i]
                ifms = input_flow_maps_split[i]
                isms = input_score_maps_split[i]
                igms = input_geo_maps_split[i]
                itms = input_training_masks_split[i]
                # model changed to recurrent one, we only need the recurrent loss returned
                total_loss, model_loss = model_gru_agg.tower_loss(
                    iis,
                    ifms,
                    isms,
                    igms,
                    itms,
                    gpu_id=gpu_id,
                    config=config,
                    reuse_variables=reuse_variables)
                batch_norm_updates_op = tf.group(
                    *tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
                reuse_variables = True
                # tvar1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='tiny_embed')
                # tvar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='pred_module')
                # tvars = tvar1 + tvar2
                # , var_list=tvars
                grads = opt.compute_gradients(total_loss)
                tower_grads.append(grads)
    # 1.5 gradient parsering
    grads = average_gradients(tower_grads)
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
    # 1.6 get training operations
    summary_op = tf.summary.merge_all()
    # variable_averages = tf.train.ExponentialMovingAverage(
    #     FLAGS.moving_average_decay, global_step)
    # variables_averages_op = variable_averages.apply(tf.trainable_variables())
    with tf.control_dependencies([apply_gradient_op, batch_norm_updates_op]):
        train_op = tf.no_op(name='train_op')
    # 1.8 Saver & Session & Restore
    saver = tf.train.Saver(tf.global_variables())
    # sv = tf.train.Supervisor()
    summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_path,
                                           tf.get_default_graph())
    init = tf.global_variables_initializer()
    g = tf.get_default_graph()
    with g.as_default():
        config1 = tf.ConfigProto()
        config1.gpu_options.allow_growth = True
        config1.allow_soft_placement = True
        sess1 = tf.Session(config=config1)
        if FLAGS.restore:
            print('continue training from previous checkpoint')
            ckpt = FLAGS.prev_checkpoint_path
            saver.restore(sess1, ckpt)
        else:
            sess1.run(init)
            var_list1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                          scope='multi_rnn_cell')
            # var_list2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='pred_module')
            var_list_part1 = var_list1
            saver_alter1 = tf.train.Saver(
                {v.op.name: v
                 for v in var_list_part1})
            # # var_list3 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='tiny_embed')
            # # var_list4 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='pred_module')
            # # var_list_part2 = var_list3 + var_list4
            # # saver_alter2 = tf.train.Saver({v.op.name: v for v in var_list_part2})
            print('continue training from previous weights')
            ckpt1 = FLAGS.prev_checkpoint_path
            print('Restore from {}'.format(ckpt1))
            saver_alter1.restore(sess1, ckpt1)
            # # print('continue training from previous Flow weights')
            # # ckpt2 = FLAGS.prev_checkpoint_path
            # # print('Restore from {}'.format(ckpt2))
            # # saver_alter2.restore(sess1, ckpt2)


#============================= III. Other necessary componets before training =============================#
    print("Step 1: AGG model has been reconstructed")
    GPUtil.showUtilization()
    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>> EAST model >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> #
    east_net = model_flow_east.EAST(mode='test', options=east_opts)
    print("Step 2: EAST model has been reconstructed")
    GPUtil.showUtilization()
    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>> PWCnet model >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>#
    nn = ModelPWCNet(mode='test', options=nn_opts)
    print("Step 3: PWC model has been reconstructed")
    GPUtil.showUtilization()
    train_data_generator = icdar_smart.get_batch_seq(
        num_workers=FLAGS.num_readers, config=config, is_training=True)
    # val_data_generator = icdar.get_batch_seq(num_workers=FLAGS.num_readers, config=eval_config, is_training=False)
    start = time.time()
    #============================= IV. Training over Steps(!!!)================================================#
    print("Now we're starting training!!!")
    if not tf.gfile.Exists(FLAGS.checkpoint_path):
        tf.gfile.MkDir(FLAGS.checkpoint_path)
    else:
        if not FLAGS.restore:
            tf.gfile.DeleteRecursively(FLAGS.checkpoint_path)
            tf.gfile.MkDir(FLAGS.checkpoint_path)
    for step in range(FLAGS.max_steps):
        #>>>>>>>>>>>>> data
        if FLAGS.mode == "debug":
            data = []
            data.append(
                np.ones((config.batch_size, FLAGS.num_steps, 512, 512, 3),
                        dtype=np.float32))
            data.append(
                np.ones((batch_size, len_seq, 128, 128, 1), dtype=np.float32))
            data.append(
                np.ones((batch_size, len_seq, 128, 128, 5), dtype=np.float32))
            data.append(
                np.ones((batch_size, len_seq, 128, 128, 1), dtype=np.float32))
        else:
            data = next(train_data_generator)

        if step < 10:
            print("Data ready!!!")
            plt.figure(dpi=300)
            ax = plt.subplot(121)
            plt.imshow(data[1][0][0, :, :, 0])
            plt.title("score map")
            ax = plt.subplot(122)
            plt.imshow(data[3][0][0, :, :, 0] * 255)
            plt.title("training mask")
            print("saving figure")
            plt.savefig("/home/lxiaol9/debug/running/" + str(step) + ".png")
        east_feed = np.reshape(data[0], [-1, 512, 512, 3])
        target_frame = np.reshape(
            np.array(data[0])[:, 0:4, :, :, :], [-1, 512, 512, 3])
        source_frame = np.reshape(
            np.array(data[0])[:, 1:5, :, :, :], [-1, 512, 512, 3])
        flow_feed = np.concatenate((source_frame[:, np.newaxis, :, :, :],
                                    target_frame[:, np.newaxis, :, :, :]),
                                   axis=1)
        flow_maps_stack = []
        # >>>>>>>>>>>>>>>>>>>>>>>>>>> feature extraction with EAST >>>>>>>>>>>>>>>>>>>>>>>> #
        rounds = int(east_feed.shape[0] / east_opts['batch_size'])
        feature_stack = []
        flow_maps_stack = []
        for r in range(rounds):
            feature_stack.append(
                east_net.sess.run(
                    [east_net.y_hat_test_tnsr],
                    feed_dict={
                        east_net.x_tnsr:
                        east_feed[r * east_opts['batch_size']:(r + 1) *
                                  east_opts['batch_size'], :, :, :]
                    })[0][0])
        feature_maps = np.concatenate(feature_stack, axis=0)
        feature_maps_reshape = np.reshape(feature_maps,
                                          [-1, config.num_steps, 128, 128, 32])
        #>>>>>>>>>>>>>>> flow estimation with PWCnet
        # x: [batch_size,2,H,W,3] uint8; x_adapt: [batch_size,2,H,W,3] float32
        x_adapt, x_adapt_info = nn.adapt_x(flow_feed)
        if x_adapt_info is not None:
            y_adapt_info = (x_adapt_info[0], x_adapt_info[2], x_adapt_info[3],
                            2)
        else:
            y_adapt_info = None
        mini_batch = nn_opts['batch_size'] * nn.num_gpus
        rounds = int(flow_feed.shape[0] / mini_batch)
        for r in range(rounds):
            feed_dict = {
                nn.x_tnsr:
                x_adapt[r * mini_batch:(r + 1) * mini_batch, :, :, :, :]
            }
            y_hat = nn.sess.run(nn.y_hat_test_tnsr, feed_dict=feed_dict)
            if FLAGS.mode == "debug":
                print(
                    "Step 5: now finish running one round of PWCnet for flow estimation"
                )
                GPUtil.showUtilization()
            y_hats, _ = nn.postproc_y_hat_test(
                y_hat, y_adapt_info)  # suppose to be [batch, height, width, 2]
            flow_maps_stack.append(y_hats[:, 1::4, 1::4, :] / 4)
        flow_maps = np.concatenate(flow_maps_stack, axis=0)
        # print("flow maps has shape ", flow_maps.shape[:])
        flow_maps = np.reshape(flow_maps,
                               [-1, FLAGS.num_steps - 1, 128, 128, 2])
        #>>>>>>>>>>>>>>> running training session
        with g.as_default():
            ml, tl, _ = sess1.run([model_loss, total_loss, train_op], \
                                        feed_dict={input_feat_maps: feature_maps_reshape,
                                                   input_score_maps: data[1],
                                                   input_geo_maps: data[2],
                                                   input_training_masks: data[3],
                                                   input_flow_maps: flow_maps
                                                   })
            if FLAGS.mode == "debug":
                print("Step 6: running one round on training!!!")
                GPUtil.showUtilization()
            if np.isnan(tl):
                print('Loss diverged, stop training')
                break
            if step % 10 == 0:
                avg_time_per_step = (time.time() - start) / 10
                avg_examples_per_second = (10 * FLAGS.batch_size_per_gpu *
                                           len(gpus)) / (time.time() - start)
                start = time.time()
                print(
                    'Step {:06d}, model loss {:.4f}, total loss {:.4f}, {:.2f} seconds/step, {:.2f} examples/second'
                    .format(step, ml, tl, avg_time_per_step,
                            avg_examples_per_second))

            if step % FLAGS.save_checkpoint_steps == 0:
                saver.save(sess1,
                           FLAGS.checkpoint_path + 'model.ckpt',
                           global_step=global_step)

            if step % FLAGS.save_summary_steps == 0:
                _, tl, summary_str = sess1.run(
                    [train_op, total_loss, summary_op],
                    feed_dict={
                        input_feat_maps: feature_maps_reshape,
                        input_score_maps: data[1],
                        input_geo_maps: data[2],
                        input_training_masks: data[3],
                        input_flow_maps: flow_maps
                    })
                summary_writer.add_summary(summary_str, global_step=step)
Example #3
0
def main(argv=None):
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    if not tf.gfile.Exists(FLAGS.checkpoint_path):
        tf.gfile.MkDir(FLAGS.checkpoint_path)
    else:
        if not FLAGS.restore:
            tf.gfile.DeleteRecursively(FLAGS.checkpoint_path)
            tf.gfile.MkDir(FLAGS.checkpoint_path)
    # Placeholder needed by the new model >>>>>>>>>>>>>>> Output new score maps, and geometry maps
    input_feat_maps = tf.placeholder(tf.float32, shape=[cfg_flow.batch_size * cfg_flow.num_steps, 128, 128, 32], name='input_images')
    input_flow_maps = tf.placeholder(tf.float32, shape=[cfg_flow.batch_size * cfg_flow.num_steps , 128, 128, 2], name='input_flow_maps')
    input_score_maps = tf.placeholder(tf.float32, shape=[cfg_flow.batch_size, 128, 128, 1], name='input_score_maps')
    if FLAGS.geometry == 'RBOX':
        input_geo_maps = tf.placeholder(tf.float32, shape=[cfg_flow.batch_size, 128, 128, 5], name='input_geo_maps')
    else:
        input_geo_maps = tf.placeholder(tf.float32, shape=[cfg_flow.batch_size, 128, 128, 8], name='input_geo_maps')
    input_training_masks = tf.placeholder(tf.float32, shape=[cfg_flow.batch_size, 128, 128, 1], name='input_training_masks')

    global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
    learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, decay_steps=5000, decay_rate=0.8, staircase=True)
    # add summary on learning rate
    tf.summary.scalar('learning_rate', learning_rate)
    opt = tf.train.AdamOptimizer(learning_rate)
    # opt = tf.train.MomentumOptimizer(learning_rate, 0.9)

    # split only among the features, flow maps, and the GT labels
    input_feature_split = tf.split(input_feat_maps, len(gpus))
    input_score_maps_split = tf.split(input_score_maps, len(gpus))
    input_geo_maps_split = tf.split(input_geo_maps, len(gpus))
    input_training_masks_split = tf.split(input_training_masks, len(gpus))
    input_flow_maps_split = tf.split(input_flow_maps, len(gpus))

    tower_grads = []
    reuse_variables = None
    tvars = []
    for i, gpu_id in enumerate(gpus):
        with tf.device('/gpu:%d' % gpu_id):
            with tf.name_scope('model_%d' % gpu_id) as scope:
                iis = input_feature_split[i]
                ifms = input_flow_maps_split[i]
                isms = input_score_maps_split[i]
                igms = input_geo_maps_split[i]
                itms = input_training_masks_split[i]
                # create separate graph in different device
                total_loss, model_loss = tower_loss(iis, ifms, isms, igms, itms, reuse_variables)
                batch_norm_updates_op = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
                reuse_variables = True
                tvar1 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='tiny_embed')
                tvar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='pred_module')
                tvars = tvar1 + tvar2
                grads = opt.compute_gradients(total_loss, tvars)
                tower_grads.append(grads)
    # #>>>>>>>>>>>>>>>>>>>>>>>>>>>> collect gradients from different devices and get the averaged gradient, large batch size
    grads = average_gradients(tower_grads)
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

    summary_op = tf.summary.merge_all()
    # save moving average
    variable_averages = tf.train.ExponentialMovingAverage(
        FLAGS.moving_average_decay, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())
    # batch norm updates
    with tf.control_dependencies([variables_averages_op, apply_gradient_op, batch_norm_updates_op]):
        train_op = tf.no_op(name='train_op')

    saver = tf.train.Saver(tf.global_variables())
    summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_path, tf.get_default_graph())

    init = tf.global_variables_initializer()
    g = tf.get_default_graph()
    with g.as_default():
        sess1 = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        if FLAGS.restore:
            print('continue training from previous checkpoint')
            ckpt = FLAGS.prev_checkpoint_path + 'model.ckpt-14301'
            saver.restore(sess1, ckpt)
        else:
            sess1.run(init)
            # print('continue training from previous EAST weights')
            # ckpt = FLAGS.pretrained_model_path + 'model.ckpt-56092'
            # print('Restore from {}'.format(ckpt))
            # saver_alter.restore(sess1, ckpt)
    # further initialize the weights of detection head
    # reader = pywrap_tensorflow.NewCheckpointReader(FLAGS.pretrained_model_path + 'model.ckpt-56092')
    # tensor_names=['feature_fusion/Conv_7/weights', 'feature_fusion/Conv_7/biases', 'feature_fusion/Conv_8/weights', 'feature_fusion/Conv_8/biases',
    #               'feature_fusion/Conv_9/weights', 'feature_fusion/Conv_9/biases']
    # variable_names = ['pred_module/Conv/weights', 'pred_module/Conv/biases', 'pred_module/Conv_1/weights', 'pred_module/Conv_1/biases',
    #                   'pred_module/Conv_2/weights', 'pred_module/Conv_2/biases']
    # # initialize the PWC-flow graph and weights here
    # for t in range(len(variable_names)):
    #     wt = reader.get_tensor(tensor_names[t]) # numpy array
    # # get the variables, or related rensors
    #     v1 = [var for var in tf.trainable_variables() if var.op.name==variable_names[t]]
    # # tf.assign(v1[0], w1) # won't work because you will add ops to the graph
    #     v1[0].load(wt, sess1)
    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>> EAST model >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> #
    east_net = model_flow_east.EAST(mode='test', options=east_opts)
    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>> PWCnet model >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> #
    nn = ModelPWCNet(mode='test', options=nn_opts)
    nn.print_config()
    data_generator = icdar.get_batch_flow(num_workers=FLAGS.num_readers, config=cfg_flow, is_training=True)
    start = time.time()
    # Start training
    for step in range(FLAGS.max_steps):
        data = next(data_generator)
        east_feed = np.reshape(data[0], [-1, 512, 512, 3])
        # data for flow net
        center_frame = np.array(data[0])[:, 2, :, : , :][:, np.newaxis, :, :, :]
        flow_feed_1 = np.reshape(np.tile(center_frame,(1,cfg_flow.num_steps,1,1,1)), [-1, 512, 512, 3])
        flow_feed = np.concatenate((east_feed[:, np.newaxis, :, :, :], flow_feed_1[:, np.newaxis, :, :, :]), axis = 1)
        # >>>>>>>>>>>>>>>>>>>>>>>>>>> feature extraction with EAST >>>>>>>>>>>>>>>>>>>>>>>> #
        # sometimes we need to run several rounds
        rounds = int(east_feed.shape[0]/40)
        feature_stack = []
        flow_maps_stack = []
        for r in range(rounds):
            feature_stack.append(east_net.sess.run([east_net.y_hat_test_tnsr], feed_dict={east_net.x_tnsr:east_feed[r*40:(r+1)*40, :, :, :]})[0][0])
        feature_maps = np.concatenate(feature_stack, axis=0)
        # >>>>>>>>>>>>>>>>>>>>>>>>>>> flow estimation with PWCnet >>>>>>>>>>>>>>>>>>>>>>>>> #
        # x: [batch_size,2,H,W,3] uint8; x_adapt: [batch_size,2,H,W,3] float32
        x_adapt, x_adapt_info = nn.adapt_x(flow_feed)
        if x_adapt_info is not None:
            y_adapt_info = (x_adapt_info[0], x_adapt_info[2], x_adapt_info[3], 2)
        else:
            y_adapt_info = None
        # Run the adapted samples through the network
        for r in range(rounds):
            feed_dict = {nn.x_tnsr: x_adapt[r*40:(r+1)*40, :, :, :, :]}
            y_hat = nn.sess.run(nn.y_hat_test_tnsr, feed_dict=feed_dict)
            y_hats, _ = nn.postproc_y_hat_test(y_hat, y_adapt_info)# suppose to be [batch, height, width, 2]
            flow_maps_stack.append(y_hats[:, 1::4, 1::4, :]/4)
        flow_maps = np.concatenate(flow_maps_stack, axis=0)
        # display_img_pairs_w_flows(img_pairs, pred_labels)
        with g.as_default():
            ml, tl, _ = sess1.run([model_loss, total_loss, train_op], feed_dict={input_feat_maps: feature_maps,
                                                                                input_score_maps: data[1],
                                                                                input_geo_maps: data[2],
                                                                                input_training_masks: data[3],
                                                                                input_flow_maps: flow_maps})
            if np.isnan(tl):
                print('Loss diverged, stop training')
                break
            if step % 10 == 0:
                avg_time_per_step = (time.time() - start)/10
                avg_examples_per_second = (10 * FLAGS.batch_size_per_gpu * len(gpus))/(time.time() - start)
                start = time.time()
                print('Step {:06d}, model loss {:.4f}, total loss {:.4f}, {:.2f} seconds/step, {:.2f} examples/second'.format(
                    step, ml, tl, avg_time_per_step, avg_examples_per_second))

            if step % FLAGS.save_checkpoint_steps == 0:
                saver.save(sess1, FLAGS.checkpoint_path + 'model.ckpt', global_step=global_step)

            if step % FLAGS.save_summary_steps == 0:
                _, tl, summary_str = sess1.run([train_op, total_loss, summary_op], feed_dict={input_feat_maps: feature_maps,
                                                                                              input_score_maps: data[1],
                                                                                              input_geo_maps: data[2],
                                                                                              input_training_masks: data[3],
                                                                                              input_flow_maps: flow_maps})
                summary_writer.add_summary(summary_str, global_step=step)