def main(cfg):
    cprint("-- preparing..")
    max_iter = cfg['solver']['max_iter']
    summary_iter = cfg['solver']['summary_iter']
    save_iter = cfg['solver']['save_iter']
    ckpt_dir = os.path.join(cfg['path']['ckpt_dir'], cfg.cfg_name)
    ckpt_file = os.path.join(ckpt_dir, cfg.cfg_name)

    output_path = '../train_inter_results/'
    mkdir_p(output_path)

    tf.logging.info("-- constructing network..")
    with tf.Graph().as_default():

        with tf.device('/cpu:0'):
            with tf.name_scope('data_provider'):
                sample = {
                    'sample_radii': cfg.radii,
                    'cls': cfg.cls,
                    'x_min': cfg.min_x,
                    'y_min': cfg.min_y,
                    'z_min': cfg.min_z,
                    'x_max': cfg.max_x,
                    'y_max': cfg.max_y,
                    'z_max': cfg.max_z,
                    'num_points': cfg.num_points,
                    'CENTER_PERTURB': cfg.CENTER_PERTURB,
                    'CENTER_PERTURB_Z': cfg.CENTER_PERTURB_Z,
                    'SAMPLE_Z_MIN': cfg.sample_z_min,
                    'SAMPLE_Z_MAX': cfg.sample_z_max,
                    'QUANT_POINTS': cfg.QUANT_POINTS,
                    'QUANT_LEVEL': cfg.QUANT_LEVEL
                }
                dataset = ObjectProvider(
                    edict({
                        'batch_size': cfg.solver.batch_size,
                        'dataset': 'kitti',
                        'split': 'train',
                        'is_training': True,
                        'num_epochs': None,
                        'sample': sample
                    }))
                dataset.data_size = 100000  # FIXME. random number

            global_step = tf.get_variable(
                'global_step', [],
                initializer=tf.constant_initializer(0),
                trainable=False)
            learning_rate = _configure_learning_rate(cfg, dataset.data_size,
                                                     global_step)
            bn_decay = get_bn_decay(cfg, dataset.data_size, global_step)

            optimizer = _configure_optimizer(cfg, learning_rate)
            tf.summary.scalar('learning_rate', learning_rate)

        # Calculate the gradients for each model tower.
        towers_ph_points = []
        towers_ph_obj = []
        towers_ph_is_training = []

        tower_grads = []
        tower_losses = []
        device_scopes = []
        scope_name = 'rpn'
        with tf.variable_scope(scope_name):
            for gid in range(cfg.num_gpus):
                with tf.name_scope('gpu%d' % gid) as scope:
                    with tf.device('/gpu:%d' % gid):
                        with tf.name_scope("train_input"):
                            ph_points = tf.placeholder(tf.float32,
                                                       shape=(None,
                                                              cfg.num_points,
                                                              3))
                            ph_obj = tf.placeholder(tf.float32, shape=(None, ))

                            ph_is_training = tf.placeholder(tf.bool, shape=())

                            net = Net(ph_points=ph_points,
                                      is_training=ph_is_training,
                                      bn_decay=bn_decay,
                                      cfg=cfg)
                            net.losses(target_objs=ph_obj,
                                       cut_off=cfg.iou_cutoff,
                                       gl=global_step)

                        all_losses = tf.get_collection(tf.GraphKeys.LOSSES,
                                                       scope)
                        sum_loss = tf.add_n(all_losses)

                        for loss in all_losses:
                            tf.summary.scalar(loss.op.name, loss)
                        tf.summary.scalar("sum_loss_tower", sum_loss)

                        # Reuse variables for the next tower.
                        tf.get_variable_scope().reuse_variables()

                        # Calculate the gradients for the batch of data
                        grads = optimizer.compute_gradients(sum_loss)

                        # Keep track of the gradients across all towers.
                        tower_grads.append(grads)
                        tower_losses.append(sum_loss)
                        device_scopes.append(scope)

                        # Collect all placeholders
                        towers_ph_points.append(ph_points)
                        towers_ph_obj.append(ph_obj)
                        towers_ph_is_training.append(ph_is_training)

        total_loss = tf.add_n(tower_losses, name='total_loss')
        grads = _average_gradients(tower_grads)
        apply_gradient_ops = optimizer.apply_gradients(grads,
                                                       global_step=global_step)

        # Add histograms for trainable variables.
        for var in tf.trainable_variables():
            tf.summary.histogram(var.op.name, var)

        # Track the moving averages of all trainable variables.
        # if cfg.solver.moving_average_decay:
        with tf.name_scope('expMovingAverage'):
            variable_averages = tf.train.ExponentialMovingAverage(
                0.005, global_step)
            # cfg.solver.moving_average_decay, global_step)
            averages_op = variable_averages.apply(tf.trainable_variables())
        # else:
        # averages_op = None

        # Group all updates to into a single train op.
        train_op = tf.group(apply_gradient_ops, averages_op)
        train_tensor = control_flow_ops.with_dependencies([train_op],
                                                          total_loss,
                                                          name='train_op')

        # Create a saver
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=20)
        init = tf.global_variables_initializer()

        # =================================================================== #
        # Kicks off the training.
        # =================================================================== #\
        # GPU configuration
        if cfg.num_gpus == 0: config = tf.ConfigProto(device_count={'GPU': 0})
        else:
            config = tf.ConfigProto(allow_soft_placement=True,
                                    log_device_placement=False)

        with tf.Session(config=config) as sess:

            dataset.set_session(sess)

            # initialization / session / writer / saver
            print('initializing a network may take minutes...')
            sess.run(init)

            tf.train.start_queue_runners(sess=sess)

            train_writer, eval_writer = _set_filewriters(ckpt_dir, sess)
            merged = tf.summary.merge_all()

            ckpt_dir = os.path.join(cfg.path.ckpt_dir, cfg.cfg_name)
            weight_file = tf.train.latest_checkpoint(ckpt_dir)
            if weight_file is not None:
                saver.restore(sess, weight_file)
                tf.logging.info('%s loaded' % weight_file)
            else:
                tf.logging.info(
                    'Training from the scratch (no pre-trained weight_filets)..'
                )

            train_timer = Timer()
            print('start training...')
            stat_correct = []

            def inference(feed_dict):
                loc, conf_iou = sess.run([net.pred_locs, net.pred_conf_iou],
                                         feed_dict)
                return loc, conf_iou

            for step in range(max_iter):

                train_timer.tic()
                feed_dict = {}
                for i in range(cfg.num_gpus):
                    b_points, b_objs, b_locs = dataset.get_batch()
                    feed_dict[towers_ph_points[i]] = b_points
                    feed_dict[towers_ph_obj[i]] = b_objs
                    feed_dict[towers_ph_is_training[i]] = True

                gl, loss, _ = sess.run([global_step, total_loss, train_tensor],
                                       feed_dict=feed_dict)

                if gl % 100 == 0:
                    cprint("gl: {} loss: {:.3f}".format(gl, loss))

                train_timer.toc()
                if gl % summary_iter == 0:
                    if gl % (summary_iter * 10) == 0:
                        # Summary with run meta data
                        run_options = tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE)
                        run_metadata = tf.RunMetadata()
                        summary_str, loss, _ = sess.run(
                            [merged, total_loss, train_tensor],
                            feed_dict=feed_dict,
                            options=run_options,
                            run_metadata=run_metadata)
                        train_writer.add_run_metadata(run_metadata,
                                                      'step_{}'.format(gl), gl)
                        train_writer.add_summary(summary_str, gl)
                    else:
                        # Summary
                        summary_str = sess.run(merged, feed_dict=feed_dict)
                        train_writer.add_summary(summary_str, gl)

                    log_str = (
                        '{} Epoch: {:3d}, Step: {:4d}, Learning rate: {:.4e}, Loss: {:5.3f}\n'
                        '{:14s} Speed: {:.3f}s/iter, Remain: {}').format(
                            datetime.datetime.now().strftime('%m/%d %H:%M:%S'),
                            int(cfg.solver.batch_size * gl /
                                dataset.data_size), int(gl),
                            round(learning_rate.eval(session=sess), 6), loss,
                            '', train_timer.average_time,
                            train_timer.remain(step, max_iter))
                    print(log_str)
                    train_timer.reset()

                if gl % save_iter == 0:
                    print('{} Saving checkpoint file to: {}'.format(
                        datetime.datetime.now().strftime('%m/%d %H:%M:%S'),
                        ckpt_dir))
                    saver.save(sess, ckpt_file, global_step=global_step)

                evaluate_iter = 100000