Пример #1
0
class SolverWrapper(object):
  """
    A wrapper class for the training process
  """
  def __init__(self, sess, network, rfcn_network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None):
    self.net = network
    self.rfcn_network = rfcn_network
    self.imdb = imdb
    self.roidb = roidb
    self.valroidb = valroidb
    self.output_dir = output_dir
    self.tbdir = tbdir
    # Simply put '_val' at the end to save the summaries from the validation set
    self.tbvaldir = tbdir + '_val'
    if not os.path.exists(self.tbvaldir):
      os.makedirs(self.tbvaldir)
    self.pretrained_model = pretrained_model

  def snapshot(self, sess, iter):
    net = self.net

    if not os.path.exists(self.output_dir):
      os.makedirs(self.output_dir)

    # Store the model snapshot
    filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt'
    filename = os.path.join(self.output_dir, filename)

    self.saver.save(sess, filename)
    print('Wrote snapshot to: {:s}'.format(filename))

    # Also store some meta information, random state, etc.
    nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
    nfilename = os.path.join(self.output_dir, nfilename)
    # current state of numpy random
    st0 = np.random.get_state()
    # current position in the database
    cur = self.data_layer._cur
    # current shuffled indeces of the database
    perm = self.data_layer._perm
    # current position in the validation database
    cur_val = self.data_layer_val._cur
    # current shuffled indeces of the validation database
    perm_val = self.data_layer_val._perm

    # Dump the meta info
    with open(nfilename, 'wb') as fid:
      pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)
    return filename, nfilename

  def get_variables_in_checkpoint_file(self, file_name):
    try:
      reader = pywrap_tensorflow.NewCheckpointReader(file_name)
      var_to_shape_map = reader.get_variable_to_shape_map()
      return var_to_shape_map 
    except Exception as e:  # pylint: disable=broad-except
      print(str(e))
      if "corrupted compressed block contents" in str(e):
        print("It's likely that your checkpoint file has been compressed "
              "with SNAPPY.")

  def train_model(self, sess, max_iters):
    # Build data layers for both training and validation set
    self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
    self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)

    # Determine different scales for anchors, see paper
    with sess.graph.as_default():
      # Set the random seed for tensorflow
      tf.set_random_seed(cfg.RNG_SEED)
      # Build the main computation graph
      rpn_layers, proposal_targets = self.net.create_architecture(sess, 'TRAIN', self.imdb.num_classes, scope='rpn_network', tag='default',
                                            anchor_scales=cfg.ANCHOR_SCALES,
                                            anchor_ratios=cfg.ANCHOR_RATIOS)
      rfcn_layers, _ = self.rfcn_network.create_architecture(sess, 'TRAIN', self.imdb.num_classes, scope='rfcn_network', tag='default',
                                            anchor_scales=cfg.ANCHOR_SCALES,
                                            anchor_ratios=cfg.ANCHOR_RATIOS,
                                            input_rois=rpn_layers['rois'],
                                            roi_scores=rpn_layers['roi_scores'],
                                            proposal_targets=proposal_targets)

      # Define the loss
      rpn_loss = rpn_layers['rpn_loss']
      rfcn_loss = rfcn_layers['rfcn_loss']
      rpn_rfcn_loss = rpn_layers['rfcn_loss']
      # Set learning rate and momentum
      lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
      momentum = cfg.TRAIN.MOMENTUM
      self.optimizer = tf.train.MomentumOptimizer(lr, momentum)

      # Compute the gradients wrt the loss
      rpn_trainable_variables_stage1 = self.net.get_train_variables('rpn_network')
      rfcn_trainable_variables_stage2 = self.rfcn_network.get_train_variables('rfcn_network')
      rpn_trainable_variables_stage3 = self.net.get_train_variables_stage3('rpn_network')
      rpn_trainable_variables_stage4 = self.net.get_train_variables_stage4('rpn_network')
      gvs_rpn_stage1 = self.optimizer.compute_gradients(rpn_loss, rpn_trainable_variables_stage1)
      gvs_rfcn_stage2 = self.optimizer.compute_gradients(rfcn_loss, rfcn_trainable_variables_stage2)
      gvs_rpn_stage3 = self.optimizer.compute_gradients(rpn_loss, rpn_trainable_variables_stage3)
      gvs_rfcn_stage4 = self.optimizer.compute_gradients(rpn_rfcn_loss, rpn_trainable_variables_stage4)

      train_op_stage1 = self.optimizer.apply_gradients(gvs_rpn_stage1)
      train_op_stage2 = self.optimizer.apply_gradients(gvs_rfcn_stage2)
      train_op_stage3 = self.optimizer.apply_gradients(gvs_rpn_stage3)
      train_op_stage4 = self.optimizer.apply_gradients(gvs_rfcn_stage4)

      # We will handle the snapshots ourselves
      self.saver = tf.train.Saver(max_to_keep=1000000)
      # Write the train and validation information to tensorboard
      self.writer_stage1 = tf.summary.FileWriter(self.tbdir+'/stage1', sess.graph)
      self.writer_stage2 = tf.summary.FileWriter(self.tbdir+'/stage2', sess.graph)
      self.writer_stage3 = tf.summary.FileWriter(self.tbdir+'/stage3', sess.graph)
      self.writer_stage4 = tf.summary.FileWriter(self.tbdir+'/stage4', sess.graph)
      self.valwriter_stage1 = tf.summary.FileWriter(self.tbvaldir+'/stage1')
      self.valwriter_stage2 = tf.summary.FileWriter(self.tbvaldir+'/stage2')
      self.valwriter_stage3 = tf.summary.FileWriter(self.tbvaldir+'/stage3')
      self.valwriter_stage4 = tf.summary.FileWriter(self.tbvaldir+'/stage4')

    # Find previous snapshots if there is any to restore from
    sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
    sfiles = glob.glob(sfiles)
    sfiles.sort(key=os.path.getmtime)
    # Get the snapshot name in TensorFlow
    redstr = '_iter_{:d}.'.format(cfg.TRAIN.STEPSIZE+1)
    sfiles = [ss.replace('.meta', '') for ss in sfiles]
    sfiles = [ss for ss in sfiles if redstr not in ss]

    nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
    nfiles = glob.glob(nfiles)
    nfiles.sort(key=os.path.getmtime)
    nfiles = [nn for nn in nfiles if redstr not in nn]

    lsf = len(sfiles)
    assert len(nfiles) == lsf

    np_paths = nfiles
    ss_paths = sfiles

    if lsf == 0:
      # Fresh train directly from ImageNet weights
      print('Loading initial model weights from {:s}'.format(self.pretrained_model))
      variables = tf.global_variables()
      # Initialize all variables first
      sess.run(tf.variables_initializer(variables, name='init'))
      var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model)
      # Get the variables to restore, ignorizing the variables to fix
      variables_to_restore_rpn, variables_to_restore_rfcn = self.net.get_variables_to_restore(variables, var_keep_dic)

      self.saver_rpn = tf.train.Saver(variables_to_restore_rpn)
      self.saver_rpn.restore(sess, self.pretrained_model)

      self.saver_rfcn = tf.train.Saver(variables_to_restore_rfcn)
      self.saver_rfcn.restore(sess, self.pretrained_model)

      print('Loaded.')
      # Need to fix the variables before loading, so that the RGB weights are changed to BGR
      # For VGG16 it also changes the convolutional weights fc6 and fc7 to
      # fully connected weights
      self.net.fix_variables(sess, self.pretrained_model)
      print('Fixed.')
      sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))
      last_snapshot_iter = 0
    else:
      # Get the most recent snapshot and restore
      ss_paths = [ss_paths[-1]]
      np_paths = [np_paths[-1]]

      print('Restorining model snapshots from {:s}'.format(sfiles[-1]))
      self.saver.restore(sess, str(sfiles[-1]))
      print('Restored.')
      # Needs to restore the other hyperparameters/states for training, (TODO xinlei) I have
      # tried my best to find the random states so that it can be recovered exactly
      # However the Tensorflow state is currently not available
      with open(str(nfiles[-1]), 'rb') as fid:
        st0 = pickle.load(fid)
        cur = pickle.load(fid)
        perm = pickle.load(fid)
        cur_val = pickle.load(fid)
        perm_val = pickle.load(fid)
        last_snapshot_iter = pickle.load(fid)

        np.random.set_state(st0)
        self.data_layer._cur = cur
        self.data_layer._perm = perm
        self.data_layer_val._cur = cur_val
        self.data_layer_val._perm = perm_val

        # Set the learning rate, only reduce once
        if last_snapshot_iter > cfg.TRAIN.STEPSIZE:
          sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA))
        else:
          sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))

    timer = Timer()
    iter = last_snapshot_iter + 1
    last_summary_time = time.time()
    stage_infor = ''
    # while iter < max_iters + 1:
    while iter < 200001:
    # while iter < 201:
      # Learning rate
      if iter == 80001:
      # if iter == 81:
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        sess.run(tf.assign(lr, 0.001))
      # Get training data, one batch at a time
      blobs = self.data_layer.forward()
      # stage 1  training rpn layers and backbones  in rpn network
      if iter < 80001:
      # if iter < 81:
        stage_infor = 'stage1'
        if iter == 60001:
        # if iter == 61:
          sess.run(tf.assign(lr, 0.0001))
        timer.tic()
        now = time.time()
        if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
          # Compute the graph with summary
          rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
            self.net.train_rpn_step_with_summary(sess, blobs, train_op_stage1)
          self.writer_stage1.add_summary(summary, float(iter))
          # Also check the summary on the validation set
          blobs_val = self.data_layer_val.forward()
          summary_val = self.net.get_summary(sess, blobs_val)
          self.valwriter_stage1.add_summary(summary_val, float(iter))
          last_summary_time = now
        else:
          # Compute the graph without summary
          rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
            self.net.train_rpn_step(sess, blobs, train_op_stage1)
        timer.toc()
        if iter == 80000:
          self.writer_stage1.close()
          self.valwriter_stage1.close()
      # stage 2 training rfcn layers and backbones  in rfcn network
      elif 80001 <= iter < 200001:
      # elif 81 <= iter < 201:
        stage_infor = 'stage2'
        rpn_loss_cls = 0
        rpn_loss_box = 0
        if iter == 160001:
        # if iter == 161:
          self.snapshot(sess, iter)
          sess.run(tf.assign(lr, 0.0001))
        timer.tic()
        now = time.time()
        if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
          # Compute the graph with summary
          loss_cls, loss_box, total_loss, summary = \
            self.rfcn_network.train_rfcn_step_with_summary_stage2(sess, blobs, train_op_stage2, self.net)
          self.writer_stage2.add_summary(summary, float(iter))
          # Also check the summary on the validation set
          blobs_val = self.data_layer_val.forward()
          summary_val = self.rfcn_network.get_summary_stage2(sess, blobs_val, self.net)
          self.valwriter_stage2.add_summary(summary_val, float(iter))
          last_summary_time = now
        else:
          # Compute the graph without summary
          loss_cls, loss_box, total_loss = \
            self.rfcn_network.train_rfcn_step_stage2(sess, blobs, train_op_stage2, self.net)
        timer.toc()
        if iter == 200000:
          self.writer_stage2.close()
          self.valwriter_stage2.close()
      else:
        raise ValueError('illeagle input iter value')

      # Display training information
      # if iter % (cfg.TRAIN.DISPLAY) == 0:
      if iter % 20 == 0:
        print('iter: %d / %d, stage: %s, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
              '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
              (iter, max_iters, stage_infor, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
        print('speed: {:.3f}s / iter'.format(timer.average_time))

      if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
        last_snapshot_iter = iter
        snapshot_path, np_path = self.snapshot(sess, iter)
        np_paths.append(np_path)
        ss_paths.append(snapshot_path)
      iter += 1

    if iter <= 200001:
    # if iter <= 201:
      ###############################################
      #####  merge rfcn_network to rpn_network  #####
      ###############################################
      merged_ops = merged_networks_rfcn2rpn('rfcn_network', 'rpn_network')
      with tf.variable_scope('rpn_network', reuse=True):
        rpn_conv1 = tf.get_variable('resnet_v1_101/block2/unit_1/bottleneck_v1/shortcut/weights')
        rpn_conv2 = tf.get_variable('resnet_v1_101/refined_reduce_depth/weights')
        rpn_conv3 = tf.get_variable('resnet_v1_101/block4/unit_3/bottleneck_v1/conv3/weights')
        rpn_conv4 = tf.get_variable('resnet_v1_101/block3/unit_18/bottleneck_v1/conv2/weights')
        rpn_conv5 = tf.get_variable('resnet_v1_101/block3/unit_14/bottleneck_v1/conv3/weights')
      with tf.variable_scope('rfcn_network', reuse=True):
        rfcn_conv1 = tf.get_variable('resnet_v1_101/block2/unit_1/bottleneck_v1/shortcut/weights')
        rfcn_conv2 = tf.get_variable('resnet_v1_101/refined_reduce_depth/weights')
        rfcn_conv3 = tf.get_variable('resnet_v1_101/block4/unit_3/bottleneck_v1/conv3/weights')
        rfcn_conv4 = tf.get_variable('resnet_v1_101/block3/unit_18/bottleneck_v1/conv2/weights')
        rfcn_conv5 = tf.get_variable('resnet_v1_101/block3/unit_14/bottleneck_v1/conv3/weights')
      with tf.control_dependencies(merged_ops):
        rpn_conv1 = tf.identity(rpn_conv1)
        bool1 = tf.equal(rpn_conv1, rfcn_conv1)
        bool2 = tf.equal(rpn_conv2, rfcn_conv2)
        bool3 = tf.equal(rpn_conv3, rfcn_conv3)
        bool4 = tf.equal(rpn_conv4, rfcn_conv4)
        bool5 = tf.equal(rpn_conv5, rfcn_conv5)
        bool1_val, bool2_val, bool3_val, bool4_val, bool5_val = \
          sess.run([bool1, bool2, bool3, bool4, bool5])
    # stage 3 and stage 4
    sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))
    # while iter < max_iters + 1:
    while iter < 480001:
    # while iter < 401:
        if iter == 280001:
        # if iter == 261:
          # Add snapshot here before reducing the learning rate
          self.snapshot(sess, iter)
          sess.run(tf.assign(lr, 0.001))
        if iter == 400001:
          self.snapshot(sess, iter)
          sess.run(tf.assign(lr, 0.001))

        blobs = self.data_layer.forward()
        # stage 3 training rpn layers only  in rpn network rpn layers
        if 200001 <= iter < 280001:
        # if 201 <= iter < 581:
          stage_infor = 'stage3'
          # if iter == 260001:
          if iter == 260001:
            self.snapshot(sess, iter)
            sess.run(tf.assign(lr, 0.0001))
          timer.tic()
          now = time.time()
          if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
            # Compute the graph with summary
            rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
              self.net.train_rpn_step_with_summary(sess, blobs, train_op_stage3)
            self.writer_stage3.add_summary(summary, float(iter))
            # Also check the summary on the validation set
            blobs_val = self.data_layer_val.forward()
            summary_val = self.net.get_summary(sess, blobs_val)
            self.valwriter_stage3.add_summary(summary_val, float(iter))
            last_summary_time = now
          else:
            # Compute the graph without summary
            rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
              self.net.train_rpn_step(sess, blobs, train_op_stage3)
          timer.toc()
          if iter == 280000:
            self.writer_stage3.close()
            self.valwriter_stage3.close()
        # stage 4 training rfcn layer only in rpn network rfcn layers
        elif 280001 <= iter < 400001:
        # elif 581 <= iter < 1401:
          stage_infor = 'stage4'
          if iter == 360001:
          # if iter == 1361:
            sess.run(tf.assign(lr, 0.0001))
          timer.tic()
          now = time.time()
          if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
            # Compute the graph with summary
            rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
              self.net.train_rpn_step_with_summary(sess, blobs, train_op_stage4)
            self.writer_stage4.add_summary(summary, float(iter))
            # Also check the summary on the validation set
            blobs_val = self.data_layer_val.forward()
            summary_val = self.net.get_summary(sess, blobs_val)
            self.valwriter_stage4.add_summary(summary_val, float(iter))
            last_summary_time = now
          else:
            # Compute the graph without summary
            rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
              self.net.train_rpn_step(sess, blobs, train_op_stage4)
          timer.toc()
          if iter == 400000:
            self.writer_stage4.close()
            self.valwriter_stage4.close()
        elif 400001 <= iter < 480001:
          # if 401 <= iter < 481:
          stage_infor = 'stage5'
          # if iter == 461:
          if iter == 460001:
            self.snapshot(sess, iter)
            sess.run(tf.assign(lr, 0.0001))
          timer.tic()
          now = time.time()
          if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
            # Compute the graph with summary
            rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
              self.net.train_rpn_step_with_summary(sess, blobs, train_op_stage3)
            self.writer_stage3.add_summary(summary, float(iter))
            # Also check the summary on the validation set
            blobs_val = self.data_layer_val.forward()
            summary_val = self.net.get_summary(sess, blobs_val)
            self.valwriter_stage3.add_summary(summary_val, float(iter))
            last_summary_time = now
          else:
            # Compute the graph without summary
            rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
              self.net.train_rpn_step(sess, blobs, train_op_stage3)
          timer.toc()
          if iter == 480000:
            self.writer_stage3.close()
            self.valwriter_stage3.close()
        else:
          raise ValueError('iter is not allowed')

        # Display training information
        if iter % (cfg.TRAIN.DISPLAY) == 0:
          print('iter: %d / %d, stage: %s, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
                (iter, max_iters, stage_infor, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
          print('speed: {:.3f}s / iter'.format(timer.average_time))

        if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
          last_snapshot_iter = iter
          snapshot_path, np_path = self.snapshot(sess, iter)
          np_paths.append(np_path)
          ss_paths.append(snapshot_path)
        iter += 1

    if last_snapshot_iter != iter - 1:
      self.snapshot(sess, iter - 1)
Пример #2
0
class SolverWrapper(object):
  """
    A wrapper class for the training process
  """

  def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None):
    self.net = network
    self.imdb = imdb
    self.roidb = roidb
    self.valroidb = valroidb
    self.output_dir = output_dir
    self.tbdir = tbdir
    # Simply put '_val' at the end to save the summaries from the validation set
    self.tbvaldir = tbdir + '_val'
    if not os.path.exists(self.tbvaldir):
      os.makedirs(self.tbvaldir)
    self.pretrained_model = pretrained_model

  def snapshot(self, sess, iter):
    net = self.net

    if not os.path.exists(self.output_dir):
      os.makedirs(self.output_dir)

    # Store the model snapshot
    filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt'
    filename = os.path.join(self.output_dir, filename)
    self.saver.save(sess, filename)
    print('Wrote snapshot to: {:s}'.format(filename))

    # Also store some meta information, random state, etc.
    nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
    nfilename = os.path.join(self.output_dir, nfilename)
    # current state of numpy random
    st0 = np.random.get_state()
    # current position in the database
    cur = self.data_layer._cur
    # current shuffled indexes of the database
    perm = self.data_layer._perm
    # current position in the validation database
    #cur_val = self.data_layer_val._cur
    # current shuffled indexes of the validation database
    #perm_val = self.data_layer_val._perm

    # Dump the meta info
    with open(nfilename, 'wb') as fid:
      pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
      #pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
      #pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

    return filename, nfilename

  def from_snapshot(self, sess, sfile, nfile):
    print('Restoring model snapshots from {:s}'.format(sfile))
    self.saver.restore(sess, sfile)
    print('Restored.')
    # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
    # tried my best to find the random states so that it can be recovered exactly
    # However the Tensorflow state is currently not available
    with open(nfile, 'rb') as fid:
      st0 = pickle.load(fid)
      cur = pickle.load(fid)
      perm = pickle.load(fid)
      #cur_val = pickle.load(fid)
      #perm_val = pickle.load(fid)
      last_snapshot_iter = pickle.load(fid)

      np.random.set_state(st0)
      self.data_layer._cur = cur
      self.data_layer._perm = perm
      #self.data_layer_val._cur = cur_val
      #self.data_layer_val._perm = perm_val

    return last_snapshot_iter

  def get_variables_in_checkpoint_file(self, file_name):
    try:
      reader = pywrap_tensorflow.NewCheckpointReader(file_name)
      var_to_shape_map = reader.get_variable_to_shape_map()
      return var_to_shape_map 
    except Exception as e:  # pylint: disable=broad-except
      print(str(e))
      if "corrupted compressed block contents" in str(e):
        print("It's likely that your checkpoint file has been compressed "
              "with SNAPPY.")

  def construct_graph(self, sess):
    with sess.graph.as_default():
      # Set the random seed for tensorflow
      tf.set_random_seed(cfg.RNG_SEED)
      # Build the main computation graph
      layers = self.net.create_architecture('TRAIN', self.imdb.num_classes, tag='default',
                                            anchor_scales=cfg.ANCHOR_SCALES,
                                            anchor_ratios=cfg.ANCHOR_RATIOS)
      # Define the loss
      loss = layers['total_loss']
      # Set learning rate and momentum
      lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
      self.optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)

      # Compute the gradients with regard to the loss
      gvs = self.optimizer.compute_gradients(loss)
      # Double the gradient of the bias if set
      if cfg.TRAIN.DOUBLE_BIAS:
        final_gvs = []
        with tf.variable_scope('Gradient_Mult') as scope:
          for grad, var in gvs:
            scale = 1.
            if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
              scale *= 2.
            if not np.allclose(scale, 1.0):
              grad = tf.multiply(grad, scale)
            final_gvs.append((grad, var))
        train_op = self.optimizer.apply_gradients(final_gvs)
      else:
        train_op = self.optimizer.apply_gradients(gvs)

      # We will handle the snapshots ourselves
      self.saver = tf.train.Saver(max_to_keep=100000)
      # Write the train and validation information to tensorboard
      self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
      #self.valwriter = tf.summary.FileWriter(self.tbvaldir)

    return lr, train_op

  def find_previous(self):
    sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
    sfiles = glob.glob(sfiles)
    sfiles.sort(key=os.path.getmtime)
    # Get the snapshot name in TensorFlow
    redfiles = []
    for stepsize in cfg.TRAIN.STEPSIZE:
      redfiles.append(os.path.join(self.output_dir, 
                      cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}.ckpt.meta'.format(stepsize+1)))
    sfiles = [ss.replace('.meta', '') for ss in sfiles if ss not in redfiles]

    nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
    nfiles = glob.glob(nfiles)
    nfiles.sort(key=os.path.getmtime)
    redfiles = [redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles]
    nfiles = [nn for nn in nfiles if nn not in redfiles]

    lsf = len(sfiles)
    assert len(nfiles) == lsf

    return lsf, nfiles, sfiles

  def initialize(self, sess):
    # Initial file lists are empty
    np_paths = []
    ss_paths = []
    # Fresh train directly from ImageNet weights
    print('Loading initial model weights from {:s}'.format(self.pretrained_model))
    variables = tf.global_variables()
    # Initialize all variables first
    sess.run(tf.variables_initializer(variables, name='init'))
    var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model)
    # Get the variables to restore, ignoring the variables to fix
    variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic)

    restorer = tf.train.Saver(variables_to_restore)
    restorer.restore(sess, self.pretrained_model)
    print('Loaded.')
    # Need to fix the variables before loading, so that the RGB weights are changed to BGR
    # For VGG16 it also changes the convolutional weights fc6 and fc7 to
    # fully connected weights
    self.net.fix_variables(sess, self.pretrained_model)
    print('Fixed.')
    last_snapshot_iter = 0
    rate = cfg.TRAIN.LEARNING_RATE
    stepsizes = list(cfg.TRAIN.STEPSIZE)

    return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

  def restore(self, sess, sfile, nfile):
    # Get the most recent snapshot and restore
    np_paths = [nfile]
    ss_paths = [sfile]
    # Restore model from snapshots
    last_snapshot_iter = self.from_snapshot(sess, sfile, nfile)
    # Set the learning rate
    rate = cfg.TRAIN.LEARNING_RATE
    stepsizes = []
    for stepsize in cfg.TRAIN.STEPSIZE:
      if last_snapshot_iter > stepsize:
        rate *= cfg.TRAIN.GAMMA
      else:
        stepsizes.append(stepsize)

    return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

  def remove_snapshot(self, np_paths, ss_paths):
    to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
    for c in range(to_remove):
      nfile = np_paths[0]
      os.remove(str(nfile))
      np_paths.remove(nfile)

    to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
    for c in range(to_remove):
      sfile = ss_paths[0]
      # To make the code compatible to earlier versions of Tensorflow,
      # where the naming tradition for checkpoints are different
      if os.path.exists(str(sfile)):
        os.remove(str(sfile))
      else:
        os.remove(str(sfile + '.data-00000-of-00001'))
        os.remove(str(sfile + '.index'))
      sfile_meta = sfile + '.meta'
      os.remove(str(sfile_meta))
      ss_paths.remove(sfile)

  def train_model(self, sess, max_iters):
    # Build data layers for both training and validation set
    self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
    #self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)

    # Construct the computation graph
    lr, train_op = self.construct_graph(sess)

    # Find previous snapshots if there is any to restore from
    lsf, nfiles, sfiles = self.find_previous()

    # Initialize the variables or restore them from the last snapshot
    if lsf == 0:
      rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(sess)
    else:
      rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(sess, 
                                                                            str(sfiles[-1]), 
                                                                            str(nfiles[-1]))
    timer = Timer()
    iter = last_snapshot_iter + 1
    last_summary_time = time.time()
    # Make sure the lists are not empty
    stepsizes.append(max_iters)
    stepsizes.reverse()
    next_stepsize = stepsizes.pop()
    while iter < max_iters + 1:
      # Learning rate
      if iter == next_stepsize + 1:
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        rate *= cfg.TRAIN.GAMMA
        sess.run(tf.assign(lr, rate))
        next_stepsize = stepsizes.pop()

      timer.tic()
      # Get training data, one batch at a time
      blobs = self.data_layer.forward()

      now = time.time()
      if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
        # Compute the graph with summary
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
          self.net.train_step_with_summary(sess, blobs, train_op)
        self.writer.add_summary(summary, float(iter))
        # Also check the summary on the validation set
        #blobs_val = self.data_layer_val.forward()
        #summary_val = self.net.get_summary(sess, blobs_val)
        #self.valwriter.add_summary(summary_val, float(iter))
        last_summary_time = now
      else:
        # Compute the graph without summary
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
          self.net.train_step(sess, blobs, train_op)
      timer.toc()

      # Display training information
      if iter % (cfg.TRAIN.DISPLAY) == 0:
        print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
              '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
              (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
        print('speed: {:.3f}s / iter'.format(timer.average_time))

      # Snapshotting
      if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
        last_snapshot_iter = iter
        ss_path, np_path = self.snapshot(sess, iter)
        np_paths.append(np_path)
        ss_paths.append(ss_path)

        # Remove the old snapshots if there are too many
        if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
          self.remove_snapshot(np_paths, ss_paths)

      iter += 1

    if last_snapshot_iter != iter - 1:
      self.snapshot(sess, iter - 1)

    self.writer.close()
Пример #3
0
class SolverWrapper(object):
  """ A wrapper class for the training process """

  def __init__(self, sess, network, imdb, roidb, valroidb, output_dir,
               pretrained_model=None):
    self.net = network
    self.imdb = imdb
    self.roidb = roidb
    self.valroidb = valroidb
    self.output_dir = output_dir
    self.pretrained_model = pretrained_model


  def snapshot(self, sess, iter):
    net = self.net

    if not os.path.exists(self.output_dir):
      os.makedirs(self.output_dir)

    # Store the model snapshot
    filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt'
    filename = os.path.join(self.output_dir, filename)
    self.saver.save(sess, filename)
    print('Wrote snapshot to: {:s}'.format(filename))

    # Also store some meta information, random state, etc.
    nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
    nfilename = os.path.join(self.output_dir, nfilename)
    # current state of numpy random
    st0 = np.random.get_state()
    # current position in the database
    cur = self.data_layer._cur
    # current shuffled indexes of the database
    perm = self.data_layer._perm
    # current position in the validation database
    cur_val = self.data_layer_val._cur
    # current shuffled indexes of the validation database
    perm_val = self.data_layer_val._perm

    # Dump the meta info
    with open(nfilename, 'wb') as fid:
      pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

    return filename, nfilename


  def from_snapshot(self, sess, sfile, nfile):
    print('Restoring model snapshots from {:s}'.format(sfile))
    self.saver.restore(sess, sfile)
    print('Restored.')
    # Needs to restore the other hyper-parameters/states for training, I have
    # tried my best to find the random states so that it can be recovered exactly
    # However the Tensorflow state is currently not available
    with open(nfile, 'rb') as fid:
      st0 = pickle.load(fid)
      cur = pickle.load(fid)
      perm = pickle.load(fid)
      cur_val = pickle.load(fid)
      perm_val = pickle.load(fid)
      last_snapshot_iter = pickle.load(fid)

      np.random.set_state(st0)
      self.data_layer._cur = cur
      self.data_layer._perm = perm
      self.data_layer_val._cur = cur_val
      self.data_layer_val._perm = perm_val

    return last_snapshot_iter


  def get_variables_in_checkpoint_file(self, file_name):
    try:
      reader = pywrap_tensorflow.NewCheckpointReader(file_name)
      var_to_shape_map = reader.get_variable_to_shape_map()
      return var_to_shape_map 
    except Exception as e:  # pylint: disable=broad-except
      print(str(e))
      if "corrupted compressed block contents" in str(e):
        print("It's likely that your checkpoint file has been compressed "
              "with SNAPPY.")


  def construct_graph(self, sess):
    # Set the random seed for tensorflow
    tf.set_random_seed(cfg.RNG_SEED)
    with sess.graph.as_default():
      # Build the main computation graph
      layers = self.net.create_architecture('TRAIN', tag='default',
                                            anchor_scales=cfg.ANCHOR_SCALES,
                                            anchor_ratios=cfg.ANCHOR_RATIOS)
      # Define the loss
      loss = layers['total_loss']
      # Set learning rate and momentum
      lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
      self.optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)

      # Compute the gradients with regard to the loss
      gvs = self.optimizer.compute_gradients(loss)
      # Double the gradient of the bias if set
      if cfg.TRAIN.DOUBLE_BIAS:
        final_gvs = []
        with tf.variable_scope('Gradient_Mult') as scope:
          for grad, var in gvs:
            scale = 1.
            if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
              scale *= 2.
            if not np.allclose(scale, 1.0):
              grad = tf.multiply(grad, scale)
            final_gvs.append((grad, var))
        train_op = self.optimizer.apply_gradients(final_gvs)
      else:
        train_op = self.optimizer.apply_gradients(gvs)

      # Initialize post-hist module of drl-RPN
      if cfg.DRL_RPN.USE_POST:
        loss_post = layers['total_loss_hist']
        lr_post = tf.Variable(cfg.DRL_RPN_TRAIN.POST_LR, trainable=False)
        self.optimizer_post = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM) 
        gvs_post = self.optimizer_post.compute_gradients(loss_post)
        train_op_post = self.optimizer_post.apply_gradients(gvs_post)
      else:
        lr_post = None
        train_op_post = None

      # Initialize main drl-RPN network
      self.net.build_drl_rpn_network()

    return lr, train_op, lr_post, train_op_post


  def initialize(self, sess):
    # Initial file lists are empty
    np_paths = []
    ss_paths = []
    # Fresh train directly from ImageNet weights
    print('Loading initial model weights from {:s}'.format(self.pretrained_model))
    variables = tf.global_variables()
    # Initialize all variables first
    sess.run(tf.variables_initializer(variables, name='init'))
    var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model)
    #print(self.pretrained_model)
    #sleep(100)
    # Get the variables to restore, ignoring the variables to fix
    variables_to_restore = self.net.get_variables_to_restore(variables,
                                                             var_keep_dic)
    restorer = tf.train.Saver(variables_to_restore)
    restorer.restore(sess, self.pretrained_model)
    print('Loaded.')
    # Need to fix the variables before loading, so that the RGB weights are 
    # hanged to BGR For VGG16 it also changes the convolutional weights fc6
    # and fc7 to fully connected weights
    #
    # NOTE: IF YOU WANT TO TRAIN FROM EXISTING FASTER
    # R-CNN WEIGHTS, AND NOT FROM IMAGENET WEIGHTS, SET BELOW FLAG TO FALSE!!!!
    self.net.fix_variables(sess, self.pretrained_model, False)
    print('Fixed.')
    last_snapshot_iter = 0
    rate = cfg.TRAIN.LEARNING_RATE
    stepsizes = list(cfg.TRAIN.STEPSIZE)
    return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths


  def restore(self, sess, sfile, nfile):
    # Get the most recent snapshot and restore
    np_paths = [nfile]
    ss_paths = [sfile]
    # Restore model from snapshots
    last_snapshot_iter = self.from_snapshot(sess, sfile, nfile)
    # Set the learning rate
    rate = cfg.TRAIN.LEARNING_RATE
    stepsizes = []
    for stepsize in cfg.TRAIN.STEPSIZE:
      if last_snapshot_iter > stepsize:
        rate *= cfg.TRAIN.GAMMA
      else:
        stepsizes.append(stepsize)
    return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths


  def remove_snapshot(self, np_paths, ss_paths):
    to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
    for c in range(to_remove):
      nfile = np_paths[0]
      os.remove(str(nfile))
      np_paths.remove(nfile)
    to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
    for c in range(to_remove):
      sfile = ss_paths[0]
      # To make the code compatible to earlier versions of Tensorflow,
      # where the naming tradition for checkpoints are different
      if os.path.exists(str(sfile)):
        os.remove(str(sfile))
      else:
        os.remove(str(sfile + '.data-00000-of-00001'))
        os.remove(str(sfile + '.index'))
      sfile_meta = sfile + '.meta'
      os.remove(str(sfile_meta))
      ss_paths.remove(sfile)


  def _print_det_loss(self, iter, max_iters, tot_loss, loss_cls, loss_box,
                      lr, timer, in_string='detector'):
    if (iter + 1) % (cfg.TRAIN.DISPLAY) == 0:
      if loss_box is not None:
        print('iter: %d / %d, total loss: %.6f\n '
              '>>> loss_cls (%s): %.6f\n '
              '>>> loss_box (%s): %.6f\n >>> lr: %f' % \
              (iter + 1, max_iters, tot_loss, in_string, loss_cls, in_string,
               loss_box, lr))
      else:
        print('iter: %d / %d, total loss (%s): %.6f\n >>> lr: %f' % \
              (iter + 1, max_iters, in_string, tot_loss, lr))
      print('speed: {:.3f}s / iter'.format(timer.average_time))


  def _check_if_continue(self, iter, max_iters, snapshot_add):
    img_start_idx = cfg.DRL_RPN_TRAIN.IMG_START_IDX
    if iter > img_start_idx:
      return iter, max_iters, snapshot_add, False
    if iter < img_start_idx:
      print("iter %d < img_start_idx %d -- continuing" % (iter, img_start_idx))
      iter += 1
      return iter, max_iters, snapshot_add, True
    if iter == img_start_idx:
      print("Adjusting stepsize, train-det-start etcetera")
      snapshot_add = img_start_idx
      max_iters -= img_start_idx
      iter = 0
      cfg_from_list(['DRL_RPN_TRAIN.IMG_START_IDX', -1])
      cfg_from_list(['DRL_RPN_TRAIN.DET_START',
                     cfg.DRL_RPN_TRAIN.DET_START - img_start_idx])
      cfg_from_list(['DRL_RPN_TRAIN.STEPSIZE',
                     cfg.DRL_RPN_TRAIN.STEPSIZE - img_start_idx])
      cfg_from_list(['TRAIN.STEPSIZE', [cfg.TRAIN.STEPSIZE[0] - img_start_idx]])
      cfg_from_list(['DRL_RPN_TRAIN.POST_SS',
                    [cfg.DRL_RPN_TRAIN.POST_SS[0] - img_start_idx]])
      print("Done adjusting stepsize, train-det-start etcetera")
      return iter, max_iters, snapshot_add, False


  def train_model(self, sess, max_iters):
    # Build data layers for both training and validation set
    self.data_layer = RoIDataLayer(self.roidb, cfg.NBR_CLASSES)
    self.data_layer_val = RoIDataLayer(self.valroidb, cfg.NBR_CLASSES, True)

    # Construct the computation graph corresponding to the original Faster R-CNN
    # architecture first (and potentially the post-hist module of drl-RPN)
    lr_det_op, train_op, lr_post_op, train_op_post \
      = self.construct_graph(sess)

    # Initialize the variables or restore them from the last snapshot
    rate, last_snapshot_iter, stepsizes, np_paths, ss_paths \
      = self.initialize(sess)

    # We will handle the snapshots ourselves
    self.saver = tf.train.Saver(max_to_keep=100000)

    # Initialize
    self.net.init_rl_train(sess)

    # Setup initial learning rates
    lr_rl = cfg.DRL_RPN_TRAIN.LEARNING_RATE
    lr_det = cfg.TRAIN.LEARNING_RATE
    sess.run(tf.assign(lr_det_op, lr_det))
    if cfg.DRL_RPN.USE_POST:
      lr_post = cfg.DRL_RPN_TRAIN.POST_LR
      sess.run(tf.assign(lr_post_op, lr_post))

    # Sample first beta
    if cfg.DRL_RPN_TRAIN.USE_POST:
      betas = cfg.DRL_RPN_TRAIN.POST_BETAS
    else:
      betas = cfg.DRL_RPN_TRAIN.BETAS
    beta_idx = 0
    beta = betas[beta_idx]

    # Setup drl-RPN timers
    timers = {'init': Timer(), 'fulltraj': Timer(), 'upd-obs-vol': Timer(),
                  'upd-seq': Timer(), 'upd-rl': Timer(), 'action-rl': Timer(),
                  'coll-traj': Timer(), 'run-drl-rpn': Timer(),
                  'train-drl-rpn': Timer()}

    # Create StatCollector (tracks various RL training statistics)
    stat_strings = ['reward', 'rew-done', 'traj-len', 'frac-area',
                    'gt >= 0.5 frac', 'gt-IoU-frac']
    sc = StatCollector(max_iters, stat_strings)

    timer = Timer()
    iter = 0
    snapshot_add = 0
    while iter < max_iters:

      # Get training data, one batch at a time (assumes batch size 1)
      blobs = self.data_layer.forward()

      # Allows the possibility to start at arbitrary image, rather
      # than always starting from first image in dataset. Useful if
      # want to load parameters and keep going from there, rather
      # than having those and encountering visited images again.
      iter, max_iters, snapshot_add, do_continue \
        = self._check_if_continue(iter, max_iters, snapshot_add)
      if do_continue:
        continue

      if not cfg.DRL_RPN_TRAIN.USE_POST:

        # Potentially update drl-RPN learning rate
        if (iter + 1) % cfg.DRL_RPN_TRAIN.STEPSIZE == 0:
          lr_rl *= cfg.DRL_RPN_TRAIN.GAMMA

        # Run drl-RPN in training mode
        timers['run-drl-rpn'].tic()
        stats = run_drl_rpn(sess, self.net, blobs, timers, mode='train',
                            beta=beta, im_idx=None, extra_args=lr_rl)
        timers['run-drl-rpn'].toc()

        if (iter + 1) % cfg.DRL_RPN_TRAIN.BATCH_SIZE == 0:
          print("\n##### DRL-RPN BATCH GRADIENT UPDATE - START ##### \n")
          print('iter: %d / %d' % (iter + 1, max_iters))
          print('lr-rl: %f' % lr_rl)
          timers['train-drl-rpn'].tic()
          self.net.train_drl_rpn(sess, lr_rl, sc, stats)
          timers['train-drl-rpn'].toc()
          sc.print_stats()
          print('TIMINGS:')
          print('runnn-drl-rpn: %.4f' % timers['run-drl-rpn'].get_avg())
          print('train-drl-rpn: %.4f' % timers['train-drl-rpn'].get_avg())
          print("\n##### DRL-RPN BATCH GRADIENT UPDATE - DONE ###### \n")

          # Also sample new beta for next batch
          beta_idx += 1
          beta_idx %= len(betas)
          beta = betas[beta_idx]
        else:
	        sc.update(0, stats)

        # At this point we assume that an RL-trajectory has been performed.
        # We next train detector with drl-RPN running in deterministic mode.
        # Potentially train detector component of network
        if cfg.DRL_RPN_TRAIN.DET_START >= 0 and \
          iter >= cfg.DRL_RPN_TRAIN.DET_START:

          # Run drl-RPN in deterministic mode
          net_conv, rois_drl_rpn, gt_boxes, im_info, timers, _ \
            = run_drl_rpn(sess, self.net, blobs, timers, mode='train_det',
                          beta=beta, im_idx=None)

          # Learning rate
          if (iter + 1) % cfg.TRAIN.STEPSIZE[0] == 0:
            lr_det *= cfg.TRAIN.GAMMA
            sess.run(tf.assign(lr_det_op, lr_det))

          timer.tic()
          # Train detector part
          loss_cls, loss_box, tot_loss \
            = self.net.train_step_det(sess, train_op, net_conv, rois_drl_rpn,
                                      gt_boxes, im_info)
          timer.toc()

          # Display training information
          self._print_det_loss(iter, max_iters, tot_loss, loss_cls, loss_box,
                               lr_det, timer)
   
      # Train post-hist module AFTER we have trained rest of drl-RPN! Specifically
      # once rest of drl-RPN has been trained already, copy those weights into
      # the folder of pretrained weights and rerun training with those as initial
      # weights, which will then train only the posterior-history module
      else:

        # The very first time we need to assign the ordinary detector weights
        # as starting point
        if iter == 0:
          self.net.assign_post_hist_weights(sess)

        # Sample beta
        beta = betas[beta_idx]
        beta_idx += 1
        beta_idx %= len(betas)

        # Run drl-RPN in deterministic mode
        net_conv, rois_drl_rpn, gt_boxes, im_info, timers, cls_hist \
          = run_drl_rpn(sess, self.net, blobs, timers, mode='train_det',
                        beta=beta, im_idx=None)

        # Learning rate (assume only one learning rate iter for now!)
        if (iter + 1) % cfg.DRL_RPN_TRAIN.POST_SS[0] == 0:
          lr_post *= cfg.TRAIN.GAMMA
          sess.run(tf.assign(lr_post_op, lr_post))

        # Train post-hist detector part
        tot_loss = self.net.train_step_post(sess, train_op_post, net_conv,
                                            rois_drl_rpn, gt_boxes, im_info,
                                            cls_hist)

        # Display training information
        self._print_det_loss(iter, max_iters, tot_loss, None, None,
                             lr_post, timer, 'post-hist')

      # Snapshotting
      if (iter + 1) % cfg.TRAIN.SNAPSHOT_ITERS == 0:
        last_snapshot_iter = iter + 1
        ss_path, np_path = self.snapshot(sess, iter + 1 + snapshot_add)
        np_paths.append(np_path)
        ss_paths.append(ss_path)

        # Remove the old snapshots if there are too many
        if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
          self.remove_snapshot(np_paths, ss_paths)

      # Increase iteration counter
      iter += 1

    # Potentially save one last time
    if last_snapshot_iter != iter:
      self.snapshot(sess, iter + snapshot_add)
Пример #4
0
class SolverWrapper(object):
    """
    A wrapper class for the training process
  """
    def __init__(self,
                 network,
                 imdb,
                 roidb,
                 valroidb,
                 output_dir,
                 tbdir,
                 pretrained_model=None):
        self.net = network
        self.imdb = imdb
        self.roidb = roidb
        self.valroidb = valroidb
        self.output_dir = output_dir
        self.tbdir = tbdir
        # Simply put '_val' at the end to save the summaries from the validation set
        self.tbvaldir = tbdir + '_val'
        if not os.path.exists(self.tbvaldir):
            os.makedirs(self.tbvaldir, exist_ok=True)
        self.pretrained_model = pretrained_model

    def snapshot(self, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.pth'
        filename = os.path.join(self.output_dir, filename)
        torch.save(self.net.state_dict(), filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indexes of the database
        perm = self.data_layer._perm
        # current position in the validation database
        cur_val = self.data_layer_val._cur
        # current shuffled indexes of the validation database
        perm_val = self.data_layer_val._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename

    def from_snapshot(self, sfile, nfile):
        print('Restoring model snapshots from {:s}'.format(sfile))
        self.net.load_state_dict(torch.load(str(sfile)))
        print('Restored.')
        # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
        # tried my best to find the random states so that it can be recovered exactly
        # However the Tensorflow state is currently not available
        with open(nfile, 'rb') as fid:
            st0 = pickle.load(fid)
            cur = pickle.load(fid)
            perm = pickle.load(fid)
            cur_val = pickle.load(fid)
            perm_val = pickle.load(fid)
            last_snapshot_iter = pickle.load(fid)

            np.random.set_state(st0)
            self.data_layer._cur = cur
            self.data_layer._perm = perm
            self.data_layer_val._cur = cur_val
            self.data_layer_val._perm = perm_val

        return last_snapshot_iter

    def construct_graph(self):
        # Set the random seed
        torch.manual_seed(cfg.RNG_SEED)
        # Build the main computation graph
        self.net.create_architecture(self.imdb.num_classes,
                                     tag='default',
                                     anchor_scales=cfg.ANCHOR_SCALES,
                                     anchor_ratios=cfg.ANCHOR_RATIOS)
        # Define the loss
        # loss = layers['total_loss']
        # Set learning rate and momentum
        lr = cfg.TRAIN.LEARNING_RATE
        params = []
        for key, value in dict(self.net.named_parameters()).items():
            if value.requires_grad:
                if 'bias' in key:
                    params += [{
                        'params': [value],
                        'lr':
                        lr * (cfg.TRAIN.DOUBLE_BIAS + 1),
                        'weight_decay':
                        cfg.TRAIN.BIAS_DECAY and cfg.TRAIN.WEIGHT_DECAY or 0
                    }]
                else:
                    params += [{
                        'params': [value],
                        'lr': lr,
                        'weight_decay': cfg.TRAIN.WEIGHT_DECAY
                    }]
        self.optimizer = torch.optim.SGD(params, momentum=cfg.TRAIN.MOMENTUM)
        # Write the train and validation information to tensorboard
        self.writer = tb.writer.FileWriter(self.tbdir)
        self.valwriter = tb.writer.FileWriter(self.tbvaldir)

        return lr, self.optimizer

    def find_previous(self):
        sfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pth')
        sfiles = glob.glob(sfiles)
        sfiles.sort(key=os.path.getmtime)
        # Get the snapshot name in pytorch
        redfiles = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            redfiles.append(
                os.path.join(
                    self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX +
                    '_iter_{:d}.pth'.format(stepsize + 1)))
        sfiles = [ss for ss in sfiles if ss not in redfiles]

        nfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
        nfiles = glob.glob(nfiles)
        nfiles.sort(key=os.path.getmtime)
        redfiles = [redfile.replace('.pth', '.pkl') for redfile in redfiles]
        nfiles = [nn for nn in nfiles if nn not in redfiles]

        lsf = len(sfiles)
        assert len(nfiles) == lsf

        return lsf, nfiles, sfiles

    def initialize(self):
        # Initial file lists are empty
        np_paths = []
        ss_paths = []
        # Fresh train directly from ImageNet weights
        print('Loading initial model weights from {:s}'.format(
            self.pretrained_model))
        self.net.load_pretrained_cnn(torch.load(self.pretrained_model))
        print('Loaded.')
        # Need to fix the variables before loading, so that the RGB weights are changed to BGR
        # For VGG16 it also changes the convolutional weights fc6 and fc7 to
        # fully connected weights
        last_snapshot_iter = 0
        lr = cfg.TRAIN.LEARNING_RATE
        stepsizes = list(cfg.TRAIN.STEPSIZE)

        return lr, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def restore(self, sfile, nfile):
        # Get the most recent snapshot and restore
        np_paths = [nfile]
        ss_paths = [sfile]
        # Restore model from snapshots
        last_snapshot_iter = self.from_snapshot(sfile, nfile)
        # Set the learning rate
        lr_scale = 1
        stepsizes = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            if last_snapshot_iter > stepsize:
                lr_scale *= cfg.TRAIN.GAMMA
            else:
                stepsizes.append(stepsize)
        scale_lr(self.optimizer, lr_scale)
        lr = cfg.TRAIN.LEARNING_RATE * lr_scale
        return lr, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def remove_snapshot(self, np_paths, ss_paths):
        to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            nfile = np_paths[0]
            os.remove(str(nfile))
            np_paths.remove(nfile)

        to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            sfile = ss_paths[0]
            # To make the code compatible to earlier versions of Tensorflow,
            # where the naming tradition for checkpoints are different
            os.remove(str(sfile))
            ss_paths.remove(sfile)

    def train_model(self, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb,
                                           self.imdb.num_classes,
                                           random=True)

        # Construct the computation graph
        lr, train_op = self.construct_graph()

        # Find previous snapshots if there is any to restore from
        lsf, nfiles, sfiles = self.find_previous()

        # Initialize the variables or restore them from the last snapshot
        if lsf == 0:
            lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(
            )
        else:
            lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(
                str(sfiles[-1]), str(nfiles[-1]))
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        # Make sure the lists are not empty
        stepsizes.append(max_iters)
        stepsizes.reverse()
        next_stepsize = stepsizes.pop()

        self.net.train()
        self.net.cuda()

        while iter < max_iters + 1:
            # Learning rate
            if iter == next_stepsize + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(iter)
                lr *= cfg.TRAIN.GAMMA
                scale_lr(self.optimizer, cfg.TRAIN.GAMMA)
                next_stepsize = stepsizes.pop()

            utils.timer.timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            now = time.time()
            if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
                  self.net.train_step_with_summary(blobs, self.optimizer)
                for _sum in summary:
                    self.writer.add_summary(_sum, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(blobs_val)
                for _sum in summary_val:
                    self.valwriter.add_summary(_sum, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
                  self.net.train_step(blobs, self.optimizer)
            utils.timer.timer.toc()

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
                      (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr))
                print('speed: {:.3f}s / iter'.format(
                    utils.timer.timer.average_time()))

                # for k in utils.timer.timer._average_time.keys():
                #   print(k, utils.timer.timer.average_time(k))

            # Snapshotting
            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                ss_path, np_path = self.snapshot(iter)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(iter - 1)

        self.writer.close()
        self.valwriter.close()
Пример #5
0
class SolverWrapper(object):
  """
    A wrapper class for the training process
  """

  def __init__(self, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None):
    self.net = network
    self.imdb = imdb
    self.roidb = roidb
    self.valroidb = valroidb
    self.output_dir = output_dir
    self.tbdir = tbdir
    # Simply put '_val' at the end to save the summaries from the validation set
    self.tbvaldir = tbdir + '_val'
    if not os.path.exists(self.tbvaldir):
      os.makedirs(self.tbvaldir)
    self.pretrained_model = pretrained_model

  def snapshot(self, iter):
    net = self.net

    if not os.path.exists(self.output_dir):
      os.makedirs(self.output_dir)

    # Store the model snapshot
    filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pth'
    filename = os.path.join(self.output_dir, filename)
    torch.save(self.net.state_dict(), filename)
    print('Wrote snapshot to: {:s}'.format(filename))

    # Also store some meta information, random state, etc.
    nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
    nfilename = os.path.join(self.output_dir, nfilename)
    # current state of numpy random
    st0 = np.random.get_state()
    # current position in the database
    cur = self.data_layer._cur
    # current shuffled indexes of the database
    perm = self.data_layer._perm
    # current position in the validation database
    cur_val = self.data_layer_val._cur
    # current shuffled indexes of the validation database
    perm_val = self.data_layer_val._perm

    # Dump the meta info
    with open(nfilename, 'wb') as fid:
      pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

    return filename, nfilename

  def from_snapshot(self, sfile, nfile):
    print('Restoring model snapshots from {:s}'.format(sfile))
    self.net.load_state_dict(torch.load(str(sfile)))
    print('Restored.')
    # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
    # tried my best to find the random states so that it can be recovered exactly
    # However the Tensorflow state is currently not available
    with open(nfile, 'rb') as fid:
      st0 = pickle.load(fid)
      cur = pickle.load(fid)
      perm = pickle.load(fid)
      cur_val = pickle.load(fid)
      perm_val = pickle.load(fid)
      last_snapshot_iter = pickle.load(fid)

      np.random.set_state(st0)
      self.data_layer._cur = cur
      self.data_layer._perm = perm
      self.data_layer_val._cur = cur_val
      self.data_layer_val._perm = perm_val

    return last_snapshot_iter

  def construct_graph(self):
    # Set the random seed
    torch.manual_seed(cfg.RNG_SEED)
    # Build the main computation graph
    self.net.create_architecture(self.imdb.num_classes, tag='default',
                                            anchor_scales=cfg.ANCHOR_SCALES,
                                            anchor_ratios=cfg.ANCHOR_RATIOS)
    # Define the loss
    # loss = layers['total_loss']
    # Set learning rate and momentum
    lr = cfg.TRAIN.LEARNING_RATE
    params = []
    for key, value in dict(self.net.named_parameters()).items():
      if value.requires_grad:
        if 'bias' in key:
          params += [{'params':[value],'lr':lr*(cfg.TRAIN.DOUBLE_BIAS + 1), 'weight_decay': cfg.TRAIN.BIAS_DECAY and cfg.TRAIN.WEIGHT_DECAY or 0}]
        else:
          params += [{'params':[value],'lr':lr, 'weight_decay': cfg.TRAIN.WEIGHT_DECAY}]
    self.optimizer = torch.optim.SGD(params, momentum=cfg.TRAIN.MOMENTUM)
    # Write the train and validation information to tensorboard
    self.writer = tb.writer.FileWriter(self.tbdir)
    self.valwriter = tb.writer.FileWriter(self.tbvaldir)

    return lr, self.optimizer

  def find_previous(self):
    sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pth')
    sfiles = glob.glob(sfiles)
    sfiles.sort(key=os.path.getmtime)
    # Get the snapshot name in pytorch
    redfiles = []
    for stepsize in cfg.TRAIN.STEPSIZE:
      redfiles.append(os.path.join(self.output_dir, 
                      cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}.pth'.format(stepsize+1)))
    sfiles = [ss for ss in sfiles if ss not in redfiles]

    nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
    nfiles = glob.glob(nfiles)
    nfiles.sort(key=os.path.getmtime)
    redfiles = [redfile.replace('.pth', '.pkl') for redfile in redfiles]
    nfiles = [nn for nn in nfiles if nn not in redfiles]

    lsf = len(sfiles)
    assert len(nfiles) == lsf

    return lsf, nfiles, sfiles

  def initialize(self):
    # Initial file lists are empty
    np_paths = []
    ss_paths = []
    # Fresh train directly from ImageNet weights
    print('Loading initial model weights from {:s}'.format(self.pretrained_model))
    self.net.load_pretrained_cnn(torch.load(self.pretrained_model))
    print('Loaded.')
    # Need to fix the variables before loading, so that the RGB weights are changed to BGR
    # For VGG16 it also changes the convolutional weights fc6 and fc7 to
    # fully connected weights
    last_snapshot_iter = 0
    lr = cfg.TRAIN.LEARNING_RATE
    stepsizes = list(cfg.TRAIN.STEPSIZE)

    return lr, last_snapshot_iter, stepsizes, np_paths, ss_paths

  def restore(self, sfile, nfile):
    # Get the most recent snapshot and restore
    np_paths = [nfile]
    ss_paths = [sfile]
    # Restore model from snapshots
    last_snapshot_iter = self.from_snapshot(sfile, nfile)
    # Set the learning rate
    lr_scale = 1
    stepsizes = []
    for stepsize in cfg.TRAIN.STEPSIZE:
      if last_snapshot_iter > stepsize:
        lr_scale *= cfg.TRAIN.GAMMA
      else:
        stepsizes.append(stepsize)
    scale_lr(self.optimizer, lr_scale)
    lr = cfg.TRAIN.LEARNING_RATE * lr_scale
    return lr, last_snapshot_iter, stepsizes, np_paths, ss_paths

  def remove_snapshot(self, np_paths, ss_paths):
    to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
    for c in range(to_remove):
      nfile = np_paths[0]
      os.remove(str(nfile))
      np_paths.remove(nfile)

    to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
    for c in range(to_remove):
      sfile = ss_paths[0]
      # To make the code compatible to earlier versions of Tensorflow,
      # where the naming tradition for checkpoints are different
      os.remove(str(sfile))
      ss_paths.remove(sfile)

  def train_model(self, max_iters):
    # Build data layers for both training and validation set
    self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
    self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)

    # Construct the computation graph
    lr, train_op = self.construct_graph()

    # Find previous snapshots if there is any to restore from
    lsf, nfiles, sfiles = self.find_previous()

    # Initialize the variables or restore them from the last snapshot
    if lsf == 0:
      lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize()
    else:
      lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(str(sfiles[-1]), 
                                                                             str(nfiles[-1]))
    iter = last_snapshot_iter + 1
    last_summary_time = time.time()
    # Make sure the lists are not empty
    stepsizes.append(max_iters)
    stepsizes.reverse()
    next_stepsize = stepsizes.pop()

    self.net.train()
    self.net.cuda()

    while iter < max_iters + 1:
      # Learning rate
      if iter == next_stepsize + 1:
        # Add snapshot here before reducing the learning rate
        self.snapshot(iter)
        lr *= cfg.TRAIN.GAMMA
        scale_lr(self.optimizer, cfg.TRAIN.GAMMA)
        next_stepsize = stepsizes.pop()

      utils.timer.timer.tic()
      # Get training data, one batch at a time
      blobs = self.data_layer.forward()

      now = time.time()
      if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
        # Compute the graph with summary
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
          self.net.train_step_with_summary(blobs, self.optimizer)
        for _sum in summary: self.writer.add_summary(_sum, float(iter))
        # Also check the summary on the validation set
        blobs_val = self.data_layer_val.forward()
        summary_val = self.net.get_summary(blobs_val)
        for _sum in summary_val: self.valwriter.add_summary(_sum, float(iter))
        last_summary_time = now
      else:
        # Compute the graph without summary
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
          self.net.train_step(blobs, self.optimizer)
      utils.timer.timer.toc()

      # Display training information
      if iter % (cfg.TRAIN.DISPLAY) == 0:
        print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
              '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
              (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr))
        print('speed: {:.3f}s / iter'.format(utils.timer.timer.average_time()))

        # for k in utils.timer.timer._average_time.keys():
        #   print(k, utils.timer.timer.average_time(k))

      # Snapshotting
      if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
        last_snapshot_iter = iter
        ss_path, np_path = self.snapshot(iter)
        np_paths.append(np_path)
        ss_paths.append(ss_path)

        # Remove the old snapshots if there are too many
        if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
          self.remove_snapshot(np_paths, ss_paths)

      iter += 1

    if last_snapshot_iter != iter - 1:
      self.snapshot(iter - 1)

    self.writer.close()
    self.valwriter.close()
Пример #6
0
class SolverWrapper(object):
    """
    A wrapper class for the training process
  """
    def __init__(self,
                 sess,
                 network,
                 imdb,
                 roidb,
                 valroidb,
                 output_dir,
                 tbdir,
                 pretrained_model=None):
        self.net = network
        self.imdb = imdb
        self.roidb = roidb
        self.valroidb = valroidb
        self.output_dir = output_dir
        self.tbdir = tbdir
        # Simply put '_val' at the end to save the summaries from the validation set
        self.tbvaldir = tbdir + '_val'
        if not os.path.exists(self.tbvaldir):
            os.makedirs(self.tbvaldir)
        self.pretrained_model = pretrained_model

    def snapshot(self, sess, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.ckpt'
        filename = os.path.join(self.output_dir, filename)
        self.saver.save(sess, filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indexes of the database
        perm = self.data_layer._perm
        # current position in the validation database
        cur_val = self.data_layer_val._cur
        # current shuffled indexes of the validation database
        perm_val = self.data_layer_val._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename

    def from_snapshot(self, sess, sfile, nfile):
        print('Restoring model snapshots from {:s}'.format(sfile))
        self.saver.restore(sess, sfile)
        print('Restored.')
        # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
        # tried my best to find the random states so that it can be recovered exactly
        # However the Tensorflow state is currently not available
        with open(nfile, 'rb') as fid:
            st0 = pickle.load(fid)
            cur = pickle.load(fid)
            perm = pickle.load(fid)
            cur_val = pickle.load(fid)
            perm_val = pickle.load(fid)
            last_snapshot_iter = pickle.load(fid)

            np.random.set_state(st0)
            self.data_layer._cur = cur
            self.data_layer._perm = perm
            self.data_layer_val._cur = cur_val
            self.data_layer_val._perm = perm_val

        return last_snapshot_iter

    def get_variables_in_checkpoint_file(self, file_name):
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print(
                    "It's likely that your checkpoint file has been compressed "
                    "with SNAPPY.")

    def construct_graph(self, sess):
        with sess.graph.as_default():
            # Set the random seed for tensorflow
            tf.set_random_seed(cfg.RNG_SEED)
            # Build the main computation graph
            layers = self.net.create_architecture(
                'TRAIN',
                self.imdb.num_classes,
                tag='default',
                anchor_scales=cfg.ANCHOR_SCALES,
                anchor_ratios=cfg.ANCHOR_RATIOS)
            # Define the loss
            loss = layers['total_loss']
            # Set learning rate and momentum
            lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
            self.optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)

            # Compute the gradients with regard to the loss
            gvs = self.optimizer.compute_gradients(loss)
            # Double the gradient of the bias if set
            if cfg.TRAIN.DOUBLE_BIAS:
                final_gvs = []
                with tf.variable_scope('Gradient_Mult') as scope:
                    for grad, var in gvs:
                        scale = 1.
                        if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((grad, var))
                train_op = self.optimizer.apply_gradients(final_gvs)
            else:
                train_op = self.optimizer.apply_gradients(gvs)

            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=100000)
            # Write the train and validation information to tensorboard
            self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
            self.valwriter = tf.summary.FileWriter(self.tbvaldir)

        return lr, train_op

    def find_previous(self):
        sfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
        sfiles = glob.glob(sfiles)
        sfiles.sort(key=os.path.getmtime)
        # Get the snapshot name in TensorFlow
        redfiles = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            redfiles.append(
                os.path.join(
                    self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX +
                    '_iter_{:d}.ckpt.meta'.format(stepsize + 1)))
        sfiles = [
            ss.replace('.meta', '') for ss in sfiles if ss not in redfiles
        ]

        nfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
        nfiles = glob.glob(nfiles)
        nfiles.sort(key=os.path.getmtime)
        redfiles = [
            redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles
        ]
        nfiles = [nn for nn in nfiles if nn not in redfiles]

        lsf = len(sfiles)
        assert len(nfiles) == lsf

        return lsf, nfiles, sfiles

    def initialize(self, sess):
        # Initial file lists are empty
        np_paths = []
        ss_paths = []
        # Fresh train directly from ImageNet weights
        print('Loading initial model weights from {:s}'.format(
            self.pretrained_model))
        variables = tf.global_variables()
        # Initialize all variables first
        sess.run(tf.variables_initializer(variables, name='init'))
        var_keep_dic = self.get_variables_in_checkpoint_file(
            self.pretrained_model)
        # Get the variables to restore, ignoring the variables to fix
        variables_to_restore = self.net.get_variables_to_restore(
            variables, var_keep_dic)
        restorer = tf.train.Saver(variables_to_restore)
        restorer.restore(sess, self.pretrained_model)
        print('Loaded.')
        # Need to fix the variables before loading, so that the RGB weights are changed to BGR
        # For VGG16 it also changes the convolutional weights fc6 and fc7 to
        # fully connected weights
        self.net.fix_variables(sess, self.pretrained_model)
        print('Fixed.')
        last_snapshot_iter = 0
        rate = cfg.TRAIN.LEARNING_RATE
        stepsizes = list(cfg.TRAIN.STEPSIZE)

        return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def restore(self, sess, sfile, nfile):
        # Get the most recent snapshot and restore
        np_paths = [nfile]
        ss_paths = [sfile]
        # Restore model from snapshots
        last_snapshot_iter = self.from_snapshot(sess, sfile, nfile)
        # Set the learning rate
        rate = cfg.TRAIN.LEARNING_RATE
        stepsizes = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            if last_snapshot_iter > stepsize:
                rate *= cfg.TRAIN.GAMMA
            else:
                stepsizes.append(stepsize)

        return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def remove_snapshot(self, np_paths, ss_paths):
        to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            nfile = np_paths[0]
            os.remove(str(nfile))
            np_paths.remove(nfile)

        to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            sfile = ss_paths[0]
            # To make the code compatible to earlier versions of Tensorflow,
            # where the naming tradition for checkpoints are different
            if os.path.exists(str(sfile)):
                os.remove(str(sfile))
            else:
                os.remove(str(sfile + '.data-00000-of-00001'))
                os.remove(str(sfile + '.index'))
            sfile_meta = sfile + '.meta'
            os.remove(str(sfile_meta))
            ss_paths.remove(sfile)

    def train_model(self, sess, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb,
                                           self.imdb.num_classes,
                                           random=True)

        # Construct the computation graph
        lr, train_op = self.construct_graph(sess)

        # Find previous snapshots if there is any to restore from
        lsf, nfiles, sfiles = self.find_previous()

        # Initialize the variables or restore them from the last snapshot
        if lsf == 0:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(
                sess)
        else:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(
                sess, str(sfiles[-1]), str(nfiles[-1]))
        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        # Make sure the lists are not empty
        stepsizes.append(max_iters)
        stepsizes.reverse()
        next_stepsize = stepsizes.pop()
        while iter < max_iters + 1:
            # Learning rate
            if iter == next_stepsize + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(sess, iter)
                rate *= cfg.TRAIN.GAMMA
                sess.run(tf.assign(lr, rate))
                next_stepsize = stepsizes.pop()

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            now = time.time()
            if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
                  self.net.train_step_with_summary(sess, blobs, train_op)
                self.writer.add_summary(summary, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(sess, blobs_val)
                self.valwriter.add_summary(summary_val, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
                  self.net.train_step(sess, blobs, train_op)
            timer.toc()

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
                      (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
                print('speed: {:.3f}s / iter'.format(timer.average_time))

            # Snapshotting
            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                ss_path, np_path = self.snapshot(sess, iter)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(sess, iter - 1)

        self.writer.close()
        self.valwriter.close()
Пример #7
0
class Solver(object):
  def __init__(self, roidb, net, freeze=0):
    # Holds current iteration number. 
    self.iter = 0

    # How frequently we should print the training info.
    self.display_freq = 1

    # Holds the path prefix for snapshots.
    self.snapshot_prefix = 'snapshot'
    
    self.roidb = roidb
    self.net = net
    self.freeze = freeze
    self.roi_data_layer = RoIDataLayer()
    self.roi_data_layer.setup()
    self.roi_data_layer.set_roidb(self.roidb)
    self.stepfn = self.build_step_fn(self.net)
    self.predfn = self.build_pred_fn(self.net)

  # This might be a useful static method to have.
  #@staticmethod not so static anymore
  def build_step_fn(self, net):
    target_y = T.vector("target Y",dtype='int64')
    tl = lasagne.objectives.categorical_crossentropy(net.prediction,target_y)
    loss = tl.mean()
    accuracy = lasagne.objectives.categorical_accuracy(net.prediction,target_y).mean()
   
    weights = net.params
    grads = theano.grad(loss, weights)
    
    scales = np.ones(len(weights))

    if self.freeze:
        scales[:-self.freeze] = 0
        
    print 'GRAD SCALE >>>', scales
    
    for idx, param in enumerate(weights):
        grads[idx] *= scales[idx]
        grads[idx] = grads[idx].astype('float32')
        
    #updates_sgd = lasagne.updates.sgd(loss, net.params, learning_rate=0.0001)
    updates_sgd = lasagne.updates.sgd(grads, net.params, learning_rate=0.0001)
    
    stepfn = theano.function([net.inp, target_y], [loss, accuracy], updates=updates_sgd, allow_input_downcast=True)
    return stepfn

  @staticmethod
  def build_pred_fn(net):
    predfn = theano.function([net.inp], net.prediction, allow_input_downcast=True)
    return predfn

  def get_training_batch(self):
    """Uses ROIDataLayer to fetch a training batch.

    Returns:
      input_data (ndarray): input data suitable for R-CNN processing
      labels (ndarray): batch labels (of type int32)
    """
    data, rois, labels = deepcopy(self.roi_data_layer.top[: 3])
    X = roi_layer(data, rois)
    y = labels.astype('int')
    
    return X, y

  def step(self):
    self.roi_data_layer.forward()
    data, labels = self.get_training_batch()
    """Conducts a single step of SGD."""
    
    loss, acc = self.stepfn(data, labels)
    
    self.loss = loss
    self.acc = acc
    ###################################################### Your code goes here.
    # Among other things, assign the current loss value to self.loss.

    self.iter += 1
    if self.iter % self.display_freq == 0:
      print 'Iteration {:<5} Train loss: {} Train acc: {}'.format(self.iter, self.loss, self.acc)

  def save(self, filename):
    self.net.save(filename)
Пример #8
0
class SolverWrapper(object):
  """
    A wrapper class for the training process
  """

  def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None):
    self.net = network      # network类的实例
    self.imdb = imdb        # imdb类的实例
    self.roidb = roidb      # roidb字典
    self.valroidb = valroidb     # 验证roidb字典
    self.output_dir = output_dir  # 模型保存路径
    self.tbdir = tbdir            # tensorboard保存路径
    # Simply put '_val' at the end to save the summaries from the validation set
    self.tbvaldir = tbdir + '_val'      # 验证过程的tensorboard保存路径
    if not os.path.exists(self.tbvaldir):
      os.makedirs(self.tbvaldir)
    self.pretrained_model = pretrained_model      # 预训练权重的路径

  # 保存快照,包括模型权重的ckpt文件和训练参数的pkl文件
  def snapshot(self, sess, iter):
    net = self.net

    if not os.path.exists(self.output_dir):
      os.makedirs(self.output_dir)

    # Store the model snapshot
    # 保存模型权重 例:shufflenetv2_faster_rcnn_iter_10000.cpkt等三个文件
    # SNAPSHOT_PREFIX在yml文件中配置 例:'shufflenetv2_faster_rcnn'
    filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt'
    filename = os.path.join(self.output_dir, filename)
    self.saver.save(sess, filename)
    print('Wrote snapshot to: {:s}'.format(filename))

    # Also store some meta information, random state, etc.
    # 保存随机数状态等训练过程参数
    nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
    nfilename = os.path.join(self.output_dir, nfilename)
    # current state of numpy random
    # 保存随机数种子
    st0 = np.random.get_state()
    # current position in the database
    # 保存当前图片序号
    cur = self.data_layer._cur
    # current shuffled indexes of the database
    # 保存打乱后的图片序号列表
    perm = self.data_layer._perm
    # current position in the validation database
    # 验证过程,同上
    cur_val = self.data_layer_val._cur
    # current shuffled indexes of the validation database
    perm_val = self.data_layer_val._perm

    # Dump the meta info
    # 写入 例:shufflenetv2_faster_rcnn_iter_10000.pkl
    with open(nfilename, 'wb') as fid:
      pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

    return filename, nfilename

  # 载入快照
  def from_snapshot(self, sess, sfile, nfile):
    print('Restoring model snapshots from {:s}'.format(sfile))
    self.saver.restore(sess, sfile)
    print('Restored.')
    # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
    # tried my best to find the random states so that it can be recovered exactly
    # However the Tensorflow state is currently not available
    with open(nfile, 'rb') as fid:
      st0 = pickle.load(fid)
      cur = pickle.load(fid)
      perm = pickle.load(fid)
      cur_val = pickle.load(fid)
      perm_val = pickle.load(fid)
      last_snapshot_iter = pickle.load(fid)

      np.random.set_state(st0)
      self.data_layer._cur = cur
      self.data_layer._perm = perm
      self.data_layer_val._cur = cur_val
      self.data_layer_val._perm = perm_val

    return last_snapshot_iter

  # 返回checkpoint文件中得到的参数词典
  def get_variables_in_checkpoint_file(self, file_name):
    try:
      reader = pywrap_tensorflow.NewCheckpointReader(file_name)
      var_to_shape_map = reader.get_variable_to_shape_map()
      return var_to_shape_map 
    except Exception as e:  # pylint: disable=broad-except
      print(str(e))
      if "corrupted compressed block contents" in str(e):
        print("It's likely that your checkpoint file has been compressed "
              "with SNAPPY.")

  # 构建模型
  def construct_graph(self, sess):
    with sess.graph.as_default():
      # Set the random seed for tensorflow
      # 为了方便复现,随机数种子在cfg中设置,并在保存模型时保存在pkl文件中
      tf.set_random_seed(cfg.RNG_SEED)
      # Build the main computation graph
      # layers为模型输出,包含roi区域,loss值,预测结果,形式为字典
      layers = self.net.create_architecture('TRAIN', self.imdb.num_classes, tag='default',
                                            anchor_scales=cfg.ANCHOR_SCALES,
                                            anchor_ratios=cfg.ANCHOR_RATIOS)
      # Define the loss
      loss = layers['total_loss']
      # Set learning rate and momentum
      # 设置优化器,设定学习速率和动量
      lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
      self.optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)

      # 计算梯度,返回(gradient,variable)列表
      # Compute the gradients with regard to the loss
      gvs = self.optimizer.compute_gradients(loss)
      # Double the gradient of the bias if set
      # 在yml文件中重新设置为false,如果为true就将biases的梯度翻倍
      if cfg.TRAIN.DOUBLE_BIAS:
        final_gvs = []
        with tf.variable_scope('Gradient_Mult') as scope:
          for grad, var in gvs:
            scale = 1.
            if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
              scale *= 2.
            if not np.allclose(scale, 1.0):
              grad = tf.multiply(grad, scale)
            final_gvs.append((grad, var))
            # 更新梯度
        train_op = self.optimizer.apply_gradients(final_gvs)
      else:
        # 更新梯度
        train_op = self.optimizer.apply_gradients(gvs)

      # 创建Saver类,默认保存所有检查点
      # We will handle the snapshots ourselves
      self.saver = tf.train.Saver(max_to_keep=100000)
      # Write the train and validation information to tensorboard
      # 创建训练和验证过程tensorboard的filewrite类
      self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
      self.valwriter = tf.summary.FileWriter(self.tbvaldir)

    return lr, train_op

  def find_previous(self):
    # 设置匹配字符串
    sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
    # 返回当前目录下符合格式的文件列表
    sfiles = glob.glob(sfiles)
    # 按照文件最后修改时间排序
    sfiles.sort(key=os.path.getmtime)
    # Get the snapshot name in TensorFlow
    # 不是很懂这部分的目的
    # '''lb:应该用于查看是否之前保存过模型,如果有的restore,没有就initialize'''
    redfiles = []
    for stepsize in cfg.TRAIN.STEPSIZE:
      redfiles.append(os.path.join(self.output_dir, 
                      cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}.ckpt.meta'.format(stepsize+1)))
    sfiles = [ss.replace('.meta', '') for ss in sfiles if ss not in redfiles]

    nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
    nfiles = glob.glob(nfiles)
    nfiles.sort(key=os.path.getmtime)
    redfiles = [redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles]
    nfiles = [nn for nn in nfiles if nn not in redfiles]

    lsf = len(sfiles)
    assert len(nfiles) == lsf

    return lsf, nfiles, sfiles

  # 模型初始化
  def initialize(self, sess):
    # Initial file lists are empty
    # 作用是初始化变量
    np_paths = []
    ss_paths = []
    # Fresh train directly from ImageNet weights
    print('Loading initial model weights from {:s}'.format(self.pretrained_model))
    # 获取全部变量
    variables = tf.global_variables()
    # Initialize all variables first
    # 初始化全部变量
    sess.run(tf.variables_initializer(variables, name='init'))
    # 从预训练的checkpoint中获得变量和对应值的字典
    var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model)
    # Get the variables to restore, ignoring the variables to fix
    # 在resnet等子类中实现,获得需要重载的参数字典
    variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic)

    # 从预训练权重中重载variables_to_restore中的参数
    restorer = tf.train.Saver(variables_to_restore)
    restorer.restore(sess, self.pretrained_model)
    print('Loaded.')
    # Need to fix the variables before loading, so that the RGB weights are changed to BGR
    # For VGG16 it also changes the convolutional weights fc6 and fc7 to
    # fully connected weights
    # 部分参数需要转化后再重载,在每个子网络中实现
    self.net.fix_variables(sess, self.pretrained_model)
    print('Fixed.')
    last_snapshot_iter = 0
    rate = cfg.TRAIN.LEARNING_RATE
    stepsizes = list(cfg.TRAIN.STEPSIZE)

    return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

  # 获取最新的快照(用于重载)
  def restore(self, sess, sfile, nfile):
    # Get the most recent snapshot and restore
    np_paths = [nfile]
    ss_paths = [sfile]
    # Restore model from snapshots
    # 从pkl中回复变量
    last_snapshot_iter = self.from_snapshot(sess, sfile, nfile)
    # Set the learning rate
    # 初始学习速率
    rate = cfg.TRAIN.LEARNING_RATE
    stepsizes = []

    # 目前cfg.TRAIN.STEPSIZE仅有一项,为[30000],迭代超过30000次后,学习速率乘0.1
    # 在train_faster_rcnn.sh中根据数据集重新设置,voc_2007_trainval默认为[50000]
    for stepsize in cfg.TRAIN.STEPSIZE:
      if last_snapshot_iter > stepsize:
        rate *= cfg.TRAIN.GAMMA
      else:
        stepsizes.append(stepsize)

    return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

  # 删除多余的模型快照,默认保存3个
  def remove_snapshot(self, np_paths, ss_paths):
    to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
    for c in range(to_remove):
      nfile = np_paths[0]
      os.remove(str(nfile))
      np_paths.remove(nfile)

    to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
    for c in range(to_remove):
      sfile = ss_paths[0]
      # To make the code compatible to earlier versions of Tensorflow,
      # where the naming tradition for checkpoints are different
      if os.path.exists(str(sfile)):
        os.remove(str(sfile))
      else:
        os.remove(str(sfile + '.data-00000-of-00001'))
        os.remove(str(sfile + '.index'))
      sfile_meta = sfile + '.meta'
      os.remove(str(sfile_meta))
      ss_paths.remove(sfile)

  # 训练模型
  def train_model(self, sess, max_iters):     # max_iters在train_faster_rcnn.sh指定
    # Build data layers for both training and validation set
    # 创建RoIDataLayer类
    self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
    self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)

    # Construct the computation graph
    # 构建计算图
    lr, train_op = self.construct_graph(sess)

    # Find previous snapshots if there is any to restore from
    # 返回快照文件 lsf为快照文件个数
    lsf, nfiles, sfiles = self.find_previous()

    # Initialize the variables or restore them from the last snapshot
    # 如果没有快照文件就从头初始化,有则直接重载
    if lsf == 0:
      rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(sess)
    else:
      rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(sess, 
                                                                            str(sfiles[-1]), 
                                                                            str(nfiles[-1]))
    timer = Timer()
    iter = last_snapshot_iter + 1
    last_summary_time = time.time()

    # Make sure the lists are not empty
    # max_iter被添加到列表末尾,又反向排列成了第一个,然后pop从末尾开始提取元素
    # 在initialize中:stepsizes = list(cfg.TRAIN.STEPSIZE)
    # 在restore中:如果cfg.TRAIN.STEPSIZE中的值大于last_snapshot_iter的保存下来
    stepsizes.append(max_iters)
    stepsizes.reverse()
    next_stepsize = stepsizes.pop()
    while iter < max_iters + 1:
      # Learning rate
      # 每满足下一个stepsize,就保存一次快照,并将学习速率降低
      if iter == next_stepsize + 1:
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        rate *= cfg.TRAIN.GAMMA
        sess.run(tf.assign(lr, rate))
        next_stepsize = stepsizes.pop()

      # 开始计时
      timer.tic()
      # Get training data, one batch at a time
      # 提取一个batch的blobs数据
      # blobs包含三组键值对{data、gt_boxes、im_info},具体内容在roi_data_layer/minibatch.py中
      blobs = self.data_layer.forward()

      now = time.time()
      # cfg.TRAIN.SUMMARY_INTERVAL=180 每3分钟保存一次摘要
      if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
        # Compute the graph with summary
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
          self.net.train_step_with_summary(sess, blobs, train_op)
        # 保存摘要
        self.writer.add_summary(summary, float(iter))
        # Also check the summary on the validation set
        # 从验证集提取一个batch进行验证,并计算保存摘要
        blobs_val = self.data_layer_val.forward()
        summary_val = self.net.get_summary(sess, blobs_val)
        self.valwriter.add_summary(summary_val, float(iter))
        last_summary_time = now
      else:
        # Compute the graph without summary
        # 只训练,不计算保存摘要
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
          self.net.train_step(sess, blobs, train_op)
        # 结束计时
      timer.toc()

      # Display training information
      if iter % (cfg.TRAIN.DISPLAY) == 0:
        print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
              '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
              (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
        print('speed: {:.3f}s / iter'.format(timer.average_time))

      # Snapshotting
      # 每间隔SNAPSHOT_ITERS次保存一次快照
      if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
        last_snapshot_iter = iter
        ss_path, np_path = self.snapshot(sess, iter)
        np_paths.append(np_path)
        ss_paths.append(ss_path)

        # Remove the old snapshots if there are too many
        # 删除多余快照
        if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
          self.remove_snapshot(np_paths, ss_paths)

      iter += 1
    # 保存训练结束时的快照
    if last_snapshot_iter != iter - 1:
      self.snapshot(sess, iter - 1)

    self.writer.close()
    self.valwriter.close()
Пример #9
0
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# training
train_loss = 0
tp, tf, fg, bg = 0., 0., 0, 0
step_cnt = 0
re_cnt = False
t = Timer()
t.tic()

for step in range(start_step, end_step + 1):
    net.train()
    # get one batch
    blobs = data_layer.forward()
    im_data = blobs['data']  # im_data = (1, 600, 901, 3)
    rois = blobs['rois']  #rois = (64, 5)
    im_info = blobs['im_info']  # im_info = (1,3)
    gt_vec = blobs['labels']  # gt_vec = (1,20)
    # import pdb; pdb.set_trace()
    #gt_boxes = blobs['gt_boxes']
    # forward
    # rois = (128,5) im_info = (1,3) gt_vec = (1, 20), one-hot label for two streams
    net(im_data, rois, im_info, gt_vec)
    loss = net.loss
    train_loss += loss.data[0]
    step_cnt += 1

    # backward pass and update
    optimizer.zero_grad()
Пример #10
0
class SolverWrapper(object):
  """
    A wrapper class for the training process
  """

  def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None):
    self.net = network
    self.imdb = imdb
    self.roidb = roidb
    self.valroidb = valroidb
    self.output_dir = output_dir
    self.tbdir = tbdir
    # Simply put '_val' at the end to save the summaries from the validation set
    self.tbvaldir = tbdir + '_val'
    if not os.path.exists(self.tbvaldir):
      os.makedirs(self.tbvaldir)
    self.pretrained_model = pretrained_model

  def snapshot(self, sess, iter):
    net = self.net

    if not os.path.exists(self.output_dir):
      os.makedirs(self.output_dir)

    # Store the model snapshot
    filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt'
    filename = os.path.join(self.output_dir, filename)
    self.saver.save(sess, filename)
    print('Wrote snapshot to: {:s}'.format(filename))

    # Also store some meta information, random state, etc.
    nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
    nfilename = os.path.join(self.output_dir, nfilename)
    # current state of numpy random
    st0 = np.random.get_state()
    # current position in the database
    cur = self.data_layer._cur
    # current shuffled indeces of the database
    perm = self.data_layer._perm
    # current position in the validation database
    cur_val = self.data_layer_val._cur
    # current shuffled indeces of the validation database
    perm_val = self.data_layer_val._perm

    # Dump the meta info
    with open(nfilename, 'wb') as fid:
      pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

    return filename, nfilename

  def get_variables_in_checkpoint_file(self, file_name):
    try:
      reader = pywrap_tensorflow.NewCheckpointReader(file_name)
      var_to_shape_map = reader.get_variable_to_shape_map()
      return var_to_shape_map 
    except Exception as e:  # pylint: disable=broad-except
      print(str(e))
      if "corrupted compressed block contents" in str(e):
        print("It's likely that your checkpoint file has been compressed "
              "with SNAPPY.")

  def train_model(self, sess, max_iters):
    # Build data layers for both training and validation set
    self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
    self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)

    # Determine different scales for anchors, see paper
    with sess.graph.as_default():
      # Set the random seed for tensorflow
      tf.set_random_seed(cfg.RNG_SEED)
      # Build the main computation graph
      layers = self.net.create_architecture(sess, 'TRAIN', self.imdb.num_classes,
                                            tag='default', anchor_scales=cfg.ANCHOR_SCALES)
      # Define the loss
      loss = layers['total_loss']
      # Set learning rate and momentum
      lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
      momentum = cfg.TRAIN.MOMENTUM
      self.optimizer = tf.train.MomentumOptimizer(lr, momentum)

      # Compute the gradients wrt the loss
      gvs = self.optimizer.compute_gradients(loss)
      # Double the gradient of the bias if set
      if cfg.TRAIN.DOUBLE_BIAS:
        final_gvs = []
        with tf.variable_scope('Gradient_Mult') as scope:
          for grad, var in gvs:
            scale = 1.
            if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
              scale *= 2.
            if not np.allclose(scale, 1.0):
              grad = tf.multiply(grad, scale)
            final_gvs.append((grad, var))
        train_op = self.optimizer.apply_gradients(final_gvs)
      else:
        train_op = self.optimizer.apply_gradients(gvs)

      # We will handle the snapshots ourselves
      self.saver = tf.train.Saver(max_to_keep=100000)
      # Write the train and validation information to tensorboard
      self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
      self.valwriter = tf.summary.FileWriter(self.tbvaldir)

    # Find previous snapshots if there is any to restore from
    sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
    sfiles = glob.glob(sfiles)
    sfiles.sort(key=os.path.getmtime)
    # Get the snapshot name in TensorFlow
    sfiles = [ss.replace('.meta', '') for ss in sfiles]

    nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
    nfiles = glob.glob(nfiles)
    nfiles.sort(key=os.path.getmtime)

    lsf = len(sfiles)
    assert len(nfiles) == lsf

    np_paths = nfiles
    ss_paths = sfiles

    if lsf == 0:
      # Fresh train directly from ImageNet weights
      print('Loading initial model weights from {:s}'.format(self.pretrained_model))
      variables = tf.global_variables()

      # Only initialize the variables that were not initialized when the graph was built
      sess.run(tf.variables_initializer(variables, name='init'))
      var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model)
      variables_to_restore = []
      var_to_dic = {}
      # print(var_keep_dic)
      for v in variables:
          # exclude the conv weights that are fc weights in vgg16
          if v.name == 'vgg_16/fc6/weights:0' or v.name == 'vgg_16/fc7/weights:0':
            var_to_dic[v.name] = v
            continue
          if v.name.split(':')[0] in var_keep_dic:
            print('Varibles restored: %s' % v.name)
            variables_to_restore.append(v)

      restorer = tf.train.Saver(variables_to_restore)
      restorer.restore(sess, self.pretrained_model)
      print('Loaded.')
      sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))
      # A temporary solution to fix the vgg16 issue from conv weights to fc weights
      if self.net._arch == 'vgg16':
        print('Converting VGG16 fc layers..')
        with tf.device("/cpu:0"):
          fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
          fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
          restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv, "vgg_16/fc7/weights": fc7_conv})
          restorer_fc.restore(sess, self.pretrained_model)

          sess.run(tf.assign(var_to_dic['vgg_16/fc6/weights:0'], tf.reshape(fc6_conv, 
                              var_to_dic['vgg_16/fc6/weights:0'].get_shape())))
          sess.run(tf.assign(var_to_dic['vgg_16/fc7/weights:0'], tf.reshape(fc7_conv, 
                              var_to_dic['vgg_16/fc7/weights:0'].get_shape())))
      last_snapshot_iter = 0
    else:
      # Get the most recent snapshot and restore
      ss_paths = [ss_paths[-1]]
      np_paths = [np_paths[-1]]

      print('Restorining model snapshots from {:s}'.format(sfiles[-1]))
      self.saver.restore(sess, str(sfiles[-1]))
      print('Restored.')
      # Needs to restore the other hyperparameters/states for training, (TODO xinlei) I have
      # tried my best to find the random states so that it can be recovered exactly
      # However the Tensorflow state is currently not available
      with open(str(nfiles[-1]), 'rb') as fid:
        st0 = pickle.load(fid)
        cur = pickle.load(fid)
        perm = pickle.load(fid)
        cur_val = pickle.load(fid)
        perm_val = pickle.load(fid)
        last_snapshot_iter = pickle.load(fid)

        np.random.set_state(st0)
        self.data_layer._cur = cur
        self.data_layer._perm = perm
        self.data_layer_val._cur = cur_val
        self.data_layer_val._perm = perm_val

        # Set the learning rate, only reduce once
        if last_snapshot_iter > cfg.TRAIN.STEPSIZE:
          sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA))
        else:
          sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))

    timer = Timer()
    iter = last_snapshot_iter + 1
    last_summary_time = time.time()
    while iter < max_iters + 1:
      # Learning rate
      if iter == cfg.TRAIN.STEPSIZE + 1:
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA))

      timer.tic()
      # Get training data, one batch at a time
      blobs = self.data_layer.forward()

      now = time.time()
      if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
        # Compute the graph with summary
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
          self.net.train_step_with_summary(sess, blobs, train_op)
        self.writer.add_summary(summary, float(iter))
        # Also check the summary on the validation set
        blobs_val = self.data_layer_val.forward()
        summary_val = self.net.get_summary(sess, blobs_val)
        self.valwriter.add_summary(summary_val, float(iter))
        last_summary_time = now
      else:
        # Compute the graph without summary
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
          self.net.train_step(sess, blobs, train_op)
      timer.toc()

      # Display training information
      if iter % (cfg.TRAIN.DISPLAY) == 0:
        print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
              '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
              (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
        print('speed: {:.3f}s / iter'.format(timer.average_time))

      if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
        last_snapshot_iter = iter
        snapshot_path, np_path = self.snapshot(sess, iter)
        np_paths.append(np_path)
        ss_paths.append(snapshot_path)

        # Remove the old snapshots if there are too many
        if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
          to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
          for c in range(to_remove):
            nfile = np_paths[0]
            os.remove(str(nfile))
            np_paths.remove(nfile)

        if len(ss_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
          to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
          for c in range(to_remove):
            sfile = ss_paths[0]
            # To make the code compatible to earlier versions of Tensorflow,
            # where the naming tradition for checkpoints are different
            if os.path.exists(str(sfile)):
              os.remove(str(sfile))
            else:
              os.remove(str(sfile + '.data-00000-of-00001'))
              os.remove(str(sfile + '.index'))
            sfile_meta = sfile + '.meta'
            os.remove(str(sfile_meta))
            ss_paths.remove(sfile)

      iter += 1

    if last_snapshot_iter != iter - 1:
      self.snapshot(sess, iter - 1)

    self.writer.close()
    self.valwriter.close()
Пример #11
0
class SolverWrapper(object):
    """
      A wrapper class for the training process
    """

    def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None):
        self.net = network
        self.imdb = imdb
        self.roidb = roidb
        self.valroidb = valroidb
        self.output_dir = output_dir
        self.tbdir = tbdir
        # Simply put '_val' at the end to save the summaries from the validation set
        self.tbvaldir = tbdir + '_val'
        common.check_dir(self.tbvaldir)
        self.pretrained_model = pretrained_model

    def snapshot(self, sess, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt'
        filename = os.path.join(self.output_dir, filename)
        self.saver.save(sess, filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indexes of the database
        perm = self.data_layer._perm
        # current position in the validation database
        cur_val = self.data_layer_val._cur
        # current shuffled indexes of the validation database
        perm_val = self.data_layer_val._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename

    def from_snapshot(self, sess, sfile, nfile):
        print('Restoring model snapshots from {:s}'.format(sfile))
        self.saver.restore(sess, sfile)
        print('Restored.')
        # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
        # tried my best to find the random states so that it can be recovered exactly
        # However the Tensorflow state is currently not available
        with open(nfile, 'rb') as fid:
            st0 = pickle.load(fid)
            cur = pickle.load(fid)
            perm = pickle.load(fid)
            cur_val = pickle.load(fid)
            perm_val = pickle.load(fid)
            last_snapshot_iter = pickle.load(fid)

            np.random.set_state(st0)
            self.data_layer._cur = cur
            self.data_layer._perm = perm
            self.data_layer_val._cur = cur_val
            self.data_layer_val._perm = perm_val

        return last_snapshot_iter

    def get_variables_in_checkpoint_file(self, file_name):
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print("It's likely that your checkpoint file has been compressed "
                      "with SNAPPY.")

    def construct_graph(self, sess):
        with sess.graph.as_default():
            # Set the random seed for tensorflow
            tf.set_random_seed(cfg.RNG_SEED)
            # Build the main computation graph
            layers = self.net.create_architecture('TRAIN', self.imdb.num_classes, tag='default',
                                                  anchor_width=cfg.CTPN.ANCHOR_WIDTH,
                                                  anchor_h_ratio_step=cfg.CTPN.H_RADIO_STEP,
                                                  num_anchors=cfg.CTPN.NUM_ANCHORS)
            # Define the loss
            total_loss = layers['total_loss']
            # Set learning rate and momentum
            lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)


            if cfg.TRAIN.OPTIMIZER == 'Adam':
                self.optimizer = tf.train.AdamOptimizer(lr)
            elif cfg.TRAIN.OPTIMIZER == 'Momentum':
                self.optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)
            elif cfg.TRAIN.OPTIMIZER == 'RMS':
                self.optimizer = tf.train.RMSPropOptimizer(lr)
            else:
                raise NotImplementedError

            global_step = tf.Variable(0, trainable=False)
            with_clip = False
            if with_clip:
                tvars = tf.trainable_variables()
                grads, norm = tf.clip_by_global_norm(tf.gradients(total_loss, tvars), 10.0)
                train_op = self.optimizer.apply_gradients(list(zip(grads, tvars)), global_step=global_step)
            else:
                # required by tf.layers.batch_normalization()
                # add update ops(for moving_mean and moving_variance) as a dependency to the train_op
                # https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                with tf.control_dependencies(update_ops):
                    train_op = self.optimizer.minimize(total_loss, global_step=global_step)

            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=100000)
            # Write the train and validation information to tensorboard
            self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
            self.valwriter = tf.summary.FileWriter(self.tbvaldir)

        return lr, train_op

    def find_previous(self):
        sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
        sfiles = glob.glob(sfiles)
        sfiles.sort(key=os.path.getmtime)
        # Get the snapshot name in TensorFlow
        redfiles = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            redfiles.append(os.path.join(self.output_dir,
                                         cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}.ckpt.meta'.format(stepsize + 1)))
        sfiles = [ss.replace('.meta', '') for ss in sfiles if ss not in redfiles]

        nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
        nfiles = glob.glob(nfiles)
        nfiles.sort(key=os.path.getmtime)
        redfiles = [redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles]
        nfiles = [nn for nn in nfiles if nn not in redfiles]

        lsf = len(sfiles)
        assert len(nfiles) == lsf

        return lsf, nfiles, sfiles

    def restore_ckpt_from_dir(self, sess, net, checkpoint_dir):
        print("Restoring checkpoint from: " + checkpoint_dir)

        ckpt = tf.train.latest_checkpoint(checkpoint_dir)
        if ckpt is None:
            print("Checkpoint not found")
            exit(-1)

        meta_file = ckpt + '.meta'
        try:
            print('Restore variables from {}'.format(ckpt))
            print('Restore meta_filr from {}'.format(meta_file))
            saver = tf.train.Saver(net.variables_to_restore)
            saver.restore(sess, ckpt)
        except Exception:
            raise Exception("Can not restore from {}".format(checkpoint_dir))

    def initialize(self, sess):
        # Initial file lists are empty
        np_paths = []
        ss_paths = []

        variables = tf.global_variables()
        # Initialize all variables first
        sess.run(tf.variables_initializer(variables, name='init'))

        if self.pretrained_model is not None:
            if self.pretrained_model.endswith('.ckpt'):
                # Fresh train directly from ImageNet weights
                print('Loading initial model weights from {:s}'.format(self.pretrained_model))

                var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model)
                # Get the variables to restore, ignoring the variables to fix
                variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic)

                restorer = tf.train.Saver(variables_to_restore)
                restorer.restore(sess, self.pretrained_model)
                print('Loaded.')
            else:
                # Restore from checkpoint and meta file
                self.restore_ckpt_from_dir(sess, self.net, self.pretrained_model)
                print('Loaded.')

        last_snapshot_iter = 0
        rate = cfg.TRAIN.LEARNING_RATE
        stepsizes = list(cfg.TRAIN.STEPSIZE)

        return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def restore(self, sess, sfile, nfile):
        # Get the most recent snapshot and restore
        np_paths = [nfile]
        ss_paths = [sfile]
        # Restore model from snapshots
        last_snapshot_iter = self.from_snapshot(sess, sfile, nfile)
        # Set the learning rate
        rate = cfg.TRAIN.LEARNING_RATE
        stepsizes = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            if last_snapshot_iter > stepsize:
                rate *= cfg.TRAIN.GAMMA
            else:
                stepsizes.append(stepsize)

        return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def remove_snapshot(self, np_paths, ss_paths):
        to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            nfile = np_paths[0]
            os.remove(str(nfile))
            np_paths.remove(nfile)

        to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            sfile = ss_paths[0]
            # To make the code compatible to earlier versions of Tensorflow,
            # where the naming tradition for checkpoints are different
            if os.path.exists(str(sfile)):
                os.remove(str(sfile))
            else:
                os.remove(str(sfile + '.data-00000-of-00001'))
                os.remove(str(sfile + '.index'))
            sfile_meta = sfile + '.meta'
            os.remove(str(sfile_meta))
            ss_paths.remove(sfile)

    def train_model(self, sess, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)

        # Construct the computation graph
        lr, train_op = self.construct_graph(sess)

        # Find previous snapshots if there is any to restore from
        lsf, nfiles, sfiles = self.find_previous()

        # Initialize the variables or restore them from the last snapshot
        if lsf == 0:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(sess)
        else:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(sess,
                                                                                   str(sfiles[-1]),
                                                                                   str(nfiles[-1]))
        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        # Make sure the lists are not empty
        stepsizes.append(max_iters)
        stepsizes.reverse()
        next_stepsize = stepsizes.pop()
        while iter < max_iters + 1:
            # Learning rate
            if iter == next_stepsize + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(sess, iter)
                rate *= cfg.TRAIN.GAMMA
                sess.run(tf.assign(lr, rate))
                next_stepsize = stepsizes.pop()

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            now = time.time()
            if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                rpn_loss_cls, rpn_loss_box, rpn_loss, total_loss, summary = \
                    self.net.train_step_with_summary(sess, blobs, train_op)
                self.writer.add_summary(summary, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(sess, blobs_val)
                self.valwriter.add_summary(summary_val, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                rpn_loss_cls, rpn_loss_box, rpn_loss, total_loss, _ = \
                    self.net.train_step(sess, blobs, train_op)
            timer.toc()

            print('%d/%d time: %.3f total_loss: %.3f rpn_loss: %.3f rpn_loss_cls: %.3f '
                  'rpn_loss_box: %.3f lr: %f' % (
                      iter, max_iters, timer.diff, total_loss, rpn_loss, rpn_loss_cls, rpn_loss_box, lr.eval()))

            # Snapshotting
            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                ss_path, np_path = self.snapshot(sess, iter)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(sess, iter - 1)

        self.writer.close()
        self.valwriter.close()
Пример #12
0
class SolverWrapper(object):
    """
    A wrapper class for the training process
  """
    def __init__(self,
                 sess,
                 network,
                 imdb,
                 roidb,
                 valroidb,
                 output_dir,
                 tbdir,
                 pretrained_model=None):
        self.net = network
        self.imdb = imdb
        self.roidb = roidb
        self.valroidb = valroidb
        self.output_dir = output_dir
        self.tbdir = tbdir
        # Simply put '_val' at the end to save the summaries from the validation set
        self.tbvaldir = tbdir + '_val'
        if not os.path.exists(self.tbvaldir):
            os.makedirs(self.tbvaldir)
        self.pretrained_model = pretrained_model

    def snapshot(self, sess, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.ckpt'
        filename = os.path.join(self.output_dir, filename)
        self.saver.save(sess, filename)
        print 'Wrote snapshot to: {:s}'.format(filename)

        # Also store some meta information, random state, etc.
        nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indeces of the database
        perm = self.data_layer._perm
        # current position in the validation database
        cur_val = self.data_layer_val._cur
        # current shuffled indeces of the validation database
        perm_val = self.data_layer_val._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            cPickle.dump(st0, fid, cPickle.HIGHEST_PROTOCOL)
            cPickle.dump(cur, fid, cPickle.HIGHEST_PROTOCOL)
            cPickle.dump(perm, fid, cPickle.HIGHEST_PROTOCOL)
            cPickle.dump(cur_val, fid, cPickle.HIGHEST_PROTOCOL)
            cPickle.dump(perm_val, fid, cPickle.HIGHEST_PROTOCOL)
            cPickle.dump(iter, fid, cPickle.HIGHEST_PROTOCOL)

        return filename, nfilename

    def train_model(self, sess, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb,
                                           self.imdb.num_classes,
                                           random=True)

        # Determine different scales for anchors, see paper
        if self.imdb.name.startswith('voc'):
            anchors = [8, 16, 32]
        else:
            anchors = [4, 8, 16, 32]

        with sess.graph.as_default():
            # Set the random seed for tensorflow
            tf.set_random_seed(cfg.RNG_SEED)
            # Build the main computation graph
            layers = self.net.create_architecture(
                sess,
                "TRAIN",
                self.imdb.num_classes,
                caffe_weight_path=self.pretrained_model,
                tag='default',
                anchor_scales=anchors)
            # Define the loss
            loss = layers['total_loss']

            # Set learning rate and momentum
            lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
            momentum = cfg.TRAIN.MOMENTUM
            self.optimizer = tf.train.MomentumOptimizer(lr, momentum)
            # Compute the gradients wrt the loss
            gvs = self.optimizer.compute_gradients(loss)
            # Double the gradient of the bias if set
            if cfg.TRAIN.DOUBLE_BIAS:
                final_gvs = []
                with tf.variable_scope('Gradient_Mult') as scope:
                    for grad, var in gvs:
                        scale = 1.
                        if cfg.TRAIN.DOUBLE_BIAS and '/bias:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((grad, var))
                train_op = self.optimizer.apply_gradients(final_gvs)
            else:
                train_op = self.optimizer.apply_gradients(gvs)

            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=100000)
            # Write the train and validation information to tensorboard
            self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
            self.valwriter = tf.summary.FileWriter(self.tbvaldir)

        # Find previous snapshots if there is any to restore from
        sfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
        sfiles = glob.glob(sfiles)
        sfiles.sort(key=os.path.getmtime)
        # Get the snapshot name in TensorFlow
        sfiles = [ss.replace('.meta', '') for ss in sfiles]

        nfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
        nfiles = glob.glob(nfiles)
        nfiles.sort(key=os.path.getmtime)

        lsf = len(sfiles)
        assert len(nfiles) == lsf

        np_paths = nfiles
        ss_paths = sfiles

        if lsf == 0:
            # Fresh train directly from VGG weights
            print('Loading initial model weights from {:s}').format(
                self.pretrained_model)
            variables = tf.global_variables()
            # Only initialize the variables that were not initialized when the graph was built
            for vbs in self.net._initialized:
                variables.remove(vbs)
            sess.run(tf.variables_initializer(variables, name='init'))
            print 'Loaded.'
            sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))
            last_snapshot_iter = 0
        else:
            # Get the most recent snapshot and restore
            ss_paths = [ss_paths[-1]]
            np_paths = [np_paths[-1]]

            print('Restorining model snapshots from {:s}').format(sfiles[-1])
            self.saver.restore(sess, str(sfiles[-1]))
            print 'Restored.'
            # Needs to restore the other hyperparameters/states for training, (TODO xinlei) I have
            # tried my best to find the random states so that it can be recovered exactly
            # However the Tensorflow state is currently not available
            with open(str(nfiles[-1]), 'rb') as fid:
                st0 = cPickle.load(fid)
                cur = cPickle.load(fid)
                perm = cPickle.load(fid)
                cur_val = cPickle.load(fid)
                perm_val = cPickle.load(fid)
                last_snapshot_iter = cPickle.load(fid)

                np.random.set_state(st0)
                self.data_layer._cur = cur
                self.data_layer._perm = perm
                self.data_layer_val._cur = cur_val
                self.data_layer_val._perm = perm_val

                # Set the learning rate, only reduce once
                if last_snapshot_iter > cfg.TRAIN.STEPSIZE:
                    sess.run(
                        tf.assign(lr,
                                  cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA))
                else:
                    sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))

        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        while iter < max_iters + 1:
            # Learning rate
            if iter == cfg.TRAIN.STEPSIZE + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(sess, iter)
                sess.run(
                    tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA))

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            now = time.time()
            if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
                                self.net.train_step_with_summary(sess, blobs, train_op)
                self.writer.add_summary(summary, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(sess, blobs_val)
                self.valwriter.add_summary(summary_val, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
                                self.net.train_step(sess, blobs, train_op)
            timer.toc()

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
                      (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
                print 'speed: {:.3f}s / iter'.format(timer.average_time)

            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                snapshot_path, np_path = self.snapshot(sess, iter)
                np_paths.append(np_path)
                ss_paths.append(snapshot_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
                    for c in xrange(to_remove):
                        nfile = np_paths[0]
                        os.remove(str(nfile))
                        np_paths.remove(nfile)

                if len(ss_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
                    for c in xrange(to_remove):
                        sfile = ss_paths[0]
                        # To make the code compatible to earlier versions of Tensorflow,
                        # where the naming tradition for checkpoints are different
                        if os.path.exists(str(sfile)):
                            os.remove(str(sfile))
                        else:
                            os.remove(str(sfile + '.data-00000-of-00001'))
                            os.remove(str(sfile + '.index'))
                        sfile_meta = sfile + '.meta'
                        os.remove(str(sfile_meta))
                        ss_paths.remove(sfile)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(sess, iter - 1)

        self.writer.close()
        self.valwriter.close()
Пример #13
0
class SolverWrapper(object):
    """
      A wrapper class for the training process
    """
    def __init__(self,
                 sess,
                 network,
                 imdb,
                 roidb,
                 valroidb,
                 output_dir,
                 tbdir,
                 pretrained_model=None):
        self.net = network
        self.imdb = imdb
        self.roidb = roidb
        self.valroidb = valroidb
        self.output_dir = output_dir
        self.tbdir = tbdir
        # Simply put '_val' at the end to save the summaries from the validation set
        self.tbvaldir = tbdir + '_val'
        common.check_dir(self.tbvaldir)
        self.pretrained_model = pretrained_model

    def snapshot(self, sess, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.ckpt'
        filename = os.path.join(self.output_dir, filename)
        self.saver.save(sess, filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indexes of the database
        perm = self.data_layer._perm
        # current position in the validation database
        cur_val = self.data_layer_val._cur
        # current shuffled indexes of the validation database
        perm_val = self.data_layer_val._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename

    def from_snapshot(self, sess, sfile, nfile):
        print('Restoring model snapshots from {:s}'.format(sfile))
        self.saver.restore(sess, sfile)
        print('Restored.')
        # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
        # tried my best to find the random states so that it can be recovered exactly
        # However the Tensorflow state is currently not available
        with open(nfile, 'rb') as fid:
            st0 = pickle.load(fid)
            cur = pickle.load(fid)
            perm = pickle.load(fid)
            cur_val = pickle.load(fid)
            perm_val = pickle.load(fid)
            last_snapshot_iter = pickle.load(fid)

            np.random.set_state(st0)
            self.data_layer._cur = cur
            self.data_layer._perm = perm
            self.data_layer_val._cur = cur_val
            self.data_layer_val._perm = perm_val

        return last_snapshot_iter

    def get_variables_in_checkpoint_file(self, file_name):
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print(
                    "It's likely that your checkpoint file has been compressed "
                    "with SNAPPY.")

    def construct_graph(self, sess):
        with sess.graph.as_default():
            # Set the random seed for tensorflow
            tf.set_random_seed(cfg.RNG_SEED)
            # Build the main computation graph
            layers = self.net.create_architecture(
                'TRAIN',
                self.imdb.num_classes,
                tag='default',
                anchor_width=cfg.CTPN.ANCHOR_WIDTH,
                anchor_h_ratio_step=cfg.CTPN.H_RADIO_STEP,
                num_anchors=cfg.CTPN.NUM_ANCHORS)
            # Define the loss
            total_loss = layers['total_loss']
            # Set learning rate and momentum
            lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)

            if cfg.TRAIN.OPTIMIZER == 'Adam':
                self.optimizer = tf.train.AdamOptimizer(lr)
            elif cfg.TRAIN.OPTIMIZER == 'Momentum':
                self.optimizer = tf.train.MomentumOptimizer(
                    lr, cfg.TRAIN.MOMENTUM)
            elif cfg.TRAIN.OPTIMIZER == 'RMS':
                self.optimizer = tf.train.RMSPropOptimizer(lr)
            else:
                raise NotImplementedError

            global_step = tf.Variable(0, trainable=False)
            with_clip = False
            if with_clip:
                tvars = tf.trainable_variables()
                grads, norm = tf.clip_by_global_norm(
                    tf.gradients(total_loss, tvars), 10.0)
                train_op = self.optimizer.apply_gradients(
                    list(zip(grads, tvars)), global_step=global_step)
            else:
                # required by tf.layers.batch_normalization()
                # add update ops(for moving_mean and moving_variance) as a dependency to the train_op
                # https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                with tf.control_dependencies(update_ops):
                    train_op = self.optimizer.minimize(total_loss,
                                                       global_step=global_step)

            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=100000)
            # Write the train and validation information to tensorboard
            self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
            self.valwriter = tf.summary.FileWriter(self.tbvaldir)

        return lr, train_op

    def find_previous(self):
        sfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
        sfiles = glob.glob(sfiles)
        sfiles.sort(key=os.path.getmtime)
        # Get the snapshot name in TensorFlow
        redfiles = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            redfiles.append(
                os.path.join(
                    self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX +
                    '_iter_{:d}.ckpt.meta'.format(stepsize + 1)))
        sfiles = [
            ss.replace('.meta', '') for ss in sfiles if ss not in redfiles
        ]

        nfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
        nfiles = glob.glob(nfiles)
        nfiles.sort(key=os.path.getmtime)
        redfiles = [
            redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles
        ]
        nfiles = [nn for nn in nfiles if nn not in redfiles]

        lsf = len(sfiles)
        assert len(nfiles) == lsf

        return lsf, nfiles, sfiles

    def restore_ckpt_from_dir(self, sess, net, checkpoint_dir):
        print("Restoring checkpoint from: " + checkpoint_dir)

        ckpt = tf.train.latest_checkpoint(checkpoint_dir)
        if ckpt is None:
            print("Checkpoint not found")
            exit(-1)

        meta_file = ckpt + '.meta'
        try:
            print('Restore variables from {}'.format(ckpt))
            print('Restore meta_filr from {}'.format(meta_file))
            saver = tf.train.Saver(net.variables_to_restore)
            saver.restore(sess, ckpt)
        except Exception:
            raise Exception("Can not restore from {}".format(checkpoint_dir))

    def initialize(self, sess):
        # Initial file lists are empty
        np_paths = []
        ss_paths = []

        variables = tf.global_variables()
        # Initialize all variables first
        sess.run(tf.variables_initializer(variables, name='init'))

        if self.pretrained_model is not None:
            if self.pretrained_model.endswith('.ckpt'):
                # Fresh train directly from ImageNet weights
                print('Loading initial model weights from {:s}'.format(
                    self.pretrained_model))

                var_keep_dic = self.get_variables_in_checkpoint_file(
                    self.pretrained_model)
                # Get the variables to restore, ignoring the variables to fix
                variables_to_restore = self.net.get_variables_to_restore(
                    variables, var_keep_dic)

                restorer = tf.train.Saver(variables_to_restore)
                restorer.restore(sess, self.pretrained_model)
                print('Loaded.')
            else:
                # Restore from checkpoint and meta file
                self.restore_ckpt_from_dir(sess, self.net,
                                           self.pretrained_model)
                print('Loaded.')

        last_snapshot_iter = 0
        rate = cfg.TRAIN.LEARNING_RATE
        stepsizes = list(cfg.TRAIN.STEPSIZE)

        return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def restore(self, sess, sfile, nfile):
        # Get the most recent snapshot and restore
        np_paths = [nfile]
        ss_paths = [sfile]
        # Restore model from snapshots
        last_snapshot_iter = self.from_snapshot(sess, sfile, nfile)
        # Set the learning rate
        rate = cfg.TRAIN.LEARNING_RATE
        stepsizes = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            if last_snapshot_iter > stepsize:
                rate *= cfg.TRAIN.GAMMA
            else:
                stepsizes.append(stepsize)

        return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def remove_snapshot(self, np_paths, ss_paths):
        to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            nfile = np_paths[0]
            os.remove(str(nfile))
            np_paths.remove(nfile)

        to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            sfile = ss_paths[0]
            # To make the code compatible to earlier versions of Tensorflow,
            # where the naming tradition for checkpoints are different
            if os.path.exists(str(sfile)):
                os.remove(str(sfile))
            else:
                os.remove(str(sfile + '.data-00000-of-00001'))
                os.remove(str(sfile + '.index'))
            sfile_meta = sfile + '.meta'
            os.remove(str(sfile_meta))
            ss_paths.remove(sfile)

    def train_model(self, sess, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb,
                                           self.imdb.num_classes,
                                           random=True)

        # Construct the computation graph
        lr, train_op = self.construct_graph(sess)

        # Find previous snapshots if there is any to restore from
        lsf, nfiles, sfiles = self.find_previous()

        # Initialize the variables or restore them from the last snapshot
        if lsf == 0:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(
                sess)
        else:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(
                sess, str(sfiles[-1]), str(nfiles[-1]))
        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        # Make sure the lists are not empty
        stepsizes.append(max_iters)
        stepsizes.reverse()
        next_stepsize = stepsizes.pop()
        while iter < max_iters + 1:
            # Learning rate
            if iter == next_stepsize + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(sess, iter)
                rate *= cfg.TRAIN.GAMMA
                sess.run(tf.assign(lr, rate))
                next_stepsize = stepsizes.pop()

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            now = time.time()
            if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                rpn_loss_cls, rpn_loss_box, rpn_loss, total_loss, summary = \
                    self.net.train_step_with_summary(sess, blobs, train_op)
                self.writer.add_summary(summary, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(sess, blobs_val)
                self.valwriter.add_summary(summary_val, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                rpn_loss_cls, rpn_loss_box, rpn_loss, total_loss, _ = \
                    self.net.train_step(sess, blobs, train_op)
            timer.toc()

            print(
                '%d/%d time: %.3f total_loss: %.3f rpn_loss: %.3f rpn_loss_cls: %.3f '
                'rpn_loss_box: %.3f lr: %f' %
                (iter, max_iters, timer.diff, total_loss, rpn_loss,
                 rpn_loss_cls, rpn_loss_box, lr.eval()))

            # Snapshotting
            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                ss_path, np_path = self.snapshot(sess, iter)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(sess, iter - 1)

        self.writer.close()
        self.valwriter.close()