class SolverWrapper(object): """ A wrapper class for the training process """ def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None): self.net = network self.imdb = imdb self.roidb = roidb self.valroidb = valroidb self.output_dir = output_dir self.tbdir = tbdir # Simply put '_val' at the end to save the summaries from the validation set self.tbvaldir = tbdir + '_val' if not os.path.exists(self.tbvaldir): os.makedirs(self.tbvaldir) self.pretrained_model = pretrained_model def snapshot(self, sess, iter): net = self.net if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) # Store the model snapshot filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format( iter) + '.ckpt' filename = os.path.join(self.output_dir, filename) self.saver.save(sess, filename) print('Wrote snapshot to: {:s}'.format(filename)) # Also store some meta information, random state, etc. nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format( iter) + '.pkl' nfilename = os.path.join(self.output_dir, nfilename) # current state of numpy random st0 = np.random.get_state() # current position in the database cur = self.data_layer._cur # current shuffled indeces of the database perm = self.data_layer._perm # current position in the validation database cur_val = self.data_layer_val._cur # current shuffled indeces of the validation database perm_val = self.data_layer_val._perm # Dump the meta info with open(nfilename, 'wb') as fid: pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL) return filename, nfilename def get_variables_in_checkpoint_file(self, file_name): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() return var_to_shape_map except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print( "It's likely that your checkpoint file has been compressed " "with SNAPPY.") def train_model(self, sess, max_iters): # Build data layers for both training and validation set self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes) self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True) # Determine different scales for anchors, see paper with sess.graph.as_default(): # Set the random seed for tensorflow tf.set_random_seed(cfg.RNG_SEED) # Build the main computation graph layers = self.net.create_architecture( sess, 'TRAIN', self.imdb.num_classes, tag='default', anchor_scales=cfg.ANCHOR_SCALES, anchor_ratios=cfg.ANCHOR_RATIOS) # Define the loss loss = layers['total_loss'] # Set learning rate and momentum lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) momentum = cfg.TRAIN.MOMENTUM self.optimizer = tf.train.MomentumOptimizer(lr, momentum) # Compute the gradients wrt the loss gvs = self.optimizer.compute_gradients(loss) # Double the gradient of the bias if set if cfg.TRAIN.DOUBLE_BIAS: final_gvs = [] with tf.variable_scope('Gradient_Mult') as scope: for grad, var in gvs: scale = 1. if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name: scale *= 2. if not np.allclose(scale, 1.0): grad = tf.multiply(grad, scale) final_gvs.append((grad, var)) train_op = self.optimizer.apply_gradients(final_gvs) else: train_op = self.optimizer.apply_gradients(gvs) # We will handle the snapshots ourselves self.saver = tf.train.Saver(max_to_keep=100000) # Write the train and validation information to tensorboard self.writer = tf.summary.FileWriter(self.tbdir, sess.graph) self.valwriter = tf.summary.FileWriter(self.tbvaldir) # Find previous snapshots if there is any to restore from sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta') sfiles = glob.glob(sfiles) sfiles.sort(key=os.path.getmtime) # Get the snapshot name in TensorFlow redstr = '_iter_{:d}.'.format(cfg.TRAIN.STEPSIZE + 1) sfiles = [ss.replace('.meta', '') for ss in sfiles] sfiles = [ss for ss in sfiles if redstr not in ss] nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl') nfiles = glob.glob(nfiles) nfiles.sort(key=os.path.getmtime) nfiles = [nn for nn in nfiles if redstr not in nn] lsf = len(sfiles) assert len(nfiles) == lsf np_paths = nfiles ss_paths = sfiles if lsf == 0: # Fresh train directly from ImageNet weights print('Loading initial model weights from {:s}'.format( self.pretrained_model)) variables = tf.global_variables() # Initialize all variables first sess.run(tf.variables_initializer(variables, name='init')) var_keep_dic = self.get_variables_in_checkpoint_file( self.pretrained_model) # Get the variables to restore, ignorizing the variables to fix variables_to_restore = self.net.get_variables_to_restore( variables, var_keep_dic) restorer = tf.train.Saver(variables_to_restore) restorer.restore(sess, self.pretrained_model) print('Loaded.') # Need to fix the variables before loading, so that the RGB weights are changed to BGR # For VGG16 it also changes the convolutional weights fc6 and fc7 to # fully connected weights self.net.fix_variables(sess, self.pretrained_model) print('Fixed.') sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE)) last_snapshot_iter = 0 else: # Get the most recent snapshot and restore ss_paths = [ss_paths[-1]] np_paths = [np_paths[-1]] print('Restorining model snapshots from {:s}'.format(sfiles[-1])) self.saver.restore(sess, str(sfiles[-1])) print('Restored.') # Needs to restore the other hyperparameters/states for training, (TODO xinlei) I have # tried my best to find the random states so that it can be recovered exactly # However the Tensorflow state is currently not available with open(str(nfiles[-1]), 'rb') as fid: st0 = pickle.load(fid) cur = pickle.load(fid) perm = pickle.load(fid) cur_val = pickle.load(fid) perm_val = pickle.load(fid) last_snapshot_iter = pickle.load(fid) np.random.set_state(st0) self.data_layer._cur = cur self.data_layer._perm = perm self.data_layer_val._cur = cur_val self.data_layer_val._perm = perm_val # Set the learning rate, only reduce once if last_snapshot_iter > cfg.TRAIN.STEPSIZE: sess.run( tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA)) else: sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE)) timer = Timer() iter = last_snapshot_iter + 1 last_summary_time = time.time() while iter < max_iters + 1: # Learning rate if iter == cfg.TRAIN.STEPSIZE + 1: # Add snapshot here before reducing the learning rate self.snapshot(sess, iter) sess.run( tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA)) timer.tic() # Get training data, one batch at a time blobs = self.data_layer.forward() now = time.time() if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL: # Compute the graph with summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \ self.net.train_step_with_summary(sess, blobs, train_op) self.writer.add_summary(summary, float(iter)) # Also check the summary on the validation set blobs_val = self.data_layer_val.forward() summary_val = self.net.get_summary(sess, blobs_val) self.valwriter.add_summary(summary_val, float(iter)) last_summary_time = now else: # Compute the graph without summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \ self.net.train_step(sess, blobs, train_op) timer.toc() # Display training information if iter % (cfg.TRAIN.DISPLAY) == 0: print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n ' '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \ (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval())) print('speed: {:.3f}s / iter'.format(timer.average_time)) if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: last_snapshot_iter = iter snapshot_path, np_path = self.snapshot(sess, iter) np_paths.append(np_path) ss_paths.append(snapshot_path) # Remove the old snapshots if there are too many if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT: to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): nfile = np_paths[0] os.remove(str(nfile)) np_paths.remove(nfile) if len(ss_paths) > cfg.TRAIN.SNAPSHOT_KEPT: to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): sfile = ss_paths[0] # To make the code compatible to earlier versions of Tensorflow, # where the naming tradition for checkpoints are different if os.path.exists(str(sfile)): os.remove(str(sfile)) else: os.remove(str(sfile + '.data-00000-of-00001')) os.remove(str(sfile + '.index')) sfile_meta = sfile + '.meta' os.remove(str(sfile_meta)) ss_paths.remove(sfile) iter += 1 if last_snapshot_iter != iter - 1: self.snapshot(sess, iter - 1) self.writer.close() self.valwriter.close()
class SolverWrapper(object): def __init__(self, network, imdb, valimdb, roidb, valroidb, model_dir, pretrained_model=None): self.net = network self.imdb = imdb self.valimdb = valimdb self.roidb = roidb self.valroidb = valroidb self.model_dir = model_dir self.tbdir = os.path.join(model_dir, 'train_log') if not os.path.exists(self.tbdir): os.makedirs(self.tbdir) self.pretrained_model = pretrained_model def set_learn_strategy(self, learn_dict): self._disp_interval = learn_dict['disp_interval'] self._valid_interval = learn_dict['disp_interval'] * 5 self._use_tensorboard = learn_dict['use_tensorboard'] self._use_valid = learn_dict['use_valid'] self._evaluate = learn_dict['evaluate'] self._save_point_interval = learn_dict['save_point_interval'] self._lr_decay_steps = learn_dict['lr_decay_steps'] if self._evaluate: self._begin_eval_point = learn_dict['begin_eval_point'] self.evaluate_dir = os.path.join(self.model_dir, 'evaluate') self._evaluate_thresh = 0.05 self._evaluate_max_per_image = 1 def train_model(self, resume=None, max_iters=100000): # Build data layers for both training and validation set self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes) self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True) # 载入存档点,初始化权重,设置优化函数,设置网络权重学习率 self.prepare_construct(resume) net = self.net # training train_loss = 0 rpn_cls_loss = 0 rpn_bbox_loss = 0 fast_rcnn_cls_loss = 0 fast_rcnn_bbox_loss = 0 tp, tf, fg, bg = 0., 0., 0, 0 step_cnt = 0 re_cnt = False t = Timer() t.tic() for step in range(self.start_step, max_iters + 1): blobs = self.data_layer.forward() if step % self._valid_interval == 0 and self._use_tensorboard: loss_r, image_r = net.train_operation( blobs, self._optimizer, image_if=True, clip_parameters=self._parameters) self._tensor_writer.add_image('Image', image_r, step) else: try: loss_r, image_r = net.train_operation( blobs, self._optimizer, image_if=False, clip_parameters=self._parameters) except: print('=' * 40) print('=' * 40) print('=' * 40) print(blobs['im_name']) train_loss += loss_r[0] rpn_cls_loss += loss_r[1] rpn_bbox_loss += loss_r[2] fast_rcnn_cls_loss += loss_r[3] fast_rcnn_bbox_loss += loss_r[4] # fg:物体 bg:背景 tp:真阳 tf:真阴 fg += net.metrics_dict['fg'] bg += net.metrics_dict['bg'] tp += net.metrics_dict['tp'] tf += net.metrics_dict['tf'] step_cnt += 1 if step % self._disp_interval == 0: duration = t.toc(average=False) fps = step_cnt / duration log_text = 'step %d, image: %s, loss: %.4f, fps: %.2f (%.2fs per batch)' % ( step, blobs['im_name'], train_loss / step_cnt, fps, 1. / fps) tp_text = 'step {}, tp: {}/{}, tf: {}/{}'.format( step, int(tp / step_cnt), int(fg / step_cnt), int(tf / step_cnt), int(bg / step_cnt)) pprint.pprint(log_text) pprint.pprint(tp_text) if self._use_tensorboard: self._tensor_writer.add_text('Train', log_text, global_step=step) # Train avg_rpn_cls_loss = rpn_cls_loss / step_cnt avg_rpn_bbox_loss = rpn_bbox_loss / step_cnt avg_fast_rcnn_cls_loss = fast_rcnn_cls_loss / step_cnt avg_fast_rcnn_bbox_loss = fast_rcnn_bbox_loss / step_cnt self._tensor_writer.add_scalars( 'TrainSetLoss', { 'RPN_cls_loss': avg_rpn_cls_loss, 'RPN_bbox_loss': avg_rpn_bbox_loss, 'FastRcnn_cls_loss': avg_fast_rcnn_cls_loss, 'FastRcnn_bbox_loss': avg_fast_rcnn_bbox_loss }, global_step=step) self._tensor_writer.add_scalar('Learning_rate', self._lr, global_step=step) re_cnt = True if self._use_valid and step % self._valid_interval == 0 and step != 0: total_valid_loss = 0.0 valid_rpn_cls_loss = 0.0 valid_rpn_bbox_loss = 0.0 valid_fast_rcnn_cls_loss = 0.0 valid_fast_rcnn_bbox_loss = 0.0 valid_step_cnt = 0 valid_tp, valid_tf, valid_fg, valid_bg = 0., 0., 0, 0 start_time = time.time() valid_length = self._disp_interval for valid_batch in range(valid_length): # get one batch blobs = self.data_layer_val.forward() if self._use_tensorboard and valid_batch % valid_length == 0: # 此处没传optimizer,不会更新网络,只计算loss loss_r, image_r = net.train_operation(blobs, None, image_if=True) self._tensor_writer.add_image('Image_Valid', image_r, step) else: loss_r, image_r = net.train_operation(blobs, None, image_if=False) total_valid_loss += loss_r[0] valid_rpn_cls_loss += loss_r[1] valid_rpn_bbox_loss += loss_r[2] valid_fast_rcnn_cls_loss += loss_r[3] valid_fast_rcnn_bbox_loss += loss_r[4] valid_fg += net.metrics_dict['fg'] valid_bg += net.metrics_dict['bg'] valid_tp += net.metrics_dict['tp'] valid_tf += net.metrics_dict['tf'] valid_step_cnt += 1 duration = time.time() - start_time fps = valid_step_cnt / duration log_text = 'step %d, valid average loss: %.4f, fps: %.2f (%.2fs per batch)' % ( step, total_valid_loss / valid_step_cnt, fps, 1. / fps) pprint.pprint(log_text) if self._use_tensorboard: # Valid avg_rpn_cls_loss_valid = valid_rpn_cls_loss / valid_step_cnt avg_rpn_bbox_loss_valid = valid_rpn_bbox_loss / valid_step_cnt avg_fast_rcnn_cls_loss_valid = valid_fast_rcnn_cls_loss / valid_step_cnt avg_fast_rcnn_bbox_loss_valid = valid_fast_rcnn_bbox_loss / valid_step_cnt valid_tpr = valid_tp * 1.0 / valid_fg valid_tfr = valid_tf * 1.0 / valid_bg real_total_loss_valid = valid_rpn_cls_loss + valid_rpn_bbox_loss\ + valid_fast_rcnn_cls_loss + valid_fast_rcnn_bbox_loss # Train avg_rpn_cls_loss = rpn_cls_loss / step_cnt avg_rpn_bbox_loss = rpn_bbox_loss / step_cnt avg_fast_rcnn_cls_loss = fast_rcnn_cls_loss / step_cnt avg_fast_rcnn_bbox_loss = fast_rcnn_bbox_loss / step_cnt tpr = tp * 1.0 / fg tfr = tf * 1.0 / bg real_total_loss = rpn_cls_loss + rpn_bbox_loss + fast_rcnn_cls_loss + fast_rcnn_bbox_loss self._tensor_writer.add_text('Valid', log_text, global_step=step) self._tensor_writer.add_scalars( 'Total_Loss', { 'train': train_loss / step_cnt, 'valid': total_valid_loss / valid_step_cnt }, global_step=step) self._tensor_writer.add_scalars( 'Real_loss', { 'train': real_total_loss / step_cnt, 'valid': real_total_loss_valid / valid_step_cnt }, global_step=step) self._tensor_writer.add_scalars( 'RPN_cls_loss', { 'train': avg_rpn_cls_loss, 'valid': avg_rpn_cls_loss_valid }, global_step=step) self._tensor_writer.add_scalars( 'RPN_bbox_loss', { 'train': avg_rpn_bbox_loss, 'valid': avg_rpn_bbox_loss_valid }, global_step=step) self._tensor_writer.add_scalars( 'FastRcnn_cls_loss', { 'train': avg_fast_rcnn_cls_loss, 'valid': avg_fast_rcnn_cls_loss_valid }, global_step=step) self._tensor_writer.add_scalars( 'FastRcnn_bbox_loss', { 'train': avg_fast_rcnn_bbox_loss, 'valid': avg_fast_rcnn_bbox_loss_valid }, global_step=step) self._tensor_writer.add_scalars('tpr', { 'train': tpr, 'valid': valid_tpr }, global_step=step) self._tensor_writer.add_scalars('tfr', { 'train': tfr, 'valid': valid_tfr }, global_step=step) self._tensor_writer.add_scalars( 'ValidSetLoss', { 'RPN_cls_loss': avg_rpn_cls_loss_valid, 'RPN_bbox_loss': avg_rpn_bbox_loss_valid, 'FastRcnn_cls_loss': avg_fast_rcnn_cls_loss_valid, 'FastRcnn_bbox_loss': avg_fast_rcnn_bbox_loss_valid }, global_step=step) if (step % self._save_point_interval == 0) and step != 0: save_name, _ = self.save_check_point(step) print('save model: {}'.format(save_name)) if self._evaluate: if step > self._begin_eval_point and step % cfg.TRAIN.EVALUATE_POINT == 0: self.net.eval() evaluate_solverwrapper = EvaluateSolverWrapper( network=self.net, imdb=self.valimdb, model_dir=None, output_dir=self.evaluate_dir) metrics_cls, metrics_reg = evaluate_solverwrapper.\ eval_model(step, self._evaluate_max_per_image, self._evaluate_thresh) self.after_model_mode() del evaluate_solverwrapper for key in metrics_cls.keys(): metrics_cls[key] = metrics_cls[key][0] for key in metrics_reg.keys(): metrics_reg[key] = metrics_reg[key][0] if self._use_tensorboard: self._tensor_writer.add_scalars('metrics_cls', metrics_cls, global_step=step) self._tensor_writer.add_scalars('metrics_reg', metrics_reg, global_step=step) if step in self._lr_decay_steps: self._lr *= self._lr_decay self._optimizer = self._train_optimizer() if re_cnt: tp, tf, fg, bg = 0., 0., 0., 0. train_loss = 0 rpn_cls_loss = 0 rpn_bbox_loss = 0 fast_rcnn_cls_loss = 0 fast_rcnn_bbox_loss = 0 step_cnt = 0 t.tic() re_cnt = False if self._use_tensorboard: self._tensor_writer.export_scalars_to_json( os.path.join(self.tbdir, 'all_scalars.json')) def save_check_point(self, step): net = self.net if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) # store the model snapshot filename = os.path.join(self.model_dir, 'fasterRcnn_iter_{}.h5'.format(step)) h5f = h5py.File(filename, mode='w') for k, v in net.state_dict().items(): h5f.create_dataset(k, data=v.cpu().numpy()) # store data information nfilename = os.path.join(self.model_dir, 'fasterRcnn_iter_{}.pkl'.format(step)) # current state of numpy random st0 = np.random.get_state() # current position in the database cur = self.data_layer._cur # current shuffled indexes of the database perm = self.data_layer._perm # current position in the validation database cur_val = self.data_layer_val._cur # current shuffled indexes of the validation database perm_val = self.data_layer_val._perm # current learning rate lr = self._lr # Dump the meta info with open(nfilename, 'wb') as fid: pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(lr, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(step, fid, pickle.HIGHEST_PROTOCOL) return filename, nfilename def load_check_point(self, step): net = self.net filename = os.path.join(self.model_dir, 'fasterRcnn_iter_{}.h5'.format(step)) nfilename = os.path.join(self.model_dir, 'fasterRcnn_iter_{}.pkl'.format(step)) print('Restoring model snapshots from {:s}'.format(filename)) if not os.path.exists(filename): print('The checkPoint is not Right') sys.exit(1) # load model h5f = h5py.File(filename, mode='r') for k, v in net.state_dict().items(): param = torch.from_numpy(np.asarray(h5f[k])) v.copy_(param) # load data information with open(nfilename, 'rb') as fid: st0 = pickle.load(fid) cur = pickle.load(fid) perm = pickle.load(fid) cur_val = pickle.load(fid) perm_val = pickle.load(fid) lr = pickle.load(fid) last_snapshot_iter = pickle.load(fid) np.random.set_state(st0) self.data_layer._cur = cur self.data_layer._perm = perm self.data_layer_val._cur = cur_val self.data_layer_val._perm = perm_val self._lr = lr if last_snapshot_iter == step: print('Restore over ') else: print('The checkPoint is not Right') raise ValueError return last_snapshot_iter #初始化网络权重 def weights_normal_init(self, model, dev=0.01): import math def _gaussian_init(m, dev): m.weight.data.normal_(0.0, dev) if hasattr(m.bias, 'data'): m.bias.data.zero_() def _xaiver_init(m): nn.init.xavier_normal(m.weight.data) if hasattr(m.bias, 'data'): m.bias.data.zero_() def _hekaiming_init(m): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) if hasattr(m.bias, 'data'): m.bias.data.zero_() def _resnet_init(model, dev): if isinstance(model, list): for m in model: self.weights_normal_init(m, dev) else: for m in model.modules(): if isinstance(m, nn.Conv2d): _hekaiming_init(m) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): _gaussian_init(m, dev) def _vgg_init(model, dev): if isinstance(model, list): for m in model: self.weights_normal_init(m, dev) else: for m in model.modules(): if isinstance(m, nn.Conv2d): _gaussian_init(m, dev) elif isinstance(m, nn.Linear): _gaussian_init(m, dev) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() if cfg.TRAIN.INIT_WAY == 'resnet': _vgg_init(model, dev) elif cfg.TRAIN.INIT_WAY == 'vgg': _vgg_init(model, dev) else: raise NotImplementedError #载入存档点,初始化权重,设置优化函数,设置网络权重学习率 def prepare_construct(self, resume_iter): # init network self.net.init_fasterRCNN() # Set the random seed torch.manual_seed(cfg.RNG_SEED) np.random.seed(cfg.RNG_SEED) # Set learning rate and momentum self._lr = cfg.TRAIN.LEARNING_RATE self._lr_decay = 0.1 self._momentum = cfg.TRAIN.MOMENTUM self._weight_decay = cfg.TRAIN.WEIGHT_DECAY # load model if resume_iter: self.start_step = resume_iter + 1 self.load_check_point(resume_iter) else: self.start_step = 0 self.weights_normal_init(self.net, dev=0.01) # refer to caffe faster RCNN self.net.init_special_bbox_fc(dev=0.001) if self.pretrained_model != None: self.net._rpn._network._load_pre_trained_model( self.pretrained_model) print('Load parameters from Path: {}'.format( self.pretrained_model)) else: pass if cfg.CUDA_IF: self.net.cuda() # BN should be fixed self.after_model_mode() # set optimizer self._parameters = [ params for params in self.net.parameters() if params.requires_grad == True ] self._optimizer = self._train_optimizer() # tensorboard if self._use_tensorboard: import tensorboardX as tbx self._tensor_writer = tbx.SummaryWriter(log_dir=self.tbdir) def after_model_mode(self): # model self.net.train() # resnet fixed BN should be eval if cfg.TRAIN.INIT_WAY == 'resnet': self.net._rpn._network._bn_eval() def _train_optimizer(self): parameters = self._train_parameter() optimizer = torch.optim.SGD(parameters, momentum=self._momentum) return optimizer def _train_parameter(self): params = [] for key, value in self.net.named_parameters(): if value.requires_grad == True: if 'bias' in key: params += [{ 'params': [value], 'lr': self._lr * (cfg.TRAIN.DOUBLE_BIAS + 1), 'weight_decay': 0 }] else: params += [{ 'params': [value], 'lr': self._lr, 'weight_decay': self._weight_decay }] return params
class SolverWrapper(object): """ A wrapper class for the training process """ def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None): self.net = network self.imdb = imdb self.roidb = roidb self.valroidb = valroidb self.output_dir = output_dir self.tbdir = tbdir # Simply put '_val' at the end to save the summaries from the validation set self.tbvaldir = tbdir + '_val' if not os.path.exists(self.tbvaldir): os.makedirs(self.tbvaldir) self.pretrained_model = pretrained_model def snapshot(self, sess, iter): net = self.net if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) # Store the model snapshot filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt' filename = os.path.join(self.output_dir, filename) self.saver.save(sess, filename) print('Wrote snapshot to: {:s}'.format(filename)) # Also store some meta information, random state, etc. nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl' nfilename = os.path.join(self.output_dir, nfilename) # current state of numpy random st0 = np.random.get_state() # current position in the database cur = self.data_layer._cur # current shuffled indexes of the database perm = self.data_layer._perm # current position in the validation database cur_val = self.data_layer_val._cur # current shuffled indexes of the validation database perm_val = self.data_layer_val._perm # Dump the meta info with open(nfilename, 'wb') as fid: pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL) return filename, nfilename def from_snapshot(self, sess, sfile, nfile): print('Restoring model snapshots from {:s}'.format(sfile)) self.saver.restore(sess, sfile) print('Restored.') # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have # tried my best to find the random states so that it can be recovered exactly # However the Tensorflow state is currently not available with open(nfile, 'rb') as fid: st0 = pickle.load(fid) cur = pickle.load(fid) perm = pickle.load(fid) cur_val = pickle.load(fid) perm_val = pickle.load(fid) last_snapshot_iter = pickle.load(fid) np.random.set_state(st0) self.data_layer._cur = cur self.data_layer._perm = perm self.data_layer_val._cur = cur_val self.data_layer_val._perm = perm_val return last_snapshot_iter def get_variables_in_checkpoint_file(self, file_name): try: print('&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&') # reader = tf.train.NewCheckpointReader(file_name) reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() return var_to_shape_map except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.") def construct_graph(self, sess): with sess.graph.as_default(): # Set the random seed for tensorflow tf.set_random_seed(cfg.RNG_SEED) # Build the main computation graph layers = self.net.create_architecture('TRAIN', self.imdb.num_classes, tag='default', anchor_scales=cfg.ANCHOR_SCALES, anchor_ratios=cfg.ANCHOR_RATIOS) # Define the loss losses = layers['all_losses'] loss = losses['total_loss'] m1_loss = losses['M1']['total_loss'] m2_loss = losses['M2']['total_loss'] m3_loss = losses['M3']['total_loss'] # Set learning rate and momentum lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) self.optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM) self.optimizer_m1 = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM) self.optimizer_m2 = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM) self.optimizer_m3 = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM) # Compute the gradients with regard to the loss gvs = self.optimizer.compute_gradients(loss) gvs_m1 = self.optimizer_m1.compute_gradients(m1_loss) gvs_m2 = self.optimizer_m2.compute_gradients(m2_loss) gvs_m3 = self.optimizer_m3.compute_gradients(m3_loss) # Double the gradient of the bias if set if cfg.TRAIN.DOUBLE_BIAS: final_gvs = [] final_gvs_m1 = [] final_gvs_m2 = [] final_gvs_m3 = [] with tf.variable_scope('Gradient_Mult') as scope: for grad, var in gvs: # print("grad, var:", grad, var) scale = 1. if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name: scale *= 2. if not np.allclose(scale, 1.0): grad = tf.multiply(grad, scale) final_gvs.append((grad, var)) train_op = self.optimizer.apply_gradients(final_gvs) with tf.variable_scope('Gradient_Mult_m1') as scope: for grad, var in gvs_m1: scale = 1. if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name: scale *= 2. if not np.allclose(scale, 1.0): grad = tf.multiply(grad, scale) final_gvs_m1.append((grad, var)) train_m1_op = self.optimizer.apply_gradients(final_gvs_m1) with tf.variable_scope('Gradient_Mult_m2') as scope: for grad, var in gvs_m2: scale = 1. if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name: scale *= 2. if not np.allclose(scale, 1.0): grad = tf.multiply(grad, scale) final_gvs_m2.append((grad, var)) train_m2_op = self.optimizer.apply_gradients(final_gvs_m2) with tf.variable_scope('Gradient_Mult_m3') as scope: for grad, var in gvs_m3: scale = 1. if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name: scale *= 2. if not np.allclose(scale, 1.0): grad = tf.multiply(grad, scale) final_gvs_m3.append((grad, var)) train_m3_op = self.optimizer.apply_gradients(final_gvs_m3) else: train_op = self.optimizer.apply_gradients(gvs) train_m1_op = self.optimizer_m1.apply_gradients(gvs_m1) train_m2_op = self.optimizer_m2.apply_gradients(gvs_m2) train_m3_op = self.optimizer_m3.apply_gradients(gvs_m3) # group the three independent train_op final_train_op = tf.group(train_m1_op, train_m2_op, train_m3_op) # We will handle the snapshots ourselves self.saver = tf.train.Saver(max_to_keep=100000) # Write the train and validation information to tensorboard self.writer = tf.summary.FileWriter(self.tbdir, sess.graph) self.valwriter = tf.summary.FileWriter(self.tbvaldir) return lr, train_op, train_m1_op, train_m2_op, train_m3_op, final_train_op def find_previous(self): sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta') sfiles = glob.glob(sfiles) sfiles.sort(key=os.path.getmtime) # Get the snapshot name in TensorFlow redfiles = [] for stepsize in cfg.TRAIN.STEPSIZE: redfiles.append(os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}.ckpt.meta'.format(stepsize + 1))) sfiles = [ss.replace('.meta', '') for ss in sfiles if ss not in redfiles] nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl') nfiles = glob.glob(nfiles) nfiles.sort(key=os.path.getmtime) redfiles = [redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles] nfiles = [nn for nn in nfiles if nn not in redfiles] lsf = len(sfiles) assert len(nfiles) == lsf return lsf, nfiles, sfiles def initialize(self, sess): # Initial file lists are empty np_paths = [] ss_paths = [] # Fresh train directly from ImageNet weights print('Loading initial model weights from {:s}'.format(self.pretrained_model)) variables = tf.global_variables() print("variables:", variables) if 'darknet53' in self.pretrained_model: print('the base network is Darknet53!!!') sess.run(tf.variables_initializer(variables, name='init')) self.net.restored_from_npz(sess) print('Loaded.') # print('>>>>>>>', variables[0].eval()) last_snapshot_iter = 0 rate = cfg.TRAIN.LEARNING_RATE stepsizes = list(cfg.TRAIN.STEPSIZE) else: sess.run(tf.variables_initializer(variables, name='init')) var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model) # Get the variables to restore, ignoring the variables to fix variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic) restorer = tf.train.Saver(variables_to_restore) restorer.restore(sess, self.pretrained_model) print('Loaded.') # Need to fix the variables before loading, so that the RGB weights are changed to BGR # For VGG16 it also changes the convolutional weights self.net.fix_variables(sess, self.pretrained_model) print('Fixed.') last_snapshot_iter = 0 rate = cfg.TRAIN.LEARNING_RATE stepsizes = list(cfg.TRAIN.STEPSIZE) return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths def restore(self, sess, sfile, nfile): # Get the most recent snapshot and restore variables = tf.global_variables() print("variables:", variables) np_paths = [nfile] ss_paths = [sfile] # Restore model from snapshots last_snapshot_iter = self.from_snapshot(sess, sfile, nfile) # Set the learning rate rate = cfg.TRAIN.LEARNING_RATE stepsizes = [] for stepsize in cfg.TRAIN.STEPSIZE: if last_snapshot_iter > stepsize: rate *= cfg.TRAIN.GAMMA else: stepsizes.append(stepsize) return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths def remove_snapshot(self, np_paths, ss_paths): to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): nfile = np_paths[0] os.remove(str(nfile)) np_paths.remove(nfile) to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): sfile = ss_paths[0] # To make the code compatible to earlier versions of Tensorflow, # where the naming tradition for checkpoints are different if os.path.exists(str(sfile)): os.remove(str(sfile)) else: os.remove(str(sfile + '.data-00000-of-00001')) os.remove(str(sfile + '.index')) sfile_meta = sfile + '.meta' os.remove(str(sfile_meta)) ss_paths.remove(sfile) def train_model(self, sess, max_iters): # Build data layers for both training and validation set self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes) self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True) # Construct the computation graph # lr, train_op = self.construct_graph(sess) lr, train_op, train_m1_op, train_m2_op, train_m3_op, final_train_op = self.construct_graph(sess) # Find previous snapshots if there is any to restore from lsf, nfiles, sfiles = self.find_previous() # Initialize the variables or restore them from the last snapshot if lsf == 0: rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(sess) else: rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(sess, str(sfiles[-1]), str(nfiles[-1])) timer = Timer() iter = last_snapshot_iter + 1 last_summary_time = time.time() # Make sure the lists are not empty stepsizes.append(max_iters) stepsizes.reverse() next_stepsize = stepsizes.pop() while iter < max_iters + 1: # Learning rate if iter == next_stepsize + 1: # Add snapshot here before reducing the learning rate self.snapshot(sess, iter) rate *= cfg.TRAIN.GAMMA sess.run(tf.assign(lr, rate)) next_stepsize = stepsizes.pop() timer.tic() # Get training data, one batch at a time blobs = self.data_layer.forward() now = time.time() if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL: # Compute the graph with summary losses, summary = self.net.train_step_with_summary(sess, blobs, final_train_op) self.writer.add_summary(summary, float(iter)) # Also check the summary on the validation set blobs_val = self.data_layer_val.forward() summary_val = self.net.get_summary(sess, blobs_val) self.valwriter.add_summary(summary_val, float(iter)) last_summary_time = now else: # Compute the graph without summary losses = self.net.train_step(sess, blobs, final_train_op) timer.toc() # get the corresponding loss to show m1_cls_loss = losses['M1']['rpn_cross_entropy'] m1_box_loss = losses['M1']['rpn_loss_box'] m1_kp_loss = losses['M1']['kpoints_loss'] m1_total_loss = losses['M1']['total_loss'] m2_cls_loss = losses['M2']['rpn_cross_entropy'] m2_box_loss = losses['M2']['rpn_loss_box'] m2_kp_loss = losses['M2']['kpoints_loss'] m2_total_loss = losses['M2']['total_loss'] m3_cls_loss = losses['M3']['rpn_cross_entropy'] m3_box_loss = losses['M3']['rpn_loss_box'] m3_kp_loss = losses['M3']['kpoints_loss'] m3_total_loss = losses['M3']['total_loss'] total_loss = losses['total_loss'] # Display training information if iter % (cfg.TRAIN.DISPLAY) == 0: print('iter: %d / %d \n >>> m1_cls_loss: %.6f, m1_box_loss: %.6f, m1_kp_loss: %.6f, m1_total_loss: %.6f\n ' '>>> m2_cls_loss: %.6f, m2_box_loss: %.6f, m2_kp_loss: %.6f, m2_total_loss: %.6f\n ' '>>> m3_cls_loss: %.6f, m3_box_loss: %.6f, m3_kp_loss: %.6f, m3_total_loss: %.6f\n ' '>>> total_loss: %.6f, lr: %f' % \ (iter, max_iters, m1_cls_loss, m1_box_loss, m1_kp_loss, m1_total_loss, m2_cls_loss, m2_box_loss, m2_kp_loss, m2_total_loss, m3_cls_loss, m3_box_loss, m3_kp_loss, m3_total_loss, total_loss, lr.eval())) print('speed: {:.3f}s / iter'.format(timer.average_time)) # Snapshotting if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: last_snapshot_iter = iter ss_path, np_path = self.snapshot(sess, iter) np_paths.append(np_path) ss_paths.append(ss_path) # Remove the old snapshots if there are too many if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT: self.remove_snapshot(np_paths, ss_paths) iter += 1 if last_snapshot_iter != iter - 1: self.snapshot(sess, iter - 1) self.writer.close() self.valwriter.close() def train_model_old(self, sess, max_iters): # Build data layers for both training and validation set self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes) self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True) # Construct the computation graph lr, train_op, train_m1_op, train_m2_op, train_m3_op, _ = self.construct_graph(sess) # Find previous snapshots if there is any to restore from lsf, nfiles, sfiles = self.find_previous() # Initialize the variables or restore them from the last snapshot if lsf == 0: rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(sess) else: rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(sess, str(sfiles[-1]), str(nfiles[-1])) timer = Timer() iter = last_snapshot_iter + 1 last_summary_time = time.time() # Make sure the lists are not empty stepsizes.append(max_iters) stepsizes.reverse() next_stepsize = stepsizes.pop() m1_iters = 0 m2_iters = 0 m3_iters = 0 while iter < max_iters + 1: random_seed = np.random.rand() if random_seed < 0.33: module = "M1" train_op = train_m1_op m1_iters += 1 elif 0.33 <= random_seed < 0.67: module = "M2" train_op = train_m2_op m2_iters += 1 else: module = "M3" train_op = train_m3_op m3_iters += 1 # Learning rate if iter == next_stepsize + 1: # Add snapshot here before reducing the learning rate self.snapshot(sess, iter) rate *= cfg.TRAIN.GAMMA sess.run(tf.assign(lr, rate)) next_stepsize = stepsizes.pop() timer.tic() # Get training data, one batch at a time blobs = self.data_layer.forward() now = time.time() if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL: # Compute the graph with summary rpn_loss_cls, rpn_loss_box, landmarks_loss, total_loss, summary = self.net.train_step_with_summary_old(sess, blobs, train_op, module) self.writer.add_summary(summary, float(iter)) # Also check the summary on the validation set blobs_val = self.data_layer_val.forward() summary_val = self.net.get_summary(sess, blobs_val) self.valwriter.add_summary(summary_val, float(iter)) last_summary_time = now else: # Compute the graph without summary rpn_loss_cls, rpn_loss_box, kpoints_loss, total_loss = self.net.train_step_old(sess, blobs, train_op, module) timer.toc() # Display training information if iter % (cfg.TRAIN.DISPLAY) == 0: if module == 'M1': iters = m1_iters elif module == 'M2': iters = m2_iters else: iters = m3_iters print('iter: %d / %d, now training module: %s, iters: %d,\n >>> total loss: %.6f\n >>> rpn_loss_cls: %.6f\n ' '>>> rpn_loss_box: %.6f\n >>> kpoints_loss: %.6f\n >>> lr: %f' % \ (iter, max_iters, module, iters, total_loss, rpn_loss_cls, rpn_loss_box, kpoints_loss, lr.eval())) print('speed: {:.3f}s / iter'.format(timer.average_time)) # Snapshotting if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: last_snapshot_iter = iter ss_path, np_path = self.snapshot(sess, iter) np_paths.append(np_path) ss_paths.append(ss_path) # Remove the old snapshots if there are too many if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT: self.remove_snapshot(np_paths, ss_paths) iter += 1 if last_snapshot_iter != iter - 1: self.snapshot(sess, iter - 1) self.writer.close() self.valwriter.close()
class SolverWrapper(object): def __init__(self, network, imdb, roidb, valroidb, model_dir, pretrained_model=None): self.net = network self.imdb = imdb self.roidb = roidb self.valroidb = valroidb self.model_dir = model_dir self.tbdir = os.path.join(model_dir, 'train_log') if not os.path.exists(self.tbdir): os.makedirs(self.tbdir) self.pretrained_model = pretrained_model def set_learn_strategy(self, learn_dict): self._disp_interval = learn_dict['disp_interval'] self._valid_interval = learn_dict['disp_interval'] * 5 self._use_tensorboard = learn_dict['use_tensorboard'] self._use_valid = learn_dict['use_valid'] self._save_point_interval = learn_dict['save_point_interval'] self._lr_decay_steps = learn_dict['lr_decay_steps'] def train_model(self, resume=None, max_iters=100000): # Build data layers for both training and validation set self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes) self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True) self.prepare_construct(resume) net = self.net # training train_loss = 0 rpn_cls_loss = 0 rpn_bbox_loss = 0 fast_rcnn_cls_loss = 0 fast_rcnn_bbox_loss = 0 tp, tf, fg, bg = 0., 0., 0, 0 step_cnt = 0 re_cnt = False t = Timer() t.tic() for step in range(self.start_step, max_iters + 1): blobs = self.data_layer.forward() im_data = blobs['data'] im_info = blobs['im_info'] gt_boxes = blobs['gt_boxes'] # forward result_cls_prob, result_bbox_pred, result_rois = net( im_data, im_info, gt_boxes) loss = net.loss + net._rpn.loss train_loss += loss.data.cpu()[0] rpn_cls_loss += net._rpn.cross_entropy.data.cpu()[0] rpn_bbox_loss += net._rpn.loss_box.data.cpu()[0] fast_rcnn_cls_loss += net.cross_entropy.data.cpu()[0] fast_rcnn_bbox_loss += net.loss_box.data.cpu()[0] step_cnt += 1 # backward self._optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm(self._parameters, max_norm=10) self._optimizer.step() # clear middle memory net._delete_cache() if step % self._disp_interval == 0: duration = t.toc(average=False) fps = step_cnt / duration log_text = 'step %d, image: %s, loss: %.4f, fps: %.2f (%.2fs per batch)' % ( step, blobs['im_name'], train_loss / step_cnt, fps, 1. / fps) pprint.pprint(log_text) if self._use_tensorboard: self._tensor_writer.add_text('Train', log_text, global_step=step) # Train avg_rpn_cls_loss = rpn_cls_loss / step_cnt avg_rpn_bbox_loss = rpn_bbox_loss / step_cnt avg_fast_rcnn_cls_loss = fast_rcnn_cls_loss / step_cnt avg_fast_rcnn_bbox_loss = fast_rcnn_bbox_loss / step_cnt self._tensor_writer.add_scalars( 'TrainSetLoss', { 'RPN_cls_loss': avg_rpn_cls_loss, 'RPN_bbox_loss': avg_rpn_bbox_loss, 'FastRcnn_cls_loss': avg_fast_rcnn_cls_loss, 'FastRcnn_bbox_loss': avg_fast_rcnn_bbox_loss }, global_step=step) self._tensor_writer.add_scalar('Learning_rate', self._lr, global_step=step) re_cnt = True if self._use_tensorboard and step % self._valid_interval == 0: new_gt_boxes = gt_boxes.copy() new_gt_boxes[:, :4] = new_gt_boxes[:, :4] image = self.back_to_image(blobs['data']).astype(np.uint8) im_shape = image.shape pred_boxes, scores, classes = net.interpret_faster_rcnn_scale( result_cls_prob, result_bbox_pred, result_rois, im_shape, min_score=0.1) image = self.draw_photo(image, pred_boxes, scores, classes, new_gt_boxes) image = torchtrans.ToTensor()(image) image = vutils.make_grid([image]) self._tensor_writer.add_image('Image', image, step) if self._use_valid and step % self._valid_interval == 0: total_valid_loss = 0.0 valid_rpn_cls_loss = 0.0 valid_rpn_bbox_loss = 0.0 valid_fast_rcnn_cls_loss = 0.0 valid_fast_rcnn_bbox_loss = 0.0 valid_step_cnt = 0 start_time = time.time() valid_length = self._disp_interval net.eval() for valid_batch in range(valid_length): # get one batch blobs = self.data_layer_val.forward() im_data = blobs['data'] im_info = blobs['im_info'] gt_boxes = blobs['gt_boxes'] # forward result_cls_prob, result_bbox_pred, result_rois = net( im_data, im_info, gt_boxes) valid_loss = net.loss + net._rpn.loss total_valid_loss += valid_loss.data.cpu()[0] valid_rpn_cls_loss += net._rpn.cross_entropy.data.cpu()[0] valid_rpn_bbox_loss += net._rpn.loss_box.data.cpu()[0] valid_fast_rcnn_cls_loss += net.cross_entropy.data.cpu()[0] valid_fast_rcnn_bbox_loss += net.loss_box.data.cpu()[0] valid_step_cnt += 1 net.train() duration = time.time() - start_time fps = valid_step_cnt / duration log_text = 'step %d, valid average loss: %.4f, fps: %.2f (%.2fs per batch)' % ( step, total_valid_loss / valid_step_cnt, fps, 1. / fps) pprint.pprint(log_text) if self._use_tensorboard: self._tensor_writer.add_text('Valid', log_text, global_step=step) new_gt_boxes = gt_boxes.copy() new_gt_boxes[:, :4] = new_gt_boxes[:, :4] image = self.back_to_image(blobs['data']).astype(np.uint8) im_shape = image.shape pred_boxes, scores, classes = net.interpret_faster_rcnn_scale( result_cls_prob, result_bbox_pred, result_rois, im_shape, min_score=0.1) image = self.draw_photo(image, pred_boxes, scores, classes, new_gt_boxes) image = torchtrans.ToTensor()(image) image = vutils.make_grid([image]) self._tensor_writer.add_image('Image_Valid', image, step) if self._use_tensorboard: # Valid avg_rpn_cls_loss_valid = valid_rpn_cls_loss / valid_step_cnt avg_rpn_bbox_loss_valid = valid_rpn_bbox_loss / valid_step_cnt avg_fast_rcnn_cls_loss_valid = valid_fast_rcnn_cls_loss / valid_step_cnt avg_fast_rcnn_bbox_loss_valid = valid_fast_rcnn_bbox_loss / valid_step_cnt real_total_loss_valid = valid_rpn_cls_loss + valid_rpn_bbox_loss + valid_fast_rcnn_cls_loss + valid_fast_rcnn_bbox_loss # Train avg_rpn_cls_loss = rpn_cls_loss / step_cnt avg_rpn_bbox_loss = rpn_bbox_loss / step_cnt avg_fast_rcnn_cls_loss = fast_rcnn_cls_loss / step_cnt avg_fast_rcnn_bbox_loss = fast_rcnn_bbox_loss / step_cnt real_total_loss = rpn_cls_loss + rpn_bbox_loss + fast_rcnn_cls_loss + fast_rcnn_bbox_loss self._tensor_writer.add_scalars( 'Total_Loss', { 'train': train_loss / step_cnt, 'valid': total_valid_loss / valid_step_cnt }, global_step=step) self._tensor_writer.add_scalars( 'Real_loss', { 'train': real_total_loss / step_cnt, 'valid': real_total_loss_valid / valid_step_cnt }, global_step=step) self._tensor_writer.add_scalars( 'RPN_cls_loss', { 'train': avg_rpn_cls_loss, 'valid': avg_rpn_cls_loss_valid }, global_step=step) self._tensor_writer.add_scalars( 'RPN_bbox_loss', { 'train': avg_rpn_bbox_loss, 'valid': avg_rpn_bbox_loss_valid }, global_step=step) self._tensor_writer.add_scalars( 'FastRcnn_cls_loss', { 'train': avg_fast_rcnn_cls_loss, 'valid': avg_fast_rcnn_cls_loss_valid }, global_step=step) self._tensor_writer.add_scalars( 'FastRcnn_bbox_loss', { 'train': avg_fast_rcnn_bbox_loss, 'valid': avg_fast_rcnn_bbox_loss_valid }, global_step=step) self._tensor_writer.add_scalars( 'ValidSetLoss', { 'RPN_cls_loss': avg_rpn_cls_loss_valid, 'RPN_bbox_loss': avg_rpn_bbox_loss_valid, 'FastRcnn_cls_loss': avg_fast_rcnn_cls_loss_valid, 'FastRcnn_bbox_loss': avg_fast_rcnn_bbox_loss_valid }, global_step=step) # self._tensor_writer.add_scalars('TrainSetLoss', { # 'RPN_cls_loss': avg_rpn_cls_loss, # 'RPN_bbox_loss': avg_rpn_bbox_loss, # 'FastRcnn_cls_loss': avg_fast_rcnn_cls_loss, # 'FastRcnn_bbox_loss': avg_fast_rcnn_bbox_loss # }, global_step=step) # self._tensor_writer.add_scalar('Learning_rate', self._lr, global_step=step) if (step % self._save_point_interval == 0) and step > 0: save_name, _ = self.save_check_point(step) print('save model: {}'.format(save_name)) if step in self._lr_decay_steps: self._lr *= self._lr_decay self._optimizer = self._train_optimizer() if re_cnt: tp, tf, fg, bg = 0., 0., 0, 0 train_loss = 0 rpn_cls_loss = 0 rpn_bbox_loss = 0 fast_rcnn_cls_loss = 0 fast_rcnn_bbox_loss = 0 step_cnt = 0 t.tic() re_cnt = False if self._use_tensorboard: self._tensor_writer.export_scalars_to_json( os.path.join(self.tbdir, 'all_scalars.json')) def draw_photo(self, image, dets, scores, classes, gt_boxes): # im2show = np.copy(image) im2show = image # color_b = (0, 191, 255) for i, det in enumerate(dets): det = tuple(int(x) for x in det) r = min(0 + i * 10, 255) r_i = i / 5 g = min(150 + r_i * 10, 255) g_i = r_i / 5 b = min(200 + g_i, 255) color_b_c = (r, g, b) cv2.rectangle(im2show, det[0:2], det[2:4], color_b_c, 2) cv2.putText(im2show, '%s: %.3f' % (classes[i], scores[i]), (det[0], det[1] + 15), cv2.FONT_HERSHEY_PLAIN, 1.0, (0, 0, 255), thickness=1) for i, det in enumerate(gt_boxes): det = tuple(int(x) for x in det) gt_class = self.net._classes[det[-1]] cv2.rectangle(im2show, det[0:2], det[2:4], (255, 0, 0), 2) cv2.putText(im2show, '%s' % (gt_class), (det[0], det[1] + 15), cv2.FONT_HERSHEY_PLAIN, 1.0, (0, 0, 255), thickness=1) return im2show def back_to_image(self, img): image = img[0] + cfg.PIXEL_MEANS image = image[:, :, ::-1].copy(order='C') return image def save_check_point(self, step): net = self.net if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) # store the model snapshot filename = os.path.join(self.model_dir, 'fasterRcnn_iter_{}.h5'.format(step)) h5f = h5py.File(filename, mode='w') for k, v in net.state_dict().items(): h5f.create_dataset(k, data=v.cpu().numpy()) # store data information nfilename = os.path.join(self.model_dir, 'fasterRcnn_iter_{}.pkl'.format(step)) # current state of numpy random st0 = np.random.get_state() # current position in the database cur = self.data_layer._cur # current shuffled indexes of the database perm = self.data_layer._perm # current position in the validation database cur_val = self.data_layer_val._cur # current shuffled indexes of the validation database perm_val = self.data_layer_val._perm # current learning rate lr = self._lr # Dump the meta info with open(nfilename, 'wb') as fid: pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(lr, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(step, fid, pickle.HIGHEST_PROTOCOL) return filename, nfilename def load_check_point(self, step): net = self.net filename = os.path.join(self.model_dir, 'fasterRcnn_iter_{}.h5'.format(step)) nfilename = os.path.join(self.model_dir, 'fasterRcnn_iter_{}.pkl'.format(step)) print('Restoring model snapshots from {:s}'.format(filename)) if not os.path.exists(filename): print('The checkPoint is not Right') sys.exit(1) # load model h5f = h5py.File(filename, mode='r') for k, v in net.state_dict().items(): param = torch.from_numpy(np.asarray(h5f[k])) v.copy_(param) # load data information with open(nfilename, 'rb') as fid: st0 = pickle.load(fid) cur = pickle.load(fid) perm = pickle.load(fid) cur_val = pickle.load(fid) perm_val = pickle.load(fid) lr = pickle.load(fid) last_snapshot_iter = pickle.load(fid) np.random.set_state(st0) self.data_layer._cur = cur self.data_layer._perm = perm self.data_layer_val._cur = cur_val self.data_layer_val._perm = perm_val self._lr = lr if last_snapshot_iter == step: print('Restore over ') else: print('The checkPoint is not Right') raise ValueError return last_snapshot_iter def weights_normal_init(self, model, dev=0.01): import math def _gaussian_init(m, dev): m.weight.data.normal_(0.0, dev) if hasattr(m.bias, 'data'): m.bias.data.zero_() def _xaiver_init(m): nn.init.xavier_normal(m.weight.data) if hasattr(m.bias, 'data'): m.bias.data.zero_() def _hekaiming_init(m): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) if hasattr(m.bias, 'data'): m.bias.data.zero_() def _resnet_init(model, dev): if isinstance(model, list): for m in model: self.weights_normal_init(m, dev) else: for m in model.modules(): if isinstance(m, nn.Conv2d): _hekaiming_init(m) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): _gaussian_init(m, dev) def _vgg_init(model, dev): if isinstance(model, list): for m in model: self.weights_normal_init(m, dev) else: for m in model.modules(): if isinstance(m, nn.Conv2d): _gaussian_init(m, dev) elif isinstance(m, nn.Linear): _gaussian_init(m, dev) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() if cfg.TRAIN.INIT_WAY == 'resnet': _vgg_init(model, dev) elif cfg.TRAIN.INIT_WAY == 'vgg': _vgg_init(model, dev) else: raise NotImplementedError def prepare_construct(self, resume_iter): # init network self.net.init_fasterRCNN() # Set the random seed torch.manual_seed(cfg.RNG_SEED) np.random.seed(cfg.RNG_SEED) # Set learning rate and momentum self._lr = cfg.TRAIN.LEARNING_RATE self._lr_decay = 0.1 self._momentum = cfg.TRAIN.MOMENTUM self._weight_decay = cfg.TRAIN.WEIGHT_DECAY # load model if resume_iter: self.start_step = resume_iter + 1 self.load_check_point(resume_iter) else: self.start_step = 0 self.weights_normal_init(self.net, dev=0.01) # refer to caffe faster RCNN self.net.init_special_bbox_fc(dev=0.001) if self.pretrained_model != None: self.net._rpn._network._load_pre_trained_model( self.pretrained_model) print('Load parameters from Path: {}'.format( self.pretrained_model)) else: pass # model self.net.train() if cfg.CUDA_IF: self.net.cuda() # resnet fixed BN should be eval if cfg.TRAIN.INIT_WAY == 'resnet': self.net._rpn._network._bn_eval() # set optimizer self._parameters = [ params for params in self.net.parameters() if params.requires_grad == True ] self._optimizer = self._train_optimizer() # tensorboard if self._use_tensorboard: import tensorboardX as tbx self._tensor_writer = tbx.SummaryWriter(log_dir=self.tbdir) def _train_optimizer(self): parameters = self._train_parameter() optimizer = torch.optim.SGD(parameters, momentum=self._momentum) return optimizer def _train_parameter(self): params = [] for key, value in self.net.named_parameters(): if value.requires_grad == True: if 'bias' in key: params += [{ 'params': [value], 'lr': self._lr * (cfg.TRAIN.DOUBLE_BIAS + 1), 'weight_decay': 0 }] else: params += [{ 'params': [value], 'lr': self._lr, 'weight_decay': self._weight_decay }] return params
class SolverWrapper(object): """ A wrapper class for the training process 据作者的说法,这个类就是为了方便自己使用Python代码来控制训练过程中的相关东西 """ def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, pretrained_model): self.net = network #网络 vgg 或者 resnet self.imdb = imdb #数据库 self.roidb = roidb #region of insterest self.valroidb = valroidb #tensorboard 输出文件 self.output_dir = output_dir #结果输出文件 #self.tbdir = tbdir # Simply put '_val' at the end to save the summaries from the validation set #self.tbvaldir = tbdir + '_val' #if not os.path.exists(self.tbvaldir): #os.makedirs(self.tbvaldir) self.pretrained_model = pretrained_model def snapshot(self, sess, iter): #这个函数做一些存储 存储 net = self.net if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) # Store the model snapshot filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format( iter ) + '.ckpt' #__C.TRAIN.SNAPSHOT_PREFIX = 'res101_faster_rcnn' 默认输出模型 filename = os.path.join(self.output_dir, filename) #存在output下面 self.saver.save(sess, filename) #存储tensor变量 print('Wrote snapshot to: {:s}'.format(filename)) return filename def from_snapshot(self, sess, sfile, nfile): print('Restoring model snapshots from {:s}'.format(sfile)) self.saver.restore(sess, sfile) print('Restored.') # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have # tried my best to find the random states so that it can be recovered exactly # However the Tensorflow state is currently not available with open(nfile, 'rb') as fid: st0 = pickle.load(fid) cur = pickle.load(fid) perm = pickle.load(fid) cur_val = pickle.load(fid) perm_val = pickle.load(fid) last_snapshot_iter = pickle.load(fid) np.random.set_state(st0) self.data_layer._cur = cur self.data_layer._perm = perm self.data_layer_val._cur = cur_val self.data_layer_val._perm = perm_val return last_snapshot_iter def get_variables_in_checkpoint_file(self, file_name): #初始化模型 文件 try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() return var_to_shape_map except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print( "It's likely that your checkpoint file has been compressed " "with SNAPPY.") def construct_graph(self, sess): #该函数构建计算图 with sess.graph.as_default(): # Set the random seed for tensorflow tf.set_random_seed(cfg.RNG_SEED) #随机种子 # Build the main computation graph layers = self.net.create_architecture( 'TRAIN', self.imdb.num_classes, anchor_scales=cfg.ANCHOR_SCALES, anchor_ratios=cfg.ANCHOR_RATIOS) # Define the loss loss = layers['total_loss'] # Set learning rate and momentum lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) #学习率 self.optimizer = tf.train.MomentumOptimizer( lr, cfg.TRAIN.MOMENTUM) #梯度优化器 # Compute the gradients with regard to the loss gvs = self.optimizer.compute_gradients(loss) # Double the gradient of the bias if set if cfg.TRAIN.DOUBLE_BIAS: final_gvs = [] with tf.variable_scope('Gradient_Mult') as scope: for grad, var in gvs: scale = 1. if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name: scale *= 2. if not np.allclose(scale, 1.0): grad = tf.multiply(grad, scale) final_gvs.append((grad, var)) train_op = self.optimizer.apply_gradients(final_gvs) else: train_op = self.optimizer.apply_gradients(gvs) # We will handle the snapshots ourselves self.saver = tf.train.Saver(max_to_keep=100000) # Write the train and validation information to tensorboard #self.writer = tf.summary.FileWriter(self.tbdir, sess.graph) #self.valwriter = tf.summary.FileWriter(self.tbvaldir) print("构建网络通过") return lr, train_op def find_previous(self): sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta') sfiles = glob.glob(sfiles) sfiles.sort(key=os.path.getmtime) # Get the snapshot name in TensorFlow redfiles = [] for stepsize in cfg.TRAIN.STEPSIZE: redfiles.append( os.path.join( self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}.ckpt.meta'.format(stepsize + 1))) sfiles = [ ss.replace('.meta', '') for ss in sfiles if ss not in redfiles ] nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl') nfiles = glob.glob(nfiles) nfiles.sort(key=os.path.getmtime) redfiles = [ redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles ] nfiles = [nn for nn in nfiles if nn not in redfiles] lsf = len(sfiles) assert len(nfiles) == lsf return lsf, nfiles, sfiles def initialize(self, sess): # Initial file lists are empty np_paths = [] ss_paths = [] # Fresh train directly from ImageNet weights print('Loading initial model weights from {:s}'.format( self.pretrained_model)) variables = tf.global_variables() # Initialize all variables first sess.run(tf.variables_initializer(variables, name='init')) var_keep_dic = self.get_variables_in_checkpoint_file( self.pretrained_model) #把预训练网络参数 拿出来 # Get the variables to restore, ignoring the variables to fix variables_to_restore = self.net.get_variables_to_restore( variables, var_keep_dic) restorer = tf.train.Saver(variables_to_restore) restorer.restore(sess, self.pretrained_model) print('Loaded.') # Need to fix the variables before loading, so that the RGB weights are changed to BGR # For VGG16 it also changes the convolutional weights fc6 and fc7 to # fully connected weights self.net.fix_variables(sess, self.pretrained_model) print('Fixed.') last_snapshot_iter = 0 rate = cfg.TRAIN.LEARNING_RATE stepsizes = list(cfg.TRAIN.STEPSIZE) return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths def restore(self, sess, sfile, nfile): # Get the most recent snapshot and restore np_paths = [nfile] ss_paths = [sfile] # Restore model from snapshots last_snapshot_iter = self.from_snapshot(sess, sfile, nfile) # Set the learning rate rate = cfg.TRAIN.LEARNING_RATE stepsizes = [] for stepsize in cfg.TRAIN.STEPSIZE: if last_snapshot_iter > stepsize: rate *= cfg.TRAIN.GAMMA else: stepsizes.append(stepsize) return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths def remove_snapshot(self, np_paths, ss_paths): to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): nfile = np_paths[0] os.remove(str(nfile)) np_paths.remove(nfile) to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): sfile = ss_paths[0] # To make the code compatible to earlier versions of Tensorflow, # where the naming tradition for checkpoints are different if os.path.exists(str(sfile)): os.remove(str(sfile)) else: os.remove(str(sfile + '.data-00000-of-00001')) os.remove(str(sfile + '.index')) sfile_meta = sfile + '.meta' os.remove(str(sfile_meta)) ss_paths.remove(sfile) def train_model(self, sess, max_iters): #这个是 训练的核心函数 # Build data layers for both training and validation set self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes) print("训练准备数据通过") self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True) print("测试数据通过") # Construct the computation graph lr, train_op = self.construct_graph(sess) #构建网络通过 # Find previous snapshots if there is any to restore from lsf, nfiles, sfiles = self.find_previous() # Initialize the variables or restore them from the last snapshot if lsf == 0: rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize( sess) else: rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore( sess, str(sfiles[-1]), str(nfiles[-1])) timer = Timer() iter = last_snapshot_iter + 1 last_summary_time = time.time() # Make sure the lists are not empty stepsizes.append(max_iters) stepsizes.reverse() next_stepsize = stepsizes.pop() while iter < max_iters + 1: # Learning rate if iter == next_stepsize + 1: # Add snapshot here before reducing the learning rate self.snapshot(sess, iter) rate *= cfg.TRAIN.GAMMA sess.run(tf.assign(lr, rate)) next_stepsize = stepsizes.pop() timer.tic() # Get training data, one batch at a time blobs = self.data_layer.forward( ) #这里开始报错 解决(原因是忘了加image 属性 导致没读出数据) #验证集不需要数据扩增 !!前向运算 now = time.time() if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL: # Compute the graph with summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \ self.net.train_step_with_summary(sess, blobs, train_op) # Also check the summary on the validation set blobs_val = self.data_layer_val.forward() #验证集的运算 last_summary_time = now else: # Compute the graph without summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \ self.net.train_step(sess, blobs, train_op) timer.toc() # Display training information if iter % (cfg.TRAIN.DISPLAY) == 0: print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n ' '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \ (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval())) print('speed: {:.3f}s / iter'.format(timer.average_time)) # Snapshotting if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: last_snapshot_iter = iter ss_path, np_path = self.snapshot(sess, iter) np_paths.append(np_path) ss_paths.append(ss_path) # Remove the old snapshots if there are too many if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT: self.remove_snapshot(np_paths, ss_paths) iter += 1 if last_snapshot_iter != iter - 1: self.snapshot(sess, iter - 1) self.writer.close() self.valwriter.close()