def __init__(self, config, models): model = models[0] assert isinstance(model, Model) self.config = config self.model = model self.opt = tf.train.AdadeltaOptimizer(config.init_lr) self.var_list = model.get_var_list() self.global_step = model.get_global_step() self.summary = model.get_summary() self.models = models losses = [] grads_list = [] for gpu_idx, model in enumerate(models): with tf.name_scope("grads_{}".format(gpu_idx)), tf.device("/{}:{}".format(config.device_type, gpu_idx)): loss = model.get_loss() grads = self.opt.compute_gradients(loss, var_list=self.var_list) losses.append(loss) grads_list.append(grads) self.loss = tf.add_n(losses)/len(losses) self.grads = average_gradients(grads_list) self.train_op = self.opt.apply_gradients(self.grads, global_step=self.global_step)
def train(): import multiprocessing as mp mp.set_start_method('spawn', force=True) os.environ['CUDA_VISIBLE_DEVICES'] = cfg.TRAIN.GPU_LIST gpus = list(range(len(cfg.TRAIN.GPU_LIST.split(',')))) num_gpus = len(gpus) restore_from_original_checkpoint = True checkpoint_path = cfg.TRAIN.LOG_DIR + COMMON_POSTFIX if not tf.io.gfile.exists(checkpoint_path): tf.io.gfile.makedirs(checkpoint_path) else: restore_from_original_checkpoint = False register_coco(os.path.expanduser(cfg.DATA.BASEDIR)) data_iter = get_train_dataflow(batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * num_gpus) ds = tf.data.Dataset.from_generator( lambda: map( lambda x: tuple([ x[k] for k in [ 'images', 'gt_boxes', 'gt_labels', 'orig_gt_counts', 'all_anchors_level2', 'anchor_labels_level2', 'anchor_boxes_level2', 'all_anchors_level3', 'anchor_labels_level3', 'anchor_boxes_level3', 'all_anchors_level4', 'anchor_labels_level4', 'anchor_boxes_level4', 'all_anchors_level5', 'anchor_labels_level5', 'anchor_boxes_level5', 'all_anchors_level6', 'anchor_labels_level6', 'anchor_boxes_level6' ] ]), data_iter), (tf.float32, tf.float32, tf.int64, tf.int32, tf.float32, tf.int32, tf.float32, tf.float32, tf.int32, tf.float32, tf.float32, tf.int32, tf.float32, tf.float32, tf.int32, tf.float32, tf.float32, tf.int32, tf.float32), ( tf.TensorShape([None, None, None, 3]), tf.TensorShape([None, None, 4]), tf.TensorShape([None, None]), tf.TensorShape([ None, ]), tf.TensorShape([None, None, None, None]), tf.TensorShape([None, None, None, None]), tf.TensorShape([None, None, None, None, 4]), #lv2 tf.TensorShape([None, None, None, None]), tf.TensorShape([None, None, None, None]), tf.TensorShape([None, None, None, None, 4]), #lv3 tf.TensorShape([None, None, None, None]), tf.TensorShape([None, None, None, None]), tf.TensorShape([None, None, None, None, 4]), #lv4 tf.TensorShape([None, None, None, None]), tf.TensorShape([None, None, None, None]), tf.TensorShape([None, None, None, None, 4]), #lv5 tf.TensorShape([None, None, None, None]), tf.TensorShape([None, None, None, None]), tf.TensorShape([None, None, None, None, 4]) #lv6 )) ds = ds.prefetch(buffer_size=128) ds = ds.make_one_shot_iterator() images, gt_boxes, gt_labels, orig_gt_counts, \ all_anchors_level2, anchor_labels_level2, anchor_boxes_level2, \ all_anchors_level3, anchor_labels_level3, anchor_boxes_level3, \ all_anchors_level4, anchor_labels_level4, anchor_boxes_level4, \ all_anchors_level5, anchor_labels_level5, anchor_boxes_level5, \ all_anchors_level6, anchor_labels_level6, anchor_boxes_level6 \ = ds.get_next() # build optimizers global_step = tf.train.get_or_create_global_step() learning_rate = warmup_lr_schedule(init_learning_rate=cfg.TRAIN.BASE_LR, global_step=global_step, warmup_step=cfg.TRAIN.WARMUP_STEP) opt = tf.train.MomentumOptimizer(learning_rate, momentum=0.9) sess_config = tf.ConfigProto() sess_config.allow_soft_placement = True sess_config.log_device_placement = False sess_config.gpu_options.allow_growth = True sess = tf.Session(config=sess_config) if num_gpus > 1: base_inputs_list = [ tf.split(value, num_or_size_splits=num_gpus, axis=0) for value in [images, gt_boxes, gt_labels, orig_gt_counts] ] fpn_all_anchors_list = \ [[tf.identity(value) for _ in range(num_gpus)] for value in [all_anchors_level2, all_anchors_level3, all_anchors_level4, all_anchors_level5, all_anchors_level6]] fpn_anchor_gt_labels_list = \ [tf.split(value, num_or_size_splits=num_gpus, axis=0) for value in [anchor_labels_level2, anchor_labels_level3, anchor_labels_level4, anchor_labels_level5, anchor_labels_level6]] fpn_anchor_gt_boxes_list = \ [tf.split(value, num_or_size_splits=num_gpus, axis=0) for value in [anchor_boxes_level2, anchor_boxes_level3, anchor_boxes_level4, anchor_boxes_level5, anchor_boxes_level6]] tower_grads = [] total_loss_dict = { 'rpn_cls_loss': tf.constant(0.), 'rpn_box_loss': tf.constant(0.), 'rcnn_cls_loss': tf.constant(0.), 'rcnn_box_loss': tf.constant(0.) } for i, gpu_id in enumerate(gpus): with tf.device('/gpu:%d' % gpu_id): with tf.name_scope('model_%d' % gpu_id) as scope: inputs1 = [input[i] for input in base_inputs_list] inputs2 = [[input[i] for input in fpn_all_anchors_list]] inputs3 = [[ input[i] for input in fpn_anchor_gt_labels_list ]] inputs4 = [[ input[i] for input in fpn_anchor_gt_boxes_list ]] net_inputs = inputs1 + inputs2 + inputs3 + inputs4 tower_loss_dict = tower_loss_func(net_inputs, reuse=(gpu_id > 0)) batch_norm_updates = tf.get_collection( tf.GraphKeys.UPDATE_OPS, scope) tower_loss = tf.add_n( [v for k, v in tower_loss_dict.items()]) for k, v in tower_loss_dict.items(): total_loss_dict[k] += v if i == num_gpus - 1: wd_loss = regularize_cost('.*/kernel', l2_regularizer( cfg.TRAIN.WEIGHT_DECAY), name='wd_cost') tower_loss = tower_loss + wd_loss # Retain the summaries from the final tower. summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope) if cfg.FRCNN.VISUALIZATION: with tf.device('/cpu:0'): with tf.name_scope('loss-summaries'): for k, v in tower_loss_dict.items(): summaries.append( tf.summary.scalar(k, v)) grads = opt.compute_gradients(tower_loss) tower_grads.append(grads) grads = average_gradients(tower_grads) for k, v in total_loss_dict.items(): total_loss_dict[k] = v / tf.cast(num_gpus, tf.float32) average_total_loss = tf.add_n([v for k, v in total_loss_dict.items()] + [wd_loss]) else: fpn_all_anchors = \ [all_anchors_level2, all_anchors_level3, all_anchors_level4, all_anchors_level5, all_anchors_level6] fpn_anchor_gt_labels = \ [anchor_labels_level2, anchor_labels_level3, anchor_labels_level4, anchor_labels_level5, anchor_labels_level6] fpn_anchor_gt_boxes = \ [anchor_boxes_level2, anchor_boxes_level3, anchor_boxes_level4, anchor_boxes_level5, anchor_boxes_level6] net_inputs = [ images, gt_boxes, gt_labels, orig_gt_counts, fpn_all_anchors, fpn_anchor_gt_labels, fpn_anchor_gt_boxes ] tower_loss_dict = tower_loss_func(net_inputs) batch_norm_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS) wd_loss = regularize_cost('.*/kernel', l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost') average_total_loss = tf.add_n([v for k, v in tower_loss_dict.items()] + [wd_loss]) grads = opt.compute_gradients(average_total_loss) total_loss_dict = tower_loss_dict summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) if cfg.FRCNN.VISUALIZATION: with tf.device('/cpu:0'): with tf.name_scope('loss-summaries'): for k, v in tower_loss_dict.items(): summaries.append(tf.summary.scalar(k, v)) apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) summaries.append(tf.summary.scalar('learning_rate', learning_rate)) # add histograms for trainable variables for grad, var in grads: # print(grad, var) if grad is not None: summaries.append( tf.summary.histogram(var.op.name + '/gradients', grad)) # add histograms for trainable variables for var in tf.trainable_variables(): summaries.append(tf.summary.histogram(var.op.name, var)) variable_averages = tf.train.ExponentialMovingAverage( cfg.TRAIN.MOVING_AVERAGE_DECAY, num_updates=global_step) variable_averages_op = variable_averages.apply(tf.trainable_variables()) all_global_vars = [] for var in tf.global_variables(): all_global_vars.append(var.name + '\n') # print(var.name, var.shape) with open('all_global_vars.txt', 'w') as fp: fp.writelines(all_global_vars) all_trainable_vars = [] for var in tf.trainable_variables(): all_trainable_vars.append(var.name + '\n') with open('all_trainable_vars.txt', 'w') as fp: fp.writelines(all_trainable_vars) all_moving_average_vars = [] for var in tf.moving_average_variables(): all_moving_average_vars.append(var.name + '\n') with open('all_moving_average_variables.txt', 'w') as fp: fp.writelines(all_moving_average_vars) # batch norm updates batch_norm_updates_op = tf.group(*batch_norm_updates) with tf.control_dependencies( [apply_gradient_op, variable_averages_op, batch_norm_updates_op]): train_op = tf.no_op(name='train_op') saver = tf.train.Saver(tf.global_variables()) summary_op = tf.summary.merge(summaries) summary_writer = tf.summary.FileWriter(checkpoint_path, tf.get_default_graph()) init_op = tf.group( [tf.global_variables_initializer(), tf.local_variables_initializer()]) sess.run(init_op) if False: print('load weights ...') ckpt_params = dict(np.load('MSRA-R50.npz')) assign_ops = [] all_variables = [] for var in tf.global_variables(): dst_name = var.name all_variables.append(dst_name + '\n') if 'resnet50' in dst_name: src_name = dst_name.replace('resnet50/', ''). \ replace('conv2d/kernel:0', 'W') \ .replace('conv2d/bias:0', 'b') \ .replace('batch_normalization/gamma:0', 'gamma') \ .replace('batch_normalization/beta:0', 'beta') \ .replace('batch_normalization/moving_mean:0', 'mean/EMA') \ .replace('batch_normalization/moving_variance:0', 'variance/EMA') \ .replace('kernel:0', 'W').replace('bias:0', 'b') if 'batch_normalization' in dst_name: src_name = src_name.replace('res', 'bn') if 'conv1' in src_name: src_name = 'bn_' + src_name if src_name == 'fc1000/W': print('{} --> {} {}'.format('fc1000/W', dst_name, var.shape)) assign_ops.append( tf.assign( var, np.reshape(ckpt_params[src_name], [2048, 1000]))) continue if src_name in ckpt_params: print('{} --> {} {}'.format(src_name, dst_name, var.shape)) assign_ops.append(tf.assign(var, ckpt_params[src_name])) print('load weights done.') with open('all_vars.txt', 'w') as fp: fp.writelines(all_variables) all_update_ops = [] for op in tf.get_collection(tf.GraphKeys.UPDATE_OPS): all_update_ops.append(op.name + '\n') with open('all_update_ops.txt', 'w') as fp: fp.writelines(all_update_ops) sess.run(assign_ops) else: if False: all_vars = [] restore_var_dict = {} for var in tf.global_variables(): all_vars.append(var.name + '\n') if 'rpn' not in var.name and 'rcnn' not in var.name and 'global_step' not in var.name and \ 'Momentum' not in var.name and 'ExponentialMovingAverage' not in var.name: restore_var_dict[var.name.replace(':0', '')] = var with open('all_vars.txt', 'w') as fp: fp.writelines(all_vars) restorer = tf.train.Saver(var_list=restore_var_dict) restorer.restore(sess, cfg.BACKBONE.CHECKPOINT_PATH) else: if restore_from_original_checkpoint: # restore from official ResNet checkpoint all_vars = [] restore_var_dict = {} for var in tf.global_variables(): all_vars.append(var.name + '\n') if 'rpn' not in var.name and 'rcnn' not in var.name and 'fpn' not in var.name \ and 'global_step' not in var.name and \ 'Momentum' not in var.name and 'ExponentialMovingAverage' not in var.name: restore_var_dict[var.name.replace('resnet50/', '').replace( ':0', '')] = var print(var.name, var.shape) with open('all_vars.txt', 'w') as fp: fp.writelines(all_vars) restore_vars_names = [ k + '\n' for k in restore_var_dict.keys() ] with open('all_restore_vars.txt', 'w') as fp: fp.writelines(restore_vars_names) restorer = tf.train.Saver(var_list=restore_var_dict) restorer.restore(sess, cfg.BACKBONE.CHECKPOINT_PATH) else: all_vars = [] restore_var_dict = {} for var in tf.global_variables(): all_vars.append(var.name + '\n') restore_var_dict[var.name.replace(':0', '')] = var with open('all_vars.txt', 'w') as fp: fp.writelines(all_vars) # restore from local checkpoint restorer = tf.train.Saver(tf.global_variables()) try: restorer.restore( sess, tf.train.latest_checkpoint(checkpoint_path)) except: pass # record all ops all_operations = [] for op in sess.graph.get_operations(): all_operations.append(op.name + '\n') with open('all_ops.txt', 'w') as fp: fp.writelines(all_operations) loss_names = [ 'rpn_cls_loss', 'rpn_box_loss', 'rcnn_cls_loss', 'rcnn_box_loss' ] sess2run = list() sess2run.append(train_op) sess2run.append(learning_rate) sess2run.append(average_total_loss) sess2run.append(wd_loss) sess2run.extend([total_loss_dict[k] for k in loss_names]) print('begin training ...') step = sess.run(global_step) step0 = step start = time.time() for step in range(step, cfg.TRAIN.MAX_STEPS): if step % cfg.TRAIN.SAVE_SUMMARY_STEPS == 0: _, lr_, tl_, wd_loss_, \ rpn_cls_loss_, rpn_box_loss_, \ rcnn_cls_loss_, rcnn_box_loss_, \ summary_str = sess.run(sess2run + [summary_op]) avg_time_per_step = (time.time() - start) / cfg.TRAIN.SAVE_SUMMARY_STEPS avg_examples_per_second = (cfg.TRAIN.SAVE_SUMMARY_STEPS * cfg.TRAIN.BATCH_SIZE_PER_GPU * num_gpus) \ / (time.time() - start) start = time.time() print('Step {:06d}, LR: {:.6f} LOSS: {:.4f}, ' 'RPN: {:.4f}, {:.4f}, RCNN: {:.4f}, {:.4f}, wd: {:.4f}, ' '{:.2f} s/step, {:.2f} samples/s'.format( step, lr_, tl_, rpn_cls_loss_, rpn_box_loss_, rcnn_cls_loss_, rcnn_box_loss_, wd_loss_, avg_time_per_step, avg_examples_per_second)) summary_writer.add_summary(summary_str, global_step=step) else: sess.run(train_op) if step % 1000 == 0: saver.save(sess, checkpoint_path + '/model.ckpt', global_step=step)