def main(cfg): cprint("-- preparing..") max_iter = cfg['solver']['max_iter'] summary_iter = cfg['solver']['summary_iter'] save_iter = cfg['solver']['save_iter'] ckpt_dir = os.path.join(cfg['path']['ckpt_dir'], cfg.cfg_name) ckpt_file = os.path.join(ckpt_dir, cfg.cfg_name) output_path = '../train_inter_results/' mkdir_p(output_path) tf.logging.info("-- constructing network..") with tf.Graph().as_default(): with tf.device('/cpu:0'): with tf.name_scope('data_provider'): sample = { 'sample_radii': cfg.radii, 'cls': cfg.cls, 'x_min': cfg.min_x, 'y_min': cfg.min_y, 'z_min': cfg.min_z, 'x_max': cfg.max_x, 'y_max': cfg.max_y, 'z_max': cfg.max_z, 'num_points': cfg.num_points, 'CENTER_PERTURB': cfg.CENTER_PERTURB, 'CENTER_PERTURB_Z': cfg.CENTER_PERTURB_Z, 'SAMPLE_Z_MIN': cfg.sample_z_min, 'SAMPLE_Z_MAX': cfg.sample_z_max, 'QUANT_POINTS': cfg.QUANT_POINTS, 'QUANT_LEVEL': cfg.QUANT_LEVEL } dataset = ObjectProvider( edict({ 'batch_size': cfg.solver.batch_size, 'dataset': 'kitti', 'split': 'train', 'is_training': True, 'num_epochs': None, 'sample': sample })) dataset.data_size = 100000 # FIXME. random number global_step = tf.get_variable( 'global_step', [], initializer=tf.constant_initializer(0), trainable=False) learning_rate = _configure_learning_rate(cfg, dataset.data_size, global_step) bn_decay = get_bn_decay(cfg, dataset.data_size, global_step) optimizer = _configure_optimizer(cfg, learning_rate) tf.summary.scalar('learning_rate', learning_rate) # Calculate the gradients for each model tower. towers_ph_points = [] towers_ph_obj = [] towers_ph_is_training = [] tower_grads = [] tower_losses = [] device_scopes = [] scope_name = 'rpn' with tf.variable_scope(scope_name): for gid in range(cfg.num_gpus): with tf.name_scope('gpu%d' % gid) as scope: with tf.device('/gpu:%d' % gid): with tf.name_scope("train_input"): ph_points = tf.placeholder(tf.float32, shape=(None, cfg.num_points, 3)) ph_obj = tf.placeholder(tf.float32, shape=(None, )) ph_is_training = tf.placeholder(tf.bool, shape=()) net = Net(ph_points=ph_points, is_training=ph_is_training, bn_decay=bn_decay, cfg=cfg) net.losses(target_objs=ph_obj, cut_off=cfg.iou_cutoff, gl=global_step) all_losses = tf.get_collection(tf.GraphKeys.LOSSES, scope) sum_loss = tf.add_n(all_losses) for loss in all_losses: tf.summary.scalar(loss.op.name, loss) tf.summary.scalar("sum_loss_tower", sum_loss) # Reuse variables for the next tower. tf.get_variable_scope().reuse_variables() # Calculate the gradients for the batch of data grads = optimizer.compute_gradients(sum_loss) # Keep track of the gradients across all towers. tower_grads.append(grads) tower_losses.append(sum_loss) device_scopes.append(scope) # Collect all placeholders towers_ph_points.append(ph_points) towers_ph_obj.append(ph_obj) towers_ph_is_training.append(ph_is_training) total_loss = tf.add_n(tower_losses, name='total_loss') grads = _average_gradients(tower_grads) apply_gradient_ops = optimizer.apply_gradients(grads, global_step=global_step) # Add histograms for trainable variables. for var in tf.trainable_variables(): tf.summary.histogram(var.op.name, var) # Track the moving averages of all trainable variables. # if cfg.solver.moving_average_decay: with tf.name_scope('expMovingAverage'): variable_averages = tf.train.ExponentialMovingAverage( 0.005, global_step) # cfg.solver.moving_average_decay, global_step) averages_op = variable_averages.apply(tf.trainable_variables()) # else: # averages_op = None # Group all updates to into a single train op. train_op = tf.group(apply_gradient_ops, averages_op) train_tensor = control_flow_ops.with_dependencies([train_op], total_loss, name='train_op') # Create a saver saver = tf.train.Saver(tf.global_variables(), max_to_keep=20) init = tf.global_variables_initializer() # =================================================================== # # Kicks off the training. # =================================================================== #\ # GPU configuration if cfg.num_gpus == 0: config = tf.ConfigProto(device_count={'GPU': 0}) else: config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) with tf.Session(config=config) as sess: dataset.set_session(sess) # initialization / session / writer / saver print('initializing a network may take minutes...') sess.run(init) tf.train.start_queue_runners(sess=sess) train_writer, eval_writer = _set_filewriters(ckpt_dir, sess) merged = tf.summary.merge_all() ckpt_dir = os.path.join(cfg.path.ckpt_dir, cfg.cfg_name) weight_file = tf.train.latest_checkpoint(ckpt_dir) if weight_file is not None: saver.restore(sess, weight_file) tf.logging.info('%s loaded' % weight_file) else: tf.logging.info( 'Training from the scratch (no pre-trained weight_filets)..' ) train_timer = Timer() print('start training...') stat_correct = [] def inference(feed_dict): loc, conf_iou = sess.run([net.pred_locs, net.pred_conf_iou], feed_dict) return loc, conf_iou for step in range(max_iter): train_timer.tic() feed_dict = {} for i in range(cfg.num_gpus): b_points, b_objs, b_locs = dataset.get_batch() feed_dict[towers_ph_points[i]] = b_points feed_dict[towers_ph_obj[i]] = b_objs feed_dict[towers_ph_is_training[i]] = True gl, loss, _ = sess.run([global_step, total_loss, train_tensor], feed_dict=feed_dict) if gl % 100 == 0: cprint("gl: {} loss: {:.3f}".format(gl, loss)) train_timer.toc() if gl % summary_iter == 0: if gl % (summary_iter * 10) == 0: # Summary with run meta data run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() summary_str, loss, _ = sess.run( [merged, total_loss, train_tensor], feed_dict=feed_dict, options=run_options, run_metadata=run_metadata) train_writer.add_run_metadata(run_metadata, 'step_{}'.format(gl), gl) train_writer.add_summary(summary_str, gl) else: # Summary summary_str = sess.run(merged, feed_dict=feed_dict) train_writer.add_summary(summary_str, gl) log_str = ( '{} Epoch: {:3d}, Step: {:4d}, Learning rate: {:.4e}, Loss: {:5.3f}\n' '{:14s} Speed: {:.3f}s/iter, Remain: {}').format( datetime.datetime.now().strftime('%m/%d %H:%M:%S'), int(cfg.solver.batch_size * gl / dataset.data_size), int(gl), round(learning_rate.eval(session=sess), 6), loss, '', train_timer.average_time, train_timer.remain(step, max_iter)) print(log_str) train_timer.reset() if gl % save_iter == 0: print('{} Saving checkpoint file to: {}'.format( datetime.datetime.now().strftime('%m/%d %H:%M:%S'), ckpt_dir)) saver.save(sess, ckpt_file, global_step=global_step) evaluate_iter = 100000