示例#1
0
class SolverWrapper(object):
    """
    A wrapper class for the training process
  """
    def __init__(self,
                 sess,
                 network,
                 imdb,
                 roidb,
                 valroidb,
                 output_dir,
                 tbdir,
                 pretrained_model=None):
        self.net = network
        self.imdb = imdb
        self.roidb = roidb
        self.valroidb = valroidb
        self.output_dir = output_dir
        self.tbdir = tbdir
        # Simply put '_val' at the end to save the summaries from the validation set
        self.tbvaldir = tbdir + '_val'
        if not os.path.exists(self.tbvaldir):
            os.makedirs(self.tbvaldir)
        self.pretrained_model = pretrained_model

    def snapshot(self, sess, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.ckpt'
        filename = os.path.join(self.output_dir, filename)
        self.saver.save(sess, filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indeces of the database
        perm = self.data_layer._perm
        # current position in the validation database
        cur_val = self.data_layer_val._cur
        # current shuffled indeces of the validation database
        perm_val = self.data_layer_val._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename

    def get_variables_in_checkpoint_file(self, file_name):
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print(
                    "It's likely that your checkpoint file has been compressed "
                    "with SNAPPY.")

    def train_model(self, sess, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb,
                                           self.imdb.num_classes,
                                           random=True)

        # Determine different scales for anchors, see paper
        with sess.graph.as_default():
            # Set the random seed for tensorflow
            tf.set_random_seed(cfg.RNG_SEED)
            # Build the main computation graph
            layers = self.net.create_architecture(
                sess,
                'TRAIN',
                self.imdb.num_classes,
                tag='default',
                anchor_scales=cfg.ANCHOR_SCALES,
                anchor_ratios=cfg.ANCHOR_RATIOS)
            # Define the loss
            loss = layers['total_loss']
            # Set learning rate and momentum
            lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
            momentum = cfg.TRAIN.MOMENTUM
            self.optimizer = tf.train.MomentumOptimizer(lr, momentum)

            # Compute the gradients wrt the loss
            gvs = self.optimizer.compute_gradients(loss)
            # Double the gradient of the bias if set
            if cfg.TRAIN.DOUBLE_BIAS:
                final_gvs = []
                with tf.variable_scope('Gradient_Mult') as scope:
                    for grad, var in gvs:
                        scale = 1.
                        if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((grad, var))
                train_op = self.optimizer.apply_gradients(final_gvs)
            else:
                train_op = self.optimizer.apply_gradients(gvs)

            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=100000)
            # Write the train and validation information to tensorboard
            self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
            self.valwriter = tf.summary.FileWriter(self.tbvaldir)

        # Find previous snapshots if there is any to restore from
        sfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
        sfiles = glob.glob(sfiles)
        sfiles.sort(key=os.path.getmtime)
        # Get the snapshot name in TensorFlow
        redstr = '_iter_{:d}.'.format(cfg.TRAIN.STEPSIZE + 1)
        sfiles = [ss.replace('.meta', '') for ss in sfiles]
        sfiles = [ss for ss in sfiles if redstr not in ss]

        nfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
        nfiles = glob.glob(nfiles)
        nfiles.sort(key=os.path.getmtime)
        nfiles = [nn for nn in nfiles if redstr not in nn]

        lsf = len(sfiles)
        assert len(nfiles) == lsf

        np_paths = nfiles
        ss_paths = sfiles

        if lsf == 0:
            # Fresh train directly from ImageNet weights
            print('Loading initial model weights from {:s}'.format(
                self.pretrained_model))
            variables = tf.global_variables()
            # Initialize all variables first
            sess.run(tf.variables_initializer(variables, name='init'))
            var_keep_dic = self.get_variables_in_checkpoint_file(
                self.pretrained_model)
            # Get the variables to restore, ignorizing the variables to fix
            variables_to_restore = self.net.get_variables_to_restore(
                variables, var_keep_dic)

            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(sess, self.pretrained_model)
            print('Loaded.')
            # Need to fix the variables before loading, so that the RGB weights are changed to BGR
            # For VGG16 it also changes the convolutional weights fc6 and fc7 to
            # fully connected weights
            self.net.fix_variables(sess, self.pretrained_model)
            print('Fixed.')
            sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))
            last_snapshot_iter = 0
        else:
            # Get the most recent snapshot and restore
            ss_paths = [ss_paths[-1]]
            np_paths = [np_paths[-1]]

            print('Restorining model snapshots from {:s}'.format(sfiles[-1]))
            self.saver.restore(sess, str(sfiles[-1]))
            print('Restored.')
            # Needs to restore the other hyperparameters/states for training, (TODO xinlei) I have
            # tried my best to find the random states so that it can be recovered exactly
            # However the Tensorflow state is currently not available
            with open(str(nfiles[-1]), 'rb') as fid:
                st0 = pickle.load(fid)
                cur = pickle.load(fid)
                perm = pickle.load(fid)
                cur_val = pickle.load(fid)
                perm_val = pickle.load(fid)
                last_snapshot_iter = pickle.load(fid)

                np.random.set_state(st0)
                self.data_layer._cur = cur
                self.data_layer._perm = perm
                self.data_layer_val._cur = cur_val
                self.data_layer_val._perm = perm_val

                # Set the learning rate, only reduce once
                if last_snapshot_iter > cfg.TRAIN.STEPSIZE:
                    sess.run(
                        tf.assign(lr,
                                  cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA))
                else:
                    sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))

        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        while iter < max_iters + 1:
            # Learning rate
            if iter == cfg.TRAIN.STEPSIZE + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(sess, iter)
                sess.run(
                    tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA))

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            now = time.time()
            if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
                  self.net.train_step_with_summary(sess, blobs, train_op)
                self.writer.add_summary(summary, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(sess, blobs_val)
                self.valwriter.add_summary(summary_val, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
                  self.net.train_step(sess, blobs, train_op)
            timer.toc()

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
                      (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
                print('speed: {:.3f}s / iter'.format(timer.average_time))

            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                snapshot_path, np_path = self.snapshot(sess, iter)
                np_paths.append(np_path)
                ss_paths.append(snapshot_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
                    for c in range(to_remove):
                        nfile = np_paths[0]
                        os.remove(str(nfile))
                        np_paths.remove(nfile)

                if len(ss_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
                    for c in range(to_remove):
                        sfile = ss_paths[0]
                        # To make the code compatible to earlier versions of Tensorflow,
                        # where the naming tradition for checkpoints are different
                        if os.path.exists(str(sfile)):
                            os.remove(str(sfile))
                        else:
                            os.remove(str(sfile + '.data-00000-of-00001'))
                            os.remove(str(sfile + '.index'))
                        sfile_meta = sfile + '.meta'
                        os.remove(str(sfile_meta))
                        ss_paths.remove(sfile)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(sess, iter - 1)

        self.writer.close()
        self.valwriter.close()
示例#2
0
class SolverWrapper(object):
    def __init__(self,
                 network,
                 imdb,
                 valimdb,
                 roidb,
                 valroidb,
                 model_dir,
                 pretrained_model=None):
        self.net = network
        self.imdb = imdb
        self.valimdb = valimdb
        self.roidb = roidb
        self.valroidb = valroidb
        self.model_dir = model_dir
        self.tbdir = os.path.join(model_dir, 'train_log')
        if not os.path.exists(self.tbdir):
            os.makedirs(self.tbdir)
        self.pretrained_model = pretrained_model

    def set_learn_strategy(self, learn_dict):
        self._disp_interval = learn_dict['disp_interval']
        self._valid_interval = learn_dict['disp_interval'] * 5
        self._use_tensorboard = learn_dict['use_tensorboard']
        self._use_valid = learn_dict['use_valid']
        self._evaluate = learn_dict['evaluate']
        self._save_point_interval = learn_dict['save_point_interval']
        self._lr_decay_steps = learn_dict['lr_decay_steps']

        if self._evaluate:
            self._begin_eval_point = learn_dict['begin_eval_point']
            self.evaluate_dir = os.path.join(self.model_dir, 'evaluate')
            self._evaluate_thresh = 0.05
            self._evaluate_max_per_image = 1

    def train_model(self, resume=None, max_iters=100000):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb,
                                           self.imdb.num_classes,
                                           random=True)

        # 载入存档点,初始化权重,设置优化函数,设置网络权重学习率
        self.prepare_construct(resume)

        net = self.net
        # training
        train_loss = 0
        rpn_cls_loss = 0
        rpn_bbox_loss = 0
        fast_rcnn_cls_loss = 0
        fast_rcnn_bbox_loss = 0
        tp, tf, fg, bg = 0., 0., 0, 0
        step_cnt = 0
        re_cnt = False
        t = Timer()
        t.tic()
        for step in range(self.start_step, max_iters + 1):
            blobs = self.data_layer.forward()

            if step % self._valid_interval == 0 and self._use_tensorboard:
                loss_r, image_r = net.train_operation(
                    blobs,
                    self._optimizer,
                    image_if=True,
                    clip_parameters=self._parameters)
                self._tensor_writer.add_image('Image', image_r, step)
            else:
                try:
                    loss_r, image_r = net.train_operation(
                        blobs,
                        self._optimizer,
                        image_if=False,
                        clip_parameters=self._parameters)
                except:
                    print('=' * 40)
                    print('=' * 40)
                    print('=' * 40)
                    print(blobs['im_name'])

            train_loss += loss_r[0]
            rpn_cls_loss += loss_r[1]
            rpn_bbox_loss += loss_r[2]
            fast_rcnn_cls_loss += loss_r[3]
            fast_rcnn_bbox_loss += loss_r[4]
            # fg:物体  bg:背景 tp:真阳 tf:真阴
            fg += net.metrics_dict['fg']
            bg += net.metrics_dict['bg']
            tp += net.metrics_dict['tp']
            tf += net.metrics_dict['tf']
            step_cnt += 1

            if step % self._disp_interval == 0:
                duration = t.toc(average=False)
                fps = step_cnt / duration

                log_text = 'step %d, image: %s, loss: %.4f, fps: %.2f (%.2fs per batch)' % (
                    step, blobs['im_name'], train_loss / step_cnt, fps,
                    1. / fps)
                tp_text = 'step {}, tp: {}/{}, tf: {}/{}'.format(
                    step, int(tp / step_cnt), int(fg / step_cnt),
                    int(tf / step_cnt), int(bg / step_cnt))
                pprint.pprint(log_text)
                pprint.pprint(tp_text)

                if self._use_tensorboard:
                    self._tensor_writer.add_text('Train',
                                                 log_text,
                                                 global_step=step)
                    # Train
                    avg_rpn_cls_loss = rpn_cls_loss / step_cnt
                    avg_rpn_bbox_loss = rpn_bbox_loss / step_cnt
                    avg_fast_rcnn_cls_loss = fast_rcnn_cls_loss / step_cnt
                    avg_fast_rcnn_bbox_loss = fast_rcnn_bbox_loss / step_cnt

                    self._tensor_writer.add_scalars(
                        'TrainSetLoss', {
                            'RPN_cls_loss': avg_rpn_cls_loss,
                            'RPN_bbox_loss': avg_rpn_bbox_loss,
                            'FastRcnn_cls_loss': avg_fast_rcnn_cls_loss,
                            'FastRcnn_bbox_loss': avg_fast_rcnn_bbox_loss
                        },
                        global_step=step)
                    self._tensor_writer.add_scalar('Learning_rate',
                                                   self._lr,
                                                   global_step=step)

                re_cnt = True

            if self._use_valid and step % self._valid_interval == 0 and step != 0:
                total_valid_loss = 0.0
                valid_rpn_cls_loss = 0.0
                valid_rpn_bbox_loss = 0.0
                valid_fast_rcnn_cls_loss = 0.0
                valid_fast_rcnn_bbox_loss = 0.0
                valid_step_cnt = 0
                valid_tp, valid_tf, valid_fg, valid_bg = 0., 0., 0, 0
                start_time = time.time()

                valid_length = self._disp_interval
                for valid_batch in range(valid_length):
                    # get one batch
                    blobs = self.data_layer_val.forward()

                    if self._use_tensorboard and valid_batch % valid_length == 0:
                        # 此处没传optimizer,不会更新网络,只计算loss
                        loss_r, image_r = net.train_operation(blobs,
                                                              None,
                                                              image_if=True)
                        self._tensor_writer.add_image('Image_Valid', image_r,
                                                      step)
                    else:
                        loss_r, image_r = net.train_operation(blobs,
                                                              None,
                                                              image_if=False)

                    total_valid_loss += loss_r[0]
                    valid_rpn_cls_loss += loss_r[1]
                    valid_rpn_bbox_loss += loss_r[2]
                    valid_fast_rcnn_cls_loss += loss_r[3]
                    valid_fast_rcnn_bbox_loss += loss_r[4]
                    valid_fg += net.metrics_dict['fg']
                    valid_bg += net.metrics_dict['bg']
                    valid_tp += net.metrics_dict['tp']
                    valid_tf += net.metrics_dict['tf']
                    valid_step_cnt += 1

                duration = time.time() - start_time
                fps = valid_step_cnt / duration

                log_text = 'step %d, valid average loss: %.4f, fps: %.2f (%.2fs per batch)' % (
                    step, total_valid_loss / valid_step_cnt, fps, 1. / fps)
                pprint.pprint(log_text)

                if self._use_tensorboard:
                    # Valid
                    avg_rpn_cls_loss_valid = valid_rpn_cls_loss / valid_step_cnt
                    avg_rpn_bbox_loss_valid = valid_rpn_bbox_loss / valid_step_cnt
                    avg_fast_rcnn_cls_loss_valid = valid_fast_rcnn_cls_loss / valid_step_cnt
                    avg_fast_rcnn_bbox_loss_valid = valid_fast_rcnn_bbox_loss / valid_step_cnt
                    valid_tpr = valid_tp * 1.0 / valid_fg
                    valid_tfr = valid_tf * 1.0 / valid_bg
                    real_total_loss_valid = valid_rpn_cls_loss + valid_rpn_bbox_loss\
                                            + valid_fast_rcnn_cls_loss + valid_fast_rcnn_bbox_loss

                    # Train
                    avg_rpn_cls_loss = rpn_cls_loss / step_cnt
                    avg_rpn_bbox_loss = rpn_bbox_loss / step_cnt
                    avg_fast_rcnn_cls_loss = fast_rcnn_cls_loss / step_cnt
                    avg_fast_rcnn_bbox_loss = fast_rcnn_bbox_loss / step_cnt
                    tpr = tp * 1.0 / fg
                    tfr = tf * 1.0 / bg
                    real_total_loss = rpn_cls_loss + rpn_bbox_loss + fast_rcnn_cls_loss + fast_rcnn_bbox_loss

                    self._tensor_writer.add_text('Valid',
                                                 log_text,
                                                 global_step=step)
                    self._tensor_writer.add_scalars(
                        'Total_Loss', {
                            'train': train_loss / step_cnt,
                            'valid': total_valid_loss / valid_step_cnt
                        },
                        global_step=step)
                    self._tensor_writer.add_scalars(
                        'Real_loss', {
                            'train': real_total_loss / step_cnt,
                            'valid': real_total_loss_valid / valid_step_cnt
                        },
                        global_step=step)
                    self._tensor_writer.add_scalars(
                        'RPN_cls_loss', {
                            'train': avg_rpn_cls_loss,
                            'valid': avg_rpn_cls_loss_valid
                        },
                        global_step=step)
                    self._tensor_writer.add_scalars(
                        'RPN_bbox_loss', {
                            'train': avg_rpn_bbox_loss,
                            'valid': avg_rpn_bbox_loss_valid
                        },
                        global_step=step)
                    self._tensor_writer.add_scalars(
                        'FastRcnn_cls_loss', {
                            'train': avg_fast_rcnn_cls_loss,
                            'valid': avg_fast_rcnn_cls_loss_valid
                        },
                        global_step=step)
                    self._tensor_writer.add_scalars(
                        'FastRcnn_bbox_loss', {
                            'train': avg_fast_rcnn_bbox_loss,
                            'valid': avg_fast_rcnn_bbox_loss_valid
                        },
                        global_step=step)
                    self._tensor_writer.add_scalars('tpr', {
                        'train': tpr,
                        'valid': valid_tpr
                    },
                                                    global_step=step)
                    self._tensor_writer.add_scalars('tfr', {
                        'train': tfr,
                        'valid': valid_tfr
                    },
                                                    global_step=step)

                    self._tensor_writer.add_scalars(
                        'ValidSetLoss', {
                            'RPN_cls_loss': avg_rpn_cls_loss_valid,
                            'RPN_bbox_loss': avg_rpn_bbox_loss_valid,
                            'FastRcnn_cls_loss': avg_fast_rcnn_cls_loss_valid,
                            'FastRcnn_bbox_loss': avg_fast_rcnn_bbox_loss_valid
                        },
                        global_step=step)

            if (step % self._save_point_interval == 0) and step != 0:
                save_name, _ = self.save_check_point(step)
                print('save model: {}'.format(save_name))

            if self._evaluate:
                if step > self._begin_eval_point and step % cfg.TRAIN.EVALUATE_POINT == 0:
                    self.net.eval()
                    evaluate_solverwrapper = EvaluateSolverWrapper(
                        network=self.net,
                        imdb=self.valimdb,
                        model_dir=None,
                        output_dir=self.evaluate_dir)
                    metrics_cls, metrics_reg = evaluate_solverwrapper.\
                      eval_model(step, self._evaluate_max_per_image, self._evaluate_thresh)

                    self.after_model_mode()
                    del evaluate_solverwrapper

                    for key in metrics_cls.keys():
                        metrics_cls[key] = metrics_cls[key][0]

                    for key in metrics_reg.keys():
                        metrics_reg[key] = metrics_reg[key][0]

                    if self._use_tensorboard:
                        self._tensor_writer.add_scalars('metrics_cls',
                                                        metrics_cls,
                                                        global_step=step)
                        self._tensor_writer.add_scalars('metrics_reg',
                                                        metrics_reg,
                                                        global_step=step)

            if step in self._lr_decay_steps:
                self._lr *= self._lr_decay
                self._optimizer = self._train_optimizer()

            if re_cnt:
                tp, tf, fg, bg = 0., 0., 0., 0.
                train_loss = 0
                rpn_cls_loss = 0
                rpn_bbox_loss = 0
                fast_rcnn_cls_loss = 0
                fast_rcnn_bbox_loss = 0
                step_cnt = 0
                t.tic()
                re_cnt = False

        if self._use_tensorboard:
            self._tensor_writer.export_scalars_to_json(
                os.path.join(self.tbdir, 'all_scalars.json'))

    def save_check_point(self, step):
        net = self.net

        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

        # store the model snapshot
        filename = os.path.join(self.model_dir,
                                'fasterRcnn_iter_{}.h5'.format(step))
        h5f = h5py.File(filename, mode='w')
        for k, v in net.state_dict().items():
            h5f.create_dataset(k, data=v.cpu().numpy())

        # store data information
        nfilename = os.path.join(self.model_dir,
                                 'fasterRcnn_iter_{}.pkl'.format(step))
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indexes of the database
        perm = self.data_layer._perm
        # current position in the validation database
        cur_val = self.data_layer_val._cur
        # current shuffled indexes of the validation database
        perm_val = self.data_layer_val._perm
        # current learning rate
        lr = self._lr

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(lr, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(step, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename

    def load_check_point(self, step):
        net = self.net
        filename = os.path.join(self.model_dir,
                                'fasterRcnn_iter_{}.h5'.format(step))
        nfilename = os.path.join(self.model_dir,
                                 'fasterRcnn_iter_{}.pkl'.format(step))
        print('Restoring model snapshots from {:s}'.format(filename))

        if not os.path.exists(filename):
            print('The checkPoint is not Right')
            sys.exit(1)

        # load model
        h5f = h5py.File(filename, mode='r')
        for k, v in net.state_dict().items():
            param = torch.from_numpy(np.asarray(h5f[k]))
            v.copy_(param)

        # load data information
        with open(nfilename, 'rb') as fid:
            st0 = pickle.load(fid)
            cur = pickle.load(fid)
            perm = pickle.load(fid)
            cur_val = pickle.load(fid)
            perm_val = pickle.load(fid)
            lr = pickle.load(fid)
            last_snapshot_iter = pickle.load(fid)

            np.random.set_state(st0)
            self.data_layer._cur = cur
            self.data_layer._perm = perm
            self.data_layer_val._cur = cur_val
            self.data_layer_val._perm = perm_val
            self._lr = lr

        if last_snapshot_iter == step:
            print('Restore over ')
        else:
            print('The checkPoint is not Right')
            raise ValueError

        return last_snapshot_iter

    #初始化网络权重
    def weights_normal_init(self, model, dev=0.01):
        import math

        def _gaussian_init(m, dev):
            m.weight.data.normal_(0.0, dev)
            if hasattr(m.bias, 'data'):
                m.bias.data.zero_()

        def _xaiver_init(m):
            nn.init.xavier_normal(m.weight.data)
            if hasattr(m.bias, 'data'):
                m.bias.data.zero_()

        def _hekaiming_init(m):
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, math.sqrt(2. / n))
            if hasattr(m.bias, 'data'):
                m.bias.data.zero_()

        def _resnet_init(model, dev):
            if isinstance(model, list):
                for m in model:
                    self.weights_normal_init(m, dev)
            else:
                for m in model.modules():
                    if isinstance(m, nn.Conv2d):
                        _hekaiming_init(m)
                    elif isinstance(m, nn.BatchNorm2d):
                        m.weight.data.fill_(1)
                        m.bias.data.zero_()
                    elif isinstance(m, nn.Linear):
                        _gaussian_init(m, dev)

        def _vgg_init(model, dev):
            if isinstance(model, list):
                for m in model:
                    self.weights_normal_init(m, dev)
            else:
                for m in model.modules():
                    if isinstance(m, nn.Conv2d):
                        _gaussian_init(m, dev)
                    elif isinstance(m, nn.Linear):
                        _gaussian_init(m, dev)
                    elif isinstance(m, nn.BatchNorm2d):
                        m.weight.data.fill_(1)
                        m.bias.data.zero_()

        if cfg.TRAIN.INIT_WAY == 'resnet':
            _vgg_init(model, dev)
        elif cfg.TRAIN.INIT_WAY == 'vgg':
            _vgg_init(model, dev)
        else:
            raise NotImplementedError

    #载入存档点,初始化权重,设置优化函数,设置网络权重学习率
    def prepare_construct(self, resume_iter):
        # init network
        self.net.init_fasterRCNN()

        # Set the random seed
        torch.manual_seed(cfg.RNG_SEED)
        np.random.seed(cfg.RNG_SEED)

        # Set learning rate and momentum
        self._lr = cfg.TRAIN.LEARNING_RATE
        self._lr_decay = 0.1
        self._momentum = cfg.TRAIN.MOMENTUM
        self._weight_decay = cfg.TRAIN.WEIGHT_DECAY

        # load model
        if resume_iter:
            self.start_step = resume_iter + 1
            self.load_check_point(resume_iter)
        else:
            self.start_step = 0
            self.weights_normal_init(self.net, dev=0.01)
            # refer to caffe faster RCNN
            self.net.init_special_bbox_fc(dev=0.001)
            if self.pretrained_model != None:
                self.net._rpn._network._load_pre_trained_model(
                    self.pretrained_model)
                print('Load parameters from Path: {}'.format(
                    self.pretrained_model))
            else:
                pass

        if cfg.CUDA_IF:
            self.net.cuda()

        # BN should be fixed
        self.after_model_mode()

        # set optimizer
        self._parameters = [
            params for params in self.net.parameters()
            if params.requires_grad == True
        ]
        self._optimizer = self._train_optimizer()

        # tensorboard
        if self._use_tensorboard:
            import tensorboardX as tbx
            self._tensor_writer = tbx.SummaryWriter(log_dir=self.tbdir)

    def after_model_mode(self):
        # model
        self.net.train()

        # resnet fixed BN should be eval
        if cfg.TRAIN.INIT_WAY == 'resnet':
            self.net._rpn._network._bn_eval()

    def _train_optimizer(self):
        parameters = self._train_parameter()
        optimizer = torch.optim.SGD(parameters, momentum=self._momentum)
        return optimizer

    def _train_parameter(self):
        params = []
        for key, value in self.net.named_parameters():
            if value.requires_grad == True:
                if 'bias' in key:
                    params += [{
                        'params': [value],
                        'lr': self._lr * (cfg.TRAIN.DOUBLE_BIAS + 1),
                        'weight_decay': 0
                    }]
                else:
                    params += [{
                        'params': [value],
                        'lr': self._lr,
                        'weight_decay': self._weight_decay
                    }]
        return params
class SolverWrapper(object):
    """
      A wrapper class for the training process
    """

    def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None):
        self.net = network
        self.imdb = imdb
        self.roidb = roidb
        self.valroidb = valroidb
        self.output_dir = output_dir
        self.tbdir = tbdir
        # Simply put '_val' at the end to save the summaries from the validation set
        self.tbvaldir = tbdir + '_val'
        if not os.path.exists(self.tbvaldir):
            os.makedirs(self.tbvaldir)
        self.pretrained_model = pretrained_model

    def snapshot(self, sess, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt'
        filename = os.path.join(self.output_dir, filename)
        self.saver.save(sess, filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indexes of the database
        perm = self.data_layer._perm
        # current position in the validation database
        cur_val = self.data_layer_val._cur
        # current shuffled indexes of the validation database
        perm_val = self.data_layer_val._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename

    def from_snapshot(self, sess, sfile, nfile):
        print('Restoring model snapshots from {:s}'.format(sfile))
        self.saver.restore(sess, sfile)
        print('Restored.')
        # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
        # tried my best to find the random states so that it can be recovered exactly
        # However the Tensorflow state is currently not available
        with open(nfile, 'rb') as fid:
            st0 = pickle.load(fid)
            cur = pickle.load(fid)
            perm = pickle.load(fid)
            cur_val = pickle.load(fid)
            perm_val = pickle.load(fid)
            last_snapshot_iter = pickle.load(fid)

            np.random.set_state(st0)
            self.data_layer._cur = cur
            self.data_layer._perm = perm
            self.data_layer_val._cur = cur_val
            self.data_layer_val._perm = perm_val

        return last_snapshot_iter

    def get_variables_in_checkpoint_file(self, file_name):
        try:
            print('&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&')
            # reader = tf.train.NewCheckpointReader(file_name)
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print("It's likely that your checkpoint file has been compressed "
                      "with SNAPPY.")

    def construct_graph(self, sess):
        with sess.graph.as_default():
            # Set the random seed for tensorflow
            tf.set_random_seed(cfg.RNG_SEED)
            # Build the main computation graph
            layers = self.net.create_architecture('TRAIN', self.imdb.num_classes, tag='default',
                                                  anchor_scales=cfg.ANCHOR_SCALES,
                                                  anchor_ratios=cfg.ANCHOR_RATIOS)
            # Define the loss
            losses = layers['all_losses']
            loss = losses['total_loss']
            m1_loss = losses['M1']['total_loss']
            m2_loss = losses['M2']['total_loss']
            m3_loss = losses['M3']['total_loss']
            # Set learning rate and momentum
            lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
            self.optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)
            self.optimizer_m1 = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)
            self.optimizer_m2 = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)
            self.optimizer_m3 = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)

            # Compute the gradients with regard to the loss
            gvs = self.optimizer.compute_gradients(loss)
            gvs_m1 = self.optimizer_m1.compute_gradients(m1_loss)
            gvs_m2 = self.optimizer_m2.compute_gradients(m2_loss)
            gvs_m3 = self.optimizer_m3.compute_gradients(m3_loss)
            # Double the gradient of the bias if set
            if cfg.TRAIN.DOUBLE_BIAS:
                final_gvs = []
                final_gvs_m1 = []
                final_gvs_m2 = []
                final_gvs_m3 = []
                with tf.variable_scope('Gradient_Mult') as scope:
                    for grad, var in gvs:
                        # print("grad, var:", grad, var)
                        scale = 1.
                        if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((grad, var))
                train_op = self.optimizer.apply_gradients(final_gvs)
                with tf.variable_scope('Gradient_Mult_m1') as scope:
                    for grad, var in gvs_m1:
                        scale = 1.
                        if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs_m1.append((grad, var))
                train_m1_op = self.optimizer.apply_gradients(final_gvs_m1)
                with tf.variable_scope('Gradient_Mult_m2') as scope:
                    for grad, var in gvs_m2:
                        scale = 1.
                        if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs_m2.append((grad, var))
                train_m2_op = self.optimizer.apply_gradients(final_gvs_m2)
                with tf.variable_scope('Gradient_Mult_m3') as scope:
                    for grad, var in gvs_m3:
                        scale = 1.
                        if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs_m3.append((grad, var))
                train_m3_op = self.optimizer.apply_gradients(final_gvs_m3)
            else:
                train_op = self.optimizer.apply_gradients(gvs)
                train_m1_op = self.optimizer_m1.apply_gradients(gvs_m1)
                train_m2_op = self.optimizer_m2.apply_gradients(gvs_m2)
                train_m3_op = self.optimizer_m3.apply_gradients(gvs_m3)

            # group the three independent train_op
            final_train_op = tf.group(train_m1_op, train_m2_op, train_m3_op)
            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=100000)
            # Write the train and validation information to tensorboard
            self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
            self.valwriter = tf.summary.FileWriter(self.tbvaldir)

        return lr, train_op, train_m1_op, train_m2_op, train_m3_op, final_train_op

    def find_previous(self):
        sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
        sfiles = glob.glob(sfiles)
        sfiles.sort(key=os.path.getmtime)
        # Get the snapshot name in TensorFlow
        redfiles = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            redfiles.append(os.path.join(self.output_dir,
                                         cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}.ckpt.meta'.format(stepsize + 1)))
        sfiles = [ss.replace('.meta', '') for ss in sfiles if ss not in redfiles]

        nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
        nfiles = glob.glob(nfiles)
        nfiles.sort(key=os.path.getmtime)
        redfiles = [redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles]
        nfiles = [nn for nn in nfiles if nn not in redfiles]

        lsf = len(sfiles)
        assert len(nfiles) == lsf

        return lsf, nfiles, sfiles

    def initialize(self, sess):
        # Initial file lists are empty
        np_paths = []
        ss_paths = []
        # Fresh train directly from ImageNet weights
        print('Loading initial model weights from {:s}'.format(self.pretrained_model))
        variables = tf.global_variables()
        print("variables:", variables)
        if 'darknet53' in self.pretrained_model:
            print('the base network is Darknet53!!!')
            sess.run(tf.variables_initializer(variables, name='init'))
            self.net.restored_from_npz(sess)
            print('Loaded.')
            # print('>>>>>>>', variables[0].eval())
            last_snapshot_iter = 0
            rate = cfg.TRAIN.LEARNING_RATE
            stepsizes = list(cfg.TRAIN.STEPSIZE)
        else:
            sess.run(tf.variables_initializer(variables, name='init'))
            var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model)
            # Get the variables to restore, ignoring the variables to fix
            variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic)
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(sess, self.pretrained_model)
            print('Loaded.')

            # Need to fix the variables before loading, so that the RGB weights are changed to BGR
            # For VGG16 it also changes the convolutional weights
            self.net.fix_variables(sess, self.pretrained_model)
            print('Fixed.')
            last_snapshot_iter = 0
            rate = cfg.TRAIN.LEARNING_RATE
            stepsizes = list(cfg.TRAIN.STEPSIZE)

        return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def restore(self, sess, sfile, nfile):
        # Get the most recent snapshot and restore
        variables = tf.global_variables()
        print("variables:", variables)
        np_paths = [nfile]
        ss_paths = [sfile]
        # Restore model from snapshots
        last_snapshot_iter = self.from_snapshot(sess, sfile, nfile)
        # Set the learning rate
        rate = cfg.TRAIN.LEARNING_RATE
        stepsizes = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            if last_snapshot_iter > stepsize:
                rate *= cfg.TRAIN.GAMMA
            else:
                stepsizes.append(stepsize)

        return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def remove_snapshot(self, np_paths, ss_paths):
        to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            nfile = np_paths[0]
            os.remove(str(nfile))
            np_paths.remove(nfile)

        to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            sfile = ss_paths[0]
            # To make the code compatible to earlier versions of Tensorflow,
            # where the naming tradition for checkpoints are different
            if os.path.exists(str(sfile)):
                os.remove(str(sfile))
            else:
                os.remove(str(sfile + '.data-00000-of-00001'))
                os.remove(str(sfile + '.index'))
            sfile_meta = sfile + '.meta'
            os.remove(str(sfile_meta))
            ss_paths.remove(sfile)

    def train_model(self, sess, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)

        # Construct the computation graph
        # lr, train_op = self.construct_graph(sess)
        lr, train_op, train_m1_op, train_m2_op, train_m3_op, final_train_op = self.construct_graph(sess)

        # Find previous snapshots if there is any to restore from
        lsf, nfiles, sfiles = self.find_previous()

        # Initialize the variables or restore them from the last snapshot
        if lsf == 0:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(sess)
        else:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(sess,
                                                                                   str(sfiles[-1]),
                                                                                   str(nfiles[-1]))
        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        # Make sure the lists are not empty
        stepsizes.append(max_iters)
        stepsizes.reverse()
        next_stepsize = stepsizes.pop()
        while iter < max_iters + 1:
            # Learning rate
            if iter == next_stepsize + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(sess, iter)
                rate *= cfg.TRAIN.GAMMA
                sess.run(tf.assign(lr, rate))
                next_stepsize = stepsizes.pop()

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            now = time.time()
            if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                losses, summary = self.net.train_step_with_summary(sess, blobs, final_train_op)
                self.writer.add_summary(summary, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(sess, blobs_val)
                self.valwriter.add_summary(summary_val, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                losses = self.net.train_step(sess, blobs, final_train_op)
            timer.toc()

            # get the corresponding loss to show
            m1_cls_loss = losses['M1']['rpn_cross_entropy']
            m1_box_loss = losses['M1']['rpn_loss_box']
            m1_kp_loss = losses['M1']['kpoints_loss']
            m1_total_loss = losses['M1']['total_loss']
            m2_cls_loss = losses['M2']['rpn_cross_entropy']
            m2_box_loss = losses['M2']['rpn_loss_box']
            m2_kp_loss = losses['M2']['kpoints_loss']
            m2_total_loss = losses['M2']['total_loss']
            m3_cls_loss = losses['M3']['rpn_cross_entropy']
            m3_box_loss = losses['M3']['rpn_loss_box']
            m3_kp_loss = losses['M3']['kpoints_loss']
            m3_total_loss = losses['M3']['total_loss']
            total_loss = losses['total_loss']

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                print('iter: %d / %d \n >>> m1_cls_loss: %.6f, m1_box_loss: %.6f, m1_kp_loss: %.6f, m1_total_loss: %.6f\n '
                      '>>> m2_cls_loss: %.6f, m2_box_loss: %.6f, m2_kp_loss: %.6f, m2_total_loss: %.6f\n '
                      '>>> m3_cls_loss: %.6f, m3_box_loss: %.6f, m3_kp_loss: %.6f, m3_total_loss: %.6f\n '
                      '>>> total_loss: %.6f, lr: %f' % \
                      (iter, max_iters, m1_cls_loss, m1_box_loss, m1_kp_loss, m1_total_loss, m2_cls_loss, m2_box_loss, m2_kp_loss, m2_total_loss,
                       m3_cls_loss, m3_box_loss, m3_kp_loss, m3_total_loss, total_loss, lr.eval()))
                print('speed: {:.3f}s / iter'.format(timer.average_time))

            # Snapshotting
            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                ss_path, np_path = self.snapshot(sess, iter)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(sess, iter - 1)

        self.writer.close()
        self.valwriter.close()

    def train_model_old(self, sess, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)

        # Construct the computation graph
        lr, train_op, train_m1_op, train_m2_op, train_m3_op, _ = self.construct_graph(sess)

        # Find previous snapshots if there is any to restore from
        lsf, nfiles, sfiles = self.find_previous()

        # Initialize the variables or restore them from the last snapshot
        if lsf == 0:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(sess)
        else:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(sess,
                                                                                   str(sfiles[-1]),
                                                                                   str(nfiles[-1]))
        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        # Make sure the lists are not empty
        stepsizes.append(max_iters)
        stepsizes.reverse()
        next_stepsize = stepsizes.pop()
        m1_iters = 0
        m2_iters = 0
        m3_iters = 0
        while iter < max_iters + 1:
            random_seed = np.random.rand()
            if random_seed < 0.33:
                module = "M1"
                train_op = train_m1_op
                m1_iters += 1
            elif 0.33 <= random_seed < 0.67:
                module = "M2"
                train_op = train_m2_op
                m2_iters += 1
            else:
                module = "M3"
                train_op = train_m3_op
                m3_iters += 1
            # Learning rate
            if iter == next_stepsize + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(sess, iter)
                rate *= cfg.TRAIN.GAMMA
                sess.run(tf.assign(lr, rate))
                next_stepsize = stepsizes.pop()

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            now = time.time()
            if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                rpn_loss_cls, rpn_loss_box, landmarks_loss, total_loss, summary = self.net.train_step_with_summary_old(sess, blobs,
                                                                                                   train_op, module)
                self.writer.add_summary(summary, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(sess, blobs_val)
                self.valwriter.add_summary(summary_val, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                rpn_loss_cls, rpn_loss_box, kpoints_loss, total_loss = self.net.train_step_old(sess, blobs, train_op, module)
            timer.toc()

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                if module == 'M1':
                    iters = m1_iters
                elif module == 'M2':
                    iters = m2_iters
                else:
                    iters = m3_iters
                print('iter: %d / %d, now training module: %s, iters: %d,\n >>> total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> kpoints_loss: %.6f\n >>> lr: %f' % \
                      (iter, max_iters, module, iters, total_loss, rpn_loss_cls, rpn_loss_box, kpoints_loss, lr.eval()))
                print('speed: {:.3f}s / iter'.format(timer.average_time))

            # Snapshotting
            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                ss_path, np_path = self.snapshot(sess, iter)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(sess, iter - 1)

        self.writer.close()
        self.valwriter.close()
class SolverWrapper(object):
    def __init__(self,
                 network,
                 imdb,
                 roidb,
                 valroidb,
                 model_dir,
                 pretrained_model=None):
        self.net = network
        self.imdb = imdb
        self.roidb = roidb
        self.valroidb = valroidb
        self.model_dir = model_dir
        self.tbdir = os.path.join(model_dir, 'train_log')
        if not os.path.exists(self.tbdir):
            os.makedirs(self.tbdir)
        self.pretrained_model = pretrained_model

    def set_learn_strategy(self, learn_dict):
        self._disp_interval = learn_dict['disp_interval']
        self._valid_interval = learn_dict['disp_interval'] * 5
        self._use_tensorboard = learn_dict['use_tensorboard']
        self._use_valid = learn_dict['use_valid']
        self._save_point_interval = learn_dict['save_point_interval']
        self._lr_decay_steps = learn_dict['lr_decay_steps']

    def train_model(self, resume=None, max_iters=100000):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb,
                                           self.imdb.num_classes,
                                           random=True)

        self.prepare_construct(resume)

        net = self.net
        # training
        train_loss = 0
        rpn_cls_loss = 0
        rpn_bbox_loss = 0
        fast_rcnn_cls_loss = 0
        fast_rcnn_bbox_loss = 0
        tp, tf, fg, bg = 0., 0., 0, 0
        step_cnt = 0
        re_cnt = False
        t = Timer()
        t.tic()
        for step in range(self.start_step, max_iters + 1):
            blobs = self.data_layer.forward()

            im_data = blobs['data']
            im_info = blobs['im_info']
            gt_boxes = blobs['gt_boxes']
            # forward
            result_cls_prob, result_bbox_pred, result_rois = net(
                im_data, im_info, gt_boxes)

            loss = net.loss + net._rpn.loss

            train_loss += loss.data.cpu()[0]
            rpn_cls_loss += net._rpn.cross_entropy.data.cpu()[0]
            rpn_bbox_loss += net._rpn.loss_box.data.cpu()[0]
            fast_rcnn_cls_loss += net.cross_entropy.data.cpu()[0]
            fast_rcnn_bbox_loss += net.loss_box.data.cpu()[0]
            step_cnt += 1

            # backward
            self._optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm(self._parameters, max_norm=10)
            self._optimizer.step()
            # clear middle memory
            net._delete_cache()

            if step % self._disp_interval == 0:
                duration = t.toc(average=False)
                fps = step_cnt / duration

                log_text = 'step %d, image: %s, loss: %.4f, fps: %.2f (%.2fs per batch)' % (
                    step, blobs['im_name'], train_loss / step_cnt, fps,
                    1. / fps)
                pprint.pprint(log_text)

                if self._use_tensorboard:
                    self._tensor_writer.add_text('Train',
                                                 log_text,
                                                 global_step=step)
                    # Train
                    avg_rpn_cls_loss = rpn_cls_loss / step_cnt
                    avg_rpn_bbox_loss = rpn_bbox_loss / step_cnt
                    avg_fast_rcnn_cls_loss = fast_rcnn_cls_loss / step_cnt
                    avg_fast_rcnn_bbox_loss = fast_rcnn_bbox_loss / step_cnt

                    self._tensor_writer.add_scalars(
                        'TrainSetLoss', {
                            'RPN_cls_loss': avg_rpn_cls_loss,
                            'RPN_bbox_loss': avg_rpn_bbox_loss,
                            'FastRcnn_cls_loss': avg_fast_rcnn_cls_loss,
                            'FastRcnn_bbox_loss': avg_fast_rcnn_bbox_loss
                        },
                        global_step=step)
                    self._tensor_writer.add_scalar('Learning_rate',
                                                   self._lr,
                                                   global_step=step)

                re_cnt = True

            if self._use_tensorboard and step % self._valid_interval == 0:
                new_gt_boxes = gt_boxes.copy()
                new_gt_boxes[:, :4] = new_gt_boxes[:, :4]
                image = self.back_to_image(blobs['data']).astype(np.uint8)

                im_shape = image.shape
                pred_boxes, scores, classes = net.interpret_faster_rcnn_scale(
                    result_cls_prob,
                    result_bbox_pred,
                    result_rois,
                    im_shape,
                    min_score=0.1)
                image = self.draw_photo(image, pred_boxes, scores, classes,
                                        new_gt_boxes)
                image = torchtrans.ToTensor()(image)
                image = vutils.make_grid([image])
                self._tensor_writer.add_image('Image', image, step)

            if self._use_valid and step % self._valid_interval == 0:
                total_valid_loss = 0.0
                valid_rpn_cls_loss = 0.0
                valid_rpn_bbox_loss = 0.0
                valid_fast_rcnn_cls_loss = 0.0
                valid_fast_rcnn_bbox_loss = 0.0
                valid_step_cnt = 0
                start_time = time.time()

                valid_length = self._disp_interval
                net.eval()
                for valid_batch in range(valid_length):
                    # get one batch
                    blobs = self.data_layer_val.forward()

                    im_data = blobs['data']
                    im_info = blobs['im_info']
                    gt_boxes = blobs['gt_boxes']

                    # forward
                    result_cls_prob, result_bbox_pred, result_rois = net(
                        im_data, im_info, gt_boxes)
                    valid_loss = net.loss + net._rpn.loss

                    total_valid_loss += valid_loss.data.cpu()[0]
                    valid_rpn_cls_loss += net._rpn.cross_entropy.data.cpu()[0]
                    valid_rpn_bbox_loss += net._rpn.loss_box.data.cpu()[0]
                    valid_fast_rcnn_cls_loss += net.cross_entropy.data.cpu()[0]
                    valid_fast_rcnn_bbox_loss += net.loss_box.data.cpu()[0]
                    valid_step_cnt += 1
                net.train()
                duration = time.time() - start_time
                fps = valid_step_cnt / duration

                log_text = 'step %d, valid average loss: %.4f, fps: %.2f (%.2fs per batch)' % (
                    step, total_valid_loss / valid_step_cnt, fps, 1. / fps)
                pprint.pprint(log_text)

                if self._use_tensorboard:
                    self._tensor_writer.add_text('Valid',
                                                 log_text,
                                                 global_step=step)
                    new_gt_boxes = gt_boxes.copy()
                    new_gt_boxes[:, :4] = new_gt_boxes[:, :4]
                    image = self.back_to_image(blobs['data']).astype(np.uint8)

                    im_shape = image.shape
                    pred_boxes, scores, classes = net.interpret_faster_rcnn_scale(
                        result_cls_prob,
                        result_bbox_pred,
                        result_rois,
                        im_shape,
                        min_score=0.1)
                    image = self.draw_photo(image, pred_boxes, scores, classes,
                                            new_gt_boxes)
                    image = torchtrans.ToTensor()(image)
                    image = vutils.make_grid([image])
                    self._tensor_writer.add_image('Image_Valid', image, step)

                if self._use_tensorboard:
                    # Valid
                    avg_rpn_cls_loss_valid = valid_rpn_cls_loss / valid_step_cnt
                    avg_rpn_bbox_loss_valid = valid_rpn_bbox_loss / valid_step_cnt
                    avg_fast_rcnn_cls_loss_valid = valid_fast_rcnn_cls_loss / valid_step_cnt
                    avg_fast_rcnn_bbox_loss_valid = valid_fast_rcnn_bbox_loss / valid_step_cnt
                    real_total_loss_valid = valid_rpn_cls_loss + valid_rpn_bbox_loss + valid_fast_rcnn_cls_loss + valid_fast_rcnn_bbox_loss

                    # Train
                    avg_rpn_cls_loss = rpn_cls_loss / step_cnt
                    avg_rpn_bbox_loss = rpn_bbox_loss / step_cnt
                    avg_fast_rcnn_cls_loss = fast_rcnn_cls_loss / step_cnt
                    avg_fast_rcnn_bbox_loss = fast_rcnn_bbox_loss / step_cnt
                    real_total_loss = rpn_cls_loss + rpn_bbox_loss + fast_rcnn_cls_loss + fast_rcnn_bbox_loss

                    self._tensor_writer.add_scalars(
                        'Total_Loss', {
                            'train': train_loss / step_cnt,
                            'valid': total_valid_loss / valid_step_cnt
                        },
                        global_step=step)
                    self._tensor_writer.add_scalars(
                        'Real_loss', {
                            'train': real_total_loss / step_cnt,
                            'valid': real_total_loss_valid / valid_step_cnt
                        },
                        global_step=step)
                    self._tensor_writer.add_scalars(
                        'RPN_cls_loss', {
                            'train': avg_rpn_cls_loss,
                            'valid': avg_rpn_cls_loss_valid
                        },
                        global_step=step)
                    self._tensor_writer.add_scalars(
                        'RPN_bbox_loss', {
                            'train': avg_rpn_bbox_loss,
                            'valid': avg_rpn_bbox_loss_valid
                        },
                        global_step=step)
                    self._tensor_writer.add_scalars(
                        'FastRcnn_cls_loss', {
                            'train': avg_fast_rcnn_cls_loss,
                            'valid': avg_fast_rcnn_cls_loss_valid
                        },
                        global_step=step)
                    self._tensor_writer.add_scalars(
                        'FastRcnn_bbox_loss', {
                            'train': avg_fast_rcnn_bbox_loss,
                            'valid': avg_fast_rcnn_bbox_loss_valid
                        },
                        global_step=step)

                    self._tensor_writer.add_scalars(
                        'ValidSetLoss', {
                            'RPN_cls_loss': avg_rpn_cls_loss_valid,
                            'RPN_bbox_loss': avg_rpn_bbox_loss_valid,
                            'FastRcnn_cls_loss': avg_fast_rcnn_cls_loss_valid,
                            'FastRcnn_bbox_loss': avg_fast_rcnn_bbox_loss_valid
                        },
                        global_step=step)

                    # self._tensor_writer.add_scalars('TrainSetLoss', {
                    #   'RPN_cls_loss': avg_rpn_cls_loss,
                    #   'RPN_bbox_loss': avg_rpn_bbox_loss,
                    #   'FastRcnn_cls_loss': avg_fast_rcnn_cls_loss,
                    #   'FastRcnn_bbox_loss': avg_fast_rcnn_bbox_loss
                    # }, global_step=step)
                    # self._tensor_writer.add_scalar('Learning_rate', self._lr, global_step=step)

            if (step % self._save_point_interval == 0) and step > 0:
                save_name, _ = self.save_check_point(step)
                print('save model: {}'.format(save_name))

            if step in self._lr_decay_steps:
                self._lr *= self._lr_decay
                self._optimizer = self._train_optimizer()

            if re_cnt:
                tp, tf, fg, bg = 0., 0., 0, 0
                train_loss = 0
                rpn_cls_loss = 0
                rpn_bbox_loss = 0
                fast_rcnn_cls_loss = 0
                fast_rcnn_bbox_loss = 0
                step_cnt = 0
                t.tic()
                re_cnt = False

        if self._use_tensorboard:
            self._tensor_writer.export_scalars_to_json(
                os.path.join(self.tbdir, 'all_scalars.json'))

    def draw_photo(self, image, dets, scores, classes, gt_boxes):
        # im2show = np.copy(image)
        im2show = image
        # color_b = (0, 191, 255)
        for i, det in enumerate(dets):
            det = tuple(int(x) for x in det)
            r = min(0 + i * 10, 255)
            r_i = i / 5
            g = min(150 + r_i * 10, 255)
            g_i = r_i / 5
            b = min(200 + g_i, 255)
            color_b_c = (r, g, b)
            cv2.rectangle(im2show, det[0:2], det[2:4], color_b_c, 2)
            cv2.putText(im2show,
                        '%s: %.3f' % (classes[i], scores[i]),
                        (det[0], det[1] + 15),
                        cv2.FONT_HERSHEY_PLAIN,
                        1.0, (0, 0, 255),
                        thickness=1)
        for i, det in enumerate(gt_boxes):
            det = tuple(int(x) for x in det)
            gt_class = self.net._classes[det[-1]]
            cv2.rectangle(im2show, det[0:2], det[2:4], (255, 0, 0), 2)
            cv2.putText(im2show,
                        '%s' % (gt_class), (det[0], det[1] + 15),
                        cv2.FONT_HERSHEY_PLAIN,
                        1.0, (0, 0, 255),
                        thickness=1)
        return im2show

    def back_to_image(self, img):
        image = img[0] + cfg.PIXEL_MEANS
        image = image[:, :, ::-1].copy(order='C')
        return image

    def save_check_point(self, step):
        net = self.net

        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

        # store the model snapshot
        filename = os.path.join(self.model_dir,
                                'fasterRcnn_iter_{}.h5'.format(step))
        h5f = h5py.File(filename, mode='w')
        for k, v in net.state_dict().items():
            h5f.create_dataset(k, data=v.cpu().numpy())

        # store data information
        nfilename = os.path.join(self.model_dir,
                                 'fasterRcnn_iter_{}.pkl'.format(step))
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indexes of the database
        perm = self.data_layer._perm
        # current position in the validation database
        cur_val = self.data_layer_val._cur
        # current shuffled indexes of the validation database
        perm_val = self.data_layer_val._perm
        # current learning rate
        lr = self._lr

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(lr, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(step, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename

    def load_check_point(self, step):
        net = self.net
        filename = os.path.join(self.model_dir,
                                'fasterRcnn_iter_{}.h5'.format(step))
        nfilename = os.path.join(self.model_dir,
                                 'fasterRcnn_iter_{}.pkl'.format(step))
        print('Restoring model snapshots from {:s}'.format(filename))

        if not os.path.exists(filename):
            print('The checkPoint is not Right')
            sys.exit(1)

        # load model
        h5f = h5py.File(filename, mode='r')
        for k, v in net.state_dict().items():
            param = torch.from_numpy(np.asarray(h5f[k]))
            v.copy_(param)

        # load data information
        with open(nfilename, 'rb') as fid:
            st0 = pickle.load(fid)
            cur = pickle.load(fid)
            perm = pickle.load(fid)
            cur_val = pickle.load(fid)
            perm_val = pickle.load(fid)
            lr = pickle.load(fid)
            last_snapshot_iter = pickle.load(fid)

            np.random.set_state(st0)
            self.data_layer._cur = cur
            self.data_layer._perm = perm
            self.data_layer_val._cur = cur_val
            self.data_layer_val._perm = perm_val
            self._lr = lr

        if last_snapshot_iter == step:
            print('Restore over ')
        else:
            print('The checkPoint is not Right')
            raise ValueError

        return last_snapshot_iter

    def weights_normal_init(self, model, dev=0.01):
        import math

        def _gaussian_init(m, dev):
            m.weight.data.normal_(0.0, dev)
            if hasattr(m.bias, 'data'):
                m.bias.data.zero_()

        def _xaiver_init(m):
            nn.init.xavier_normal(m.weight.data)
            if hasattr(m.bias, 'data'):
                m.bias.data.zero_()

        def _hekaiming_init(m):
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, math.sqrt(2. / n))
            if hasattr(m.bias, 'data'):
                m.bias.data.zero_()

        def _resnet_init(model, dev):
            if isinstance(model, list):
                for m in model:
                    self.weights_normal_init(m, dev)
            else:
                for m in model.modules():
                    if isinstance(m, nn.Conv2d):
                        _hekaiming_init(m)
                    elif isinstance(m, nn.BatchNorm2d):
                        m.weight.data.fill_(1)
                        m.bias.data.zero_()
                    elif isinstance(m, nn.Linear):
                        _gaussian_init(m, dev)

        def _vgg_init(model, dev):
            if isinstance(model, list):
                for m in model:
                    self.weights_normal_init(m, dev)
            else:
                for m in model.modules():
                    if isinstance(m, nn.Conv2d):
                        _gaussian_init(m, dev)
                    elif isinstance(m, nn.Linear):
                        _gaussian_init(m, dev)
                    elif isinstance(m, nn.BatchNorm2d):
                        m.weight.data.fill_(1)
                        m.bias.data.zero_()

        if cfg.TRAIN.INIT_WAY == 'resnet':
            _vgg_init(model, dev)
        elif cfg.TRAIN.INIT_WAY == 'vgg':
            _vgg_init(model, dev)
        else:
            raise NotImplementedError

    def prepare_construct(self, resume_iter):
        # init network
        self.net.init_fasterRCNN()

        # Set the random seed
        torch.manual_seed(cfg.RNG_SEED)
        np.random.seed(cfg.RNG_SEED)

        # Set learning rate and momentum
        self._lr = cfg.TRAIN.LEARNING_RATE
        self._lr_decay = 0.1
        self._momentum = cfg.TRAIN.MOMENTUM
        self._weight_decay = cfg.TRAIN.WEIGHT_DECAY

        # load model
        if resume_iter:
            self.start_step = resume_iter + 1
            self.load_check_point(resume_iter)
        else:
            self.start_step = 0
            self.weights_normal_init(self.net, dev=0.01)
            # refer to caffe faster RCNN
            self.net.init_special_bbox_fc(dev=0.001)
            if self.pretrained_model != None:
                self.net._rpn._network._load_pre_trained_model(
                    self.pretrained_model)
                print('Load parameters from Path: {}'.format(
                    self.pretrained_model))
            else:
                pass

        # model
        self.net.train()
        if cfg.CUDA_IF:
            self.net.cuda()

        # resnet fixed BN should be eval
        if cfg.TRAIN.INIT_WAY == 'resnet':
            self.net._rpn._network._bn_eval()

        # set optimizer
        self._parameters = [
            params for params in self.net.parameters()
            if params.requires_grad == True
        ]
        self._optimizer = self._train_optimizer()

        # tensorboard
        if self._use_tensorboard:
            import tensorboardX as tbx
            self._tensor_writer = tbx.SummaryWriter(log_dir=self.tbdir)

    def _train_optimizer(self):
        parameters = self._train_parameter()
        optimizer = torch.optim.SGD(parameters, momentum=self._momentum)
        return optimizer

    def _train_parameter(self):
        params = []
        for key, value in self.net.named_parameters():
            if value.requires_grad == True:
                if 'bias' in key:
                    params += [{
                        'params': [value],
                        'lr': self._lr * (cfg.TRAIN.DOUBLE_BIAS + 1),
                        'weight_decay': 0
                    }]
                else:
                    params += [{
                        'params': [value],
                        'lr': self._lr,
                        'weight_decay': self._weight_decay
                    }]
        return params
class SolverWrapper(object):
    """
    A wrapper class for the training process
    据作者的说法,这个类就是为了方便自己使用Python代码来控制训练过程中的相关东西  
    
    
  """
    def __init__(self, sess, network, imdb, roidb, valroidb, output_dir,
                 pretrained_model):
        self.net = network  #网络 vgg 或者 resnet
        self.imdb = imdb  #数据库
        self.roidb = roidb  #region of insterest
        self.valroidb = valroidb  #tensorboard 输出文件
        self.output_dir = output_dir  #结果输出文件
        #self.tbdir = tbdir
        # Simply put '_val' at the end to save the summaries from the validation set
        #self.tbvaldir = tbdir + '_val'
        #if not os.path.exists(self.tbvaldir):
        #os.makedirs(self.tbvaldir)
        self.pretrained_model = pretrained_model

    def snapshot(self, sess, iter):  #这个函数做一些存储 存储
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter
        ) + '.ckpt'  #__C.TRAIN.SNAPSHOT_PREFIX = 'res101_faster_rcnn' 默认输出模型
        filename = os.path.join(self.output_dir, filename)  #存在output下面
        self.saver.save(sess, filename)  #存储tensor变量
        print('Wrote snapshot to: {:s}'.format(filename))
        return filename

    def from_snapshot(self, sess, sfile, nfile):
        print('Restoring model snapshots from {:s}'.format(sfile))
        self.saver.restore(sess, sfile)
        print('Restored.')
        # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
        # tried my best to find the random states so that it can be recovered exactly
        # However the Tensorflow state is currently not available
        with open(nfile, 'rb') as fid:
            st0 = pickle.load(fid)
            cur = pickle.load(fid)
            perm = pickle.load(fid)
            cur_val = pickle.load(fid)
            perm_val = pickle.load(fid)
            last_snapshot_iter = pickle.load(fid)

            np.random.set_state(st0)
            self.data_layer._cur = cur
            self.data_layer._perm = perm
            self.data_layer_val._cur = cur_val
            self.data_layer_val._perm = perm_val

        return last_snapshot_iter

    def get_variables_in_checkpoint_file(self, file_name):  #初始化模型 文件
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print(
                    "It's likely that your checkpoint file has been compressed "
                    "with SNAPPY.")

    def construct_graph(self, sess):  #该函数构建计算图
        with sess.graph.as_default():
            # Set the random seed for tensorflow
            tf.set_random_seed(cfg.RNG_SEED)  #随机种子
            # Build the main computation graph
            layers = self.net.create_architecture(
                'TRAIN',
                self.imdb.num_classes,
                anchor_scales=cfg.ANCHOR_SCALES,
                anchor_ratios=cfg.ANCHOR_RATIOS)
            # Define the loss
            loss = layers['total_loss']
            # Set learning rate and momentum
            lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)  #学习率
            self.optimizer = tf.train.MomentumOptimizer(
                lr, cfg.TRAIN.MOMENTUM)  #梯度优化器

            # Compute the gradients with regard to the loss
            gvs = self.optimizer.compute_gradients(loss)
            # Double the gradient of the bias if set
            if cfg.TRAIN.DOUBLE_BIAS:
                final_gvs = []
                with tf.variable_scope('Gradient_Mult') as scope:
                    for grad, var in gvs:
                        scale = 1.
                        if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((grad, var))
                train_op = self.optimizer.apply_gradients(final_gvs)
            else:
                train_op = self.optimizer.apply_gradients(gvs)

            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=100000)
            # Write the train and validation information to tensorboard
            #self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
            #self.valwriter = tf.summary.FileWriter(self.tbvaldir)
            print("构建网络通过")
        return lr, train_op

    def find_previous(self):
        sfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
        sfiles = glob.glob(sfiles)
        sfiles.sort(key=os.path.getmtime)
        # Get the snapshot name in TensorFlow
        redfiles = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            redfiles.append(
                os.path.join(
                    self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX +
                    '_iter_{:d}.ckpt.meta'.format(stepsize + 1)))
        sfiles = [
            ss.replace('.meta', '') for ss in sfiles if ss not in redfiles
        ]

        nfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
        nfiles = glob.glob(nfiles)
        nfiles.sort(key=os.path.getmtime)
        redfiles = [
            redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles
        ]
        nfiles = [nn for nn in nfiles if nn not in redfiles]

        lsf = len(sfiles)
        assert len(nfiles) == lsf

        return lsf, nfiles, sfiles

    def initialize(self, sess):
        # Initial file lists are empty
        np_paths = []
        ss_paths = []
        # Fresh train directly from ImageNet weights
        print('Loading initial model weights from {:s}'.format(
            self.pretrained_model))
        variables = tf.global_variables()
        # Initialize all variables first
        sess.run(tf.variables_initializer(variables, name='init'))
        var_keep_dic = self.get_variables_in_checkpoint_file(
            self.pretrained_model)  #把预训练网络参数 拿出来
        # Get the variables to restore, ignoring the variables to fix
        variables_to_restore = self.net.get_variables_to_restore(
            variables, var_keep_dic)

        restorer = tf.train.Saver(variables_to_restore)
        restorer.restore(sess, self.pretrained_model)
        print('Loaded.')
        # Need to fix the variables before loading, so that the RGB weights are changed to BGR
        # For VGG16 it also changes the convolutional weights fc6 and fc7 to
        # fully connected weights
        self.net.fix_variables(sess, self.pretrained_model)
        print('Fixed.')
        last_snapshot_iter = 0
        rate = cfg.TRAIN.LEARNING_RATE
        stepsizes = list(cfg.TRAIN.STEPSIZE)

        return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def restore(self, sess, sfile, nfile):
        # Get the most recent snapshot and restore
        np_paths = [nfile]
        ss_paths = [sfile]
        # Restore model from snapshots
        last_snapshot_iter = self.from_snapshot(sess, sfile, nfile)
        # Set the learning rate
        rate = cfg.TRAIN.LEARNING_RATE
        stepsizes = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            if last_snapshot_iter > stepsize:
                rate *= cfg.TRAIN.GAMMA
            else:
                stepsizes.append(stepsize)

        return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def remove_snapshot(self, np_paths, ss_paths):
        to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            nfile = np_paths[0]
            os.remove(str(nfile))
            np_paths.remove(nfile)

        to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            sfile = ss_paths[0]
            # To make the code compatible to earlier versions of Tensorflow,
            # where the naming tradition for checkpoints are different
            if os.path.exists(str(sfile)):
                os.remove(str(sfile))
            else:
                os.remove(str(sfile + '.data-00000-of-00001'))
                os.remove(str(sfile + '.index'))
            sfile_meta = sfile + '.meta'
            os.remove(str(sfile_meta))
            ss_paths.remove(sfile)

    def train_model(self, sess, max_iters):  #这个是 训练的核心函数
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        print("训练准备数据通过")
        self.data_layer_val = RoIDataLayer(self.valroidb,
                                           self.imdb.num_classes,
                                           random=True)
        print("测试数据通过")

        # Construct the computation graph
        lr, train_op = self.construct_graph(sess)  #构建网络通过

        # Find previous snapshots if there is any to restore from
        lsf, nfiles, sfiles = self.find_previous()

        # Initialize the variables or restore them from the last snapshot
        if lsf == 0:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(
                sess)
        else:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(
                sess, str(sfiles[-1]), str(nfiles[-1]))
        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        # Make sure the lists are not empty
        stepsizes.append(max_iters)
        stepsizes.reverse()
        next_stepsize = stepsizes.pop()
        while iter < max_iters + 1:
            # Learning rate
            if iter == next_stepsize + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(sess, iter)
                rate *= cfg.TRAIN.GAMMA
                sess.run(tf.assign(lr, rate))
                next_stepsize = stepsizes.pop()

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward(
            )  #这里开始报错 解决(原因是忘了加image 属性 导致没读出数据) #验证集不需要数据扩增 !!前向运算

            now = time.time()
            if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
                  self.net.train_step_with_summary(sess, blobs, train_op)
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()  #验证集的运算

                last_summary_time = now
            else:
                # Compute the graph without summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
                  self.net.train_step(sess, blobs, train_op)
            timer.toc()

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
                      (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
                print('speed: {:.3f}s / iter'.format(timer.average_time))

            # Snapshotting
            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                ss_path, np_path = self.snapshot(sess, iter)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(sess, iter - 1)

        self.writer.close()
        self.valwriter.close()