예제 #1
0
    def __init__(self, config, models):
        model = models[0]
        assert isinstance(model, Model)
        self.config = config
        self.model = model
        self.opt = tf.train.AdadeltaOptimizer(config.init_lr)
        self.var_list = model.get_var_list()
        self.global_step = model.get_global_step()
        self.summary = model.get_summary()
        self.models = models
        losses = []
        grads_list = []
        for gpu_idx, model in enumerate(models):
            with tf.name_scope("grads_{}".format(gpu_idx)), tf.device("/{}:{}".format(config.device_type, gpu_idx)):
                loss = model.get_loss()
                grads = self.opt.compute_gradients(loss, var_list=self.var_list)
                losses.append(loss)
                grads_list.append(grads)

        self.loss = tf.add_n(losses)/len(losses)
        self.grads = average_gradients(grads_list)
        self.train_op = self.opt.apply_gradients(self.grads, global_step=self.global_step)
def train():
    import multiprocessing as mp
    mp.set_start_method('spawn', force=True)
    os.environ['CUDA_VISIBLE_DEVICES'] = cfg.TRAIN.GPU_LIST
    gpus = list(range(len(cfg.TRAIN.GPU_LIST.split(','))))
    num_gpus = len(gpus)

    restore_from_original_checkpoint = True
    checkpoint_path = cfg.TRAIN.LOG_DIR + COMMON_POSTFIX
    if not tf.io.gfile.exists(checkpoint_path):
        tf.io.gfile.makedirs(checkpoint_path)
    else:
        restore_from_original_checkpoint = False

    register_coco(os.path.expanduser(cfg.DATA.BASEDIR))

    data_iter = get_train_dataflow(batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU *
                                   num_gpus)
    ds = tf.data.Dataset.from_generator(
        lambda: map(
            lambda x: tuple([
                x[k] for k in [
                    'images', 'gt_boxes', 'gt_labels', 'orig_gt_counts',
                    'all_anchors_level2', 'anchor_labels_level2',
                    'anchor_boxes_level2', 'all_anchors_level3',
                    'anchor_labels_level3', 'anchor_boxes_level3',
                    'all_anchors_level4', 'anchor_labels_level4',
                    'anchor_boxes_level4', 'all_anchors_level5',
                    'anchor_labels_level5', 'anchor_boxes_level5',
                    'all_anchors_level6', 'anchor_labels_level6',
                    'anchor_boxes_level6'
                ]
            ]), data_iter),
        (tf.float32, tf.float32, tf.int64, tf.int32, tf.float32, tf.int32,
         tf.float32, tf.float32, tf.int32, tf.float32, tf.float32, tf.int32,
         tf.float32, tf.float32, tf.int32, tf.float32, tf.float32, tf.int32,
         tf.float32),
        (
            tf.TensorShape([None, None, None, 3]),
            tf.TensorShape([None, None, 4]),
            tf.TensorShape([None, None]),
            tf.TensorShape([
                None,
            ]),
            tf.TensorShape([None, None, None, None]),
            tf.TensorShape([None, None, None, None]),
            tf.TensorShape([None, None, None, None, 4]),  #lv2
            tf.TensorShape([None, None, None, None]),
            tf.TensorShape([None, None, None, None]),
            tf.TensorShape([None, None, None, None, 4]),  #lv3
            tf.TensorShape([None, None, None, None]),
            tf.TensorShape([None, None, None, None]),
            tf.TensorShape([None, None, None, None, 4]),  #lv4
            tf.TensorShape([None, None, None, None]),
            tf.TensorShape([None, None, None, None]),
            tf.TensorShape([None, None, None, None, 4]),  #lv5
            tf.TensorShape([None, None, None, None]),
            tf.TensorShape([None, None, None, None]),
            tf.TensorShape([None, None, None, None, 4])  #lv6
        ))
    ds = ds.prefetch(buffer_size=128)
    ds = ds.make_one_shot_iterator()
    images, gt_boxes, gt_labels, orig_gt_counts, \
    all_anchors_level2, anchor_labels_level2, anchor_boxes_level2, \
    all_anchors_level3, anchor_labels_level3, anchor_boxes_level3, \
    all_anchors_level4, anchor_labels_level4, anchor_boxes_level4, \
    all_anchors_level5, anchor_labels_level5, anchor_boxes_level5, \
    all_anchors_level6, anchor_labels_level6, anchor_boxes_level6 \
        = ds.get_next()

    # build optimizers
    global_step = tf.train.get_or_create_global_step()
    learning_rate = warmup_lr_schedule(init_learning_rate=cfg.TRAIN.BASE_LR,
                                       global_step=global_step,
                                       warmup_step=cfg.TRAIN.WARMUP_STEP)
    opt = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)

    sess_config = tf.ConfigProto()
    sess_config.allow_soft_placement = True
    sess_config.log_device_placement = False
    sess_config.gpu_options.allow_growth = True
    sess = tf.Session(config=sess_config)

    if num_gpus > 1:

        base_inputs_list = [
            tf.split(value, num_or_size_splits=num_gpus, axis=0)
            for value in [images, gt_boxes, gt_labels, orig_gt_counts]
        ]
        fpn_all_anchors_list = \
            [[tf.identity(value) for _ in range(num_gpus)] for value in
             [all_anchors_level2, all_anchors_level3, all_anchors_level4, all_anchors_level5, all_anchors_level6]]
        fpn_anchor_gt_labels_list = \
            [tf.split(value, num_or_size_splits=num_gpus, axis=0) for value in
             [anchor_labels_level2, anchor_labels_level3, anchor_labels_level4,
              anchor_labels_level5, anchor_labels_level6]]
        fpn_anchor_gt_boxes_list = \
            [tf.split(value, num_or_size_splits=num_gpus, axis=0) for value in
             [anchor_boxes_level2, anchor_boxes_level3, anchor_boxes_level4,
              anchor_boxes_level5, anchor_boxes_level6]]

        tower_grads = []
        total_loss_dict = {
            'rpn_cls_loss': tf.constant(0.),
            'rpn_box_loss': tf.constant(0.),
            'rcnn_cls_loss': tf.constant(0.),
            'rcnn_box_loss': tf.constant(0.)
        }
        for i, gpu_id in enumerate(gpus):
            with tf.device('/gpu:%d' % gpu_id):
                with tf.name_scope('model_%d' % gpu_id) as scope:
                    inputs1 = [input[i] for input in base_inputs_list]
                    inputs2 = [[input[i] for input in fpn_all_anchors_list]]
                    inputs3 = [[
                        input[i] for input in fpn_anchor_gt_labels_list
                    ]]
                    inputs4 = [[
                        input[i] for input in fpn_anchor_gt_boxes_list
                    ]]
                    net_inputs = inputs1 + inputs2 + inputs3 + inputs4
                    tower_loss_dict = tower_loss_func(net_inputs,
                                                      reuse=(gpu_id > 0))
                    batch_norm_updates = tf.get_collection(
                        tf.GraphKeys.UPDATE_OPS, scope)

                    tower_loss = tf.add_n(
                        [v for k, v in tower_loss_dict.items()])

                    for k, v in tower_loss_dict.items():
                        total_loss_dict[k] += v

                    if i == num_gpus - 1:
                        wd_loss = regularize_cost('.*/kernel',
                                                  l2_regularizer(
                                                      cfg.TRAIN.WEIGHT_DECAY),
                                                  name='wd_cost')
                        tower_loss = tower_loss + wd_loss

                        # Retain the summaries from the final tower.
                        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
                                                      scope)

                        if cfg.FRCNN.VISUALIZATION:
                            with tf.device('/cpu:0'):
                                with tf.name_scope('loss-summaries'):
                                    for k, v in tower_loss_dict.items():
                                        summaries.append(
                                            tf.summary.scalar(k, v))

                    grads = opt.compute_gradients(tower_loss)
                    tower_grads.append(grads)

        grads = average_gradients(tower_grads)
        for k, v in total_loss_dict.items():
            total_loss_dict[k] = v / tf.cast(num_gpus, tf.float32)
        average_total_loss = tf.add_n([v for k, v in total_loss_dict.items()] +
                                      [wd_loss])
    else:
        fpn_all_anchors = \
            [all_anchors_level2, all_anchors_level3, all_anchors_level4, all_anchors_level5, all_anchors_level6]
        fpn_anchor_gt_labels = \
            [anchor_labels_level2, anchor_labels_level3, anchor_labels_level4, anchor_labels_level5,
             anchor_labels_level6]
        fpn_anchor_gt_boxes = \
            [anchor_boxes_level2, anchor_boxes_level3, anchor_boxes_level4, anchor_boxes_level5, anchor_boxes_level6]
        net_inputs = [
            images, gt_boxes, gt_labels, orig_gt_counts, fpn_all_anchors,
            fpn_anchor_gt_labels, fpn_anchor_gt_boxes
        ]
        tower_loss_dict = tower_loss_func(net_inputs)
        batch_norm_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        wd_loss = regularize_cost('.*/kernel',
                                  l2_regularizer(cfg.TRAIN.WEIGHT_DECAY),
                                  name='wd_cost')
        average_total_loss = tf.add_n([v for k, v in tower_loss_dict.items()] +
                                      [wd_loss])
        grads = opt.compute_gradients(average_total_loss)
        total_loss_dict = tower_loss_dict

        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
        if cfg.FRCNN.VISUALIZATION:
            with tf.device('/cpu:0'):
                with tf.name_scope('loss-summaries'):
                    for k, v in tower_loss_dict.items():
                        summaries.append(tf.summary.scalar(k, v))

    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
    summaries.append(tf.summary.scalar('learning_rate', learning_rate))

    # add histograms for trainable variables
    for grad, var in grads:
        # print(grad, var)
        if grad is not None:
            summaries.append(
                tf.summary.histogram(var.op.name + '/gradients', grad))

    # add histograms for trainable variables
    for var in tf.trainable_variables():
        summaries.append(tf.summary.histogram(var.op.name, var))

    variable_averages = tf.train.ExponentialMovingAverage(
        cfg.TRAIN.MOVING_AVERAGE_DECAY, num_updates=global_step)
    variable_averages_op = variable_averages.apply(tf.trainable_variables())

    all_global_vars = []
    for var in tf.global_variables():
        all_global_vars.append(var.name + '\n')
        # print(var.name, var.shape)
    with open('all_global_vars.txt', 'w') as fp:
        fp.writelines(all_global_vars)

    all_trainable_vars = []
    for var in tf.trainable_variables():
        all_trainable_vars.append(var.name + '\n')
    with open('all_trainable_vars.txt', 'w') as fp:
        fp.writelines(all_trainable_vars)

    all_moving_average_vars = []
    for var in tf.moving_average_variables():
        all_moving_average_vars.append(var.name + '\n')
    with open('all_moving_average_variables.txt', 'w') as fp:
        fp.writelines(all_moving_average_vars)

    # batch norm updates
    batch_norm_updates_op = tf.group(*batch_norm_updates)
    with tf.control_dependencies(
        [apply_gradient_op, variable_averages_op, batch_norm_updates_op]):
        train_op = tf.no_op(name='train_op')

    saver = tf.train.Saver(tf.global_variables())
    summary_op = tf.summary.merge(summaries)
    summary_writer = tf.summary.FileWriter(checkpoint_path,
                                           tf.get_default_graph())

    init_op = tf.group(
        [tf.global_variables_initializer(),
         tf.local_variables_initializer()])
    sess.run(init_op)

    if False:
        print('load weights ...')
        ckpt_params = dict(np.load('MSRA-R50.npz'))
        assign_ops = []
        all_variables = []
        for var in tf.global_variables():
            dst_name = var.name
            all_variables.append(dst_name + '\n')
            if 'resnet50' in dst_name:
                src_name = dst_name.replace('resnet50/', ''). \
                    replace('conv2d/kernel:0', 'W') \
                    .replace('conv2d/bias:0', 'b') \
                    .replace('batch_normalization/gamma:0', 'gamma') \
                    .replace('batch_normalization/beta:0', 'beta') \
                    .replace('batch_normalization/moving_mean:0', 'mean/EMA') \
                    .replace('batch_normalization/moving_variance:0', 'variance/EMA') \
                    .replace('kernel:0', 'W').replace('bias:0', 'b')
                if 'batch_normalization' in dst_name:
                    src_name = src_name.replace('res', 'bn')
                    if 'conv1' in src_name:
                        src_name = 'bn_' + src_name

                if src_name == 'fc1000/W':
                    print('{} --> {} {}'.format('fc1000/W', dst_name,
                                                var.shape))
                    assign_ops.append(
                        tf.assign(
                            var, np.reshape(ckpt_params[src_name],
                                            [2048, 1000])))
                    continue
                if src_name in ckpt_params:
                    print('{} --> {} {}'.format(src_name, dst_name, var.shape))
                    assign_ops.append(tf.assign(var, ckpt_params[src_name]))
        print('load weights done.')
        with open('all_vars.txt', 'w') as fp:
            fp.writelines(all_variables)
        all_update_ops = []
        for op in tf.get_collection(tf.GraphKeys.UPDATE_OPS):
            all_update_ops.append(op.name + '\n')
        with open('all_update_ops.txt', 'w') as fp:
            fp.writelines(all_update_ops)
        sess.run(assign_ops)
    else:
        if False:
            all_vars = []
            restore_var_dict = {}
            for var in tf.global_variables():
                all_vars.append(var.name + '\n')
                if 'rpn' not in var.name and 'rcnn' not in var.name and 'global_step' not in var.name and \
                        'Momentum' not in var.name and 'ExponentialMovingAverage' not in var.name:
                    restore_var_dict[var.name.replace(':0', '')] = var
            with open('all_vars.txt', 'w') as fp:
                fp.writelines(all_vars)
            restorer = tf.train.Saver(var_list=restore_var_dict)
            restorer.restore(sess, cfg.BACKBONE.CHECKPOINT_PATH)
        else:
            if restore_from_original_checkpoint:
                # restore from official ResNet checkpoint
                all_vars = []
                restore_var_dict = {}
                for var in tf.global_variables():
                    all_vars.append(var.name + '\n')
                    if 'rpn' not in var.name and 'rcnn' not in var.name and 'fpn' not in var.name \
                            and 'global_step' not in var.name and \
                            'Momentum' not in var.name and 'ExponentialMovingAverage' not in var.name:
                        restore_var_dict[var.name.replace('resnet50/',
                                                          '').replace(
                                                              ':0', '')] = var
                        print(var.name, var.shape)
                with open('all_vars.txt', 'w') as fp:
                    fp.writelines(all_vars)
                restore_vars_names = [
                    k + '\n' for k in restore_var_dict.keys()
                ]
                with open('all_restore_vars.txt', 'w') as fp:
                    fp.writelines(restore_vars_names)
                restorer = tf.train.Saver(var_list=restore_var_dict)
                restorer.restore(sess, cfg.BACKBONE.CHECKPOINT_PATH)
            else:
                all_vars = []
                restore_var_dict = {}
                for var in tf.global_variables():
                    all_vars.append(var.name + '\n')
                    restore_var_dict[var.name.replace(':0', '')] = var
                with open('all_vars.txt', 'w') as fp:
                    fp.writelines(all_vars)
                # restore from local checkpoint
                restorer = tf.train.Saver(tf.global_variables())
                try:
                    restorer.restore(
                        sess, tf.train.latest_checkpoint(checkpoint_path))
                except:
                    pass

    # record all ops
    all_operations = []
    for op in sess.graph.get_operations():
        all_operations.append(op.name + '\n')
    with open('all_ops.txt', 'w') as fp:
        fp.writelines(all_operations)

    loss_names = [
        'rpn_cls_loss', 'rpn_box_loss', 'rcnn_cls_loss', 'rcnn_box_loss'
    ]
    sess2run = list()
    sess2run.append(train_op)
    sess2run.append(learning_rate)
    sess2run.append(average_total_loss)
    sess2run.append(wd_loss)
    sess2run.extend([total_loss_dict[k] for k in loss_names])

    print('begin training ...')
    step = sess.run(global_step)
    step0 = step
    start = time.time()
    for step in range(step, cfg.TRAIN.MAX_STEPS):

        if step % cfg.TRAIN.SAVE_SUMMARY_STEPS == 0:

            _, lr_, tl_, wd_loss_, \
            rpn_cls_loss_, rpn_box_loss_, \
            rcnn_cls_loss_, rcnn_box_loss_, \
            summary_str = sess.run(sess2run + [summary_op])

            avg_time_per_step = (time.time() -
                                 start) / cfg.TRAIN.SAVE_SUMMARY_STEPS
            avg_examples_per_second = (cfg.TRAIN.SAVE_SUMMARY_STEPS * cfg.TRAIN.BATCH_SIZE_PER_GPU * num_gpus) \
                                      / (time.time() - start)
            start = time.time()
            print('Step {:06d}, LR: {:.6f} LOSS: {:.4f}, '
                  'RPN: {:.4f}, {:.4f}, RCNN: {:.4f}, {:.4f}, wd: {:.4f}, '
                  '{:.2f} s/step, {:.2f} samples/s'.format(
                      step, lr_, tl_, rpn_cls_loss_, rpn_box_loss_,
                      rcnn_cls_loss_, rcnn_box_loss_, wd_loss_,
                      avg_time_per_step, avg_examples_per_second))

            summary_writer.add_summary(summary_str, global_step=step)
        else:
            sess.run(train_op)

        if step % 1000 == 0:
            saver.save(sess, checkpoint_path + '/model.ckpt', global_step=step)