class SolverWrapper(object): """ A wrapper class for the training process """ def __init__(self, sess, network, rfcn_network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None): self.net = network self.rfcn_network = rfcn_network self.imdb = imdb self.roidb = roidb self.valroidb = valroidb self.output_dir = output_dir self.tbdir = tbdir # Simply put '_val' at the end to save the summaries from the validation set self.tbvaldir = tbdir + '_val' if not os.path.exists(self.tbvaldir): os.makedirs(self.tbvaldir) self.pretrained_model = pretrained_model def snapshot(self, sess, iter): net = self.net if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) # Store the model snapshot filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt' filename = os.path.join(self.output_dir, filename) self.saver.save(sess, filename) print('Wrote snapshot to: {:s}'.format(filename)) # Also store some meta information, random state, etc. nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl' nfilename = os.path.join(self.output_dir, nfilename) # current state of numpy random st0 = np.random.get_state() # current position in the database cur = self.data_layer._cur # current shuffled indeces of the database perm = self.data_layer._perm # current position in the validation database cur_val = self.data_layer_val._cur # current shuffled indeces of the validation database perm_val = self.data_layer_val._perm # Dump the meta info with open(nfilename, 'wb') as fid: pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL) return filename, nfilename def get_variables_in_checkpoint_file(self, file_name): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() return var_to_shape_map except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.") def train_model(self, sess, max_iters): # Build data layers for both training and validation set self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes) self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True) # Determine different scales for anchors, see paper with sess.graph.as_default(): # Set the random seed for tensorflow tf.set_random_seed(cfg.RNG_SEED) # Build the main computation graph rpn_layers, proposal_targets = self.net.create_architecture(sess, 'TRAIN', self.imdb.num_classes, scope='rpn_network', tag='default', anchor_scales=cfg.ANCHOR_SCALES, anchor_ratios=cfg.ANCHOR_RATIOS) rfcn_layers, _ = self.rfcn_network.create_architecture(sess, 'TRAIN', self.imdb.num_classes, scope='rfcn_network', tag='default', anchor_scales=cfg.ANCHOR_SCALES, anchor_ratios=cfg.ANCHOR_RATIOS, input_rois=rpn_layers['rois'], roi_scores=rpn_layers['roi_scores'], proposal_targets=proposal_targets) # Define the loss rpn_loss = rpn_layers['rpn_loss'] rfcn_loss = rfcn_layers['rfcn_loss'] rpn_rfcn_loss = rpn_layers['rfcn_loss'] # Set learning rate and momentum lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) momentum = cfg.TRAIN.MOMENTUM self.optimizer = tf.train.MomentumOptimizer(lr, momentum) # Compute the gradients wrt the loss rpn_trainable_variables_stage1 = self.net.get_train_variables('rpn_network') rfcn_trainable_variables_stage2 = self.rfcn_network.get_train_variables('rfcn_network') rpn_trainable_variables_stage3 = self.net.get_train_variables_stage3('rpn_network') rpn_trainable_variables_stage4 = self.net.get_train_variables_stage4('rpn_network') gvs_rpn_stage1 = self.optimizer.compute_gradients(rpn_loss, rpn_trainable_variables_stage1) gvs_rfcn_stage2 = self.optimizer.compute_gradients(rfcn_loss, rfcn_trainable_variables_stage2) gvs_rpn_stage3 = self.optimizer.compute_gradients(rpn_loss, rpn_trainable_variables_stage3) gvs_rfcn_stage4 = self.optimizer.compute_gradients(rpn_rfcn_loss, rpn_trainable_variables_stage4) train_op_stage1 = self.optimizer.apply_gradients(gvs_rpn_stage1) train_op_stage2 = self.optimizer.apply_gradients(gvs_rfcn_stage2) train_op_stage3 = self.optimizer.apply_gradients(gvs_rpn_stage3) train_op_stage4 = self.optimizer.apply_gradients(gvs_rfcn_stage4) # We will handle the snapshots ourselves self.saver = tf.train.Saver(max_to_keep=1000000) # Write the train and validation information to tensorboard self.writer_stage1 = tf.summary.FileWriter(self.tbdir+'/stage1', sess.graph) self.writer_stage2 = tf.summary.FileWriter(self.tbdir+'/stage2', sess.graph) self.writer_stage3 = tf.summary.FileWriter(self.tbdir+'/stage3', sess.graph) self.writer_stage4 = tf.summary.FileWriter(self.tbdir+'/stage4', sess.graph) self.valwriter_stage1 = tf.summary.FileWriter(self.tbvaldir+'/stage1') self.valwriter_stage2 = tf.summary.FileWriter(self.tbvaldir+'/stage2') self.valwriter_stage3 = tf.summary.FileWriter(self.tbvaldir+'/stage3') self.valwriter_stage4 = tf.summary.FileWriter(self.tbvaldir+'/stage4') # Find previous snapshots if there is any to restore from sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta') sfiles = glob.glob(sfiles) sfiles.sort(key=os.path.getmtime) # Get the snapshot name in TensorFlow redstr = '_iter_{:d}.'.format(cfg.TRAIN.STEPSIZE+1) sfiles = [ss.replace('.meta', '') for ss in sfiles] sfiles = [ss for ss in sfiles if redstr not in ss] nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl') nfiles = glob.glob(nfiles) nfiles.sort(key=os.path.getmtime) nfiles = [nn for nn in nfiles if redstr not in nn] lsf = len(sfiles) assert len(nfiles) == lsf np_paths = nfiles ss_paths = sfiles if lsf == 0: # Fresh train directly from ImageNet weights print('Loading initial model weights from {:s}'.format(self.pretrained_model)) variables = tf.global_variables() # Initialize all variables first sess.run(tf.variables_initializer(variables, name='init')) var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model) # Get the variables to restore, ignorizing the variables to fix variables_to_restore_rpn, variables_to_restore_rfcn = self.net.get_variables_to_restore(variables, var_keep_dic) self.saver_rpn = tf.train.Saver(variables_to_restore_rpn) self.saver_rpn.restore(sess, self.pretrained_model) self.saver_rfcn = tf.train.Saver(variables_to_restore_rfcn) self.saver_rfcn.restore(sess, self.pretrained_model) print('Loaded.') # Need to fix the variables before loading, so that the RGB weights are changed to BGR # For VGG16 it also changes the convolutional weights fc6 and fc7 to # fully connected weights self.net.fix_variables(sess, self.pretrained_model) print('Fixed.') sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE)) last_snapshot_iter = 0 else: # Get the most recent snapshot and restore ss_paths = [ss_paths[-1]] np_paths = [np_paths[-1]] print('Restorining model snapshots from {:s}'.format(sfiles[-1])) self.saver.restore(sess, str(sfiles[-1])) print('Restored.') # Needs to restore the other hyperparameters/states for training, (TODO xinlei) I have # tried my best to find the random states so that it can be recovered exactly # However the Tensorflow state is currently not available with open(str(nfiles[-1]), 'rb') as fid: st0 = pickle.load(fid) cur = pickle.load(fid) perm = pickle.load(fid) cur_val = pickle.load(fid) perm_val = pickle.load(fid) last_snapshot_iter = pickle.load(fid) np.random.set_state(st0) self.data_layer._cur = cur self.data_layer._perm = perm self.data_layer_val._cur = cur_val self.data_layer_val._perm = perm_val # Set the learning rate, only reduce once if last_snapshot_iter > cfg.TRAIN.STEPSIZE: sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA)) else: sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE)) timer = Timer() iter = last_snapshot_iter + 1 last_summary_time = time.time() stage_infor = '' # while iter < max_iters + 1: while iter < 200001: # while iter < 201: # Learning rate if iter == 80001: # if iter == 81: # Add snapshot here before reducing the learning rate self.snapshot(sess, iter) sess.run(tf.assign(lr, 0.001)) # Get training data, one batch at a time blobs = self.data_layer.forward() # stage 1 training rpn layers and backbones in rpn network if iter < 80001: # if iter < 81: stage_infor = 'stage1' if iter == 60001: # if iter == 61: sess.run(tf.assign(lr, 0.0001)) timer.tic() now = time.time() if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL: # Compute the graph with summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \ self.net.train_rpn_step_with_summary(sess, blobs, train_op_stage1) self.writer_stage1.add_summary(summary, float(iter)) # Also check the summary on the validation set blobs_val = self.data_layer_val.forward() summary_val = self.net.get_summary(sess, blobs_val) self.valwriter_stage1.add_summary(summary_val, float(iter)) last_summary_time = now else: # Compute the graph without summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \ self.net.train_rpn_step(sess, blobs, train_op_stage1) timer.toc() if iter == 80000: self.writer_stage1.close() self.valwriter_stage1.close() # stage 2 training rfcn layers and backbones in rfcn network elif 80001 <= iter < 200001: # elif 81 <= iter < 201: stage_infor = 'stage2' rpn_loss_cls = 0 rpn_loss_box = 0 if iter == 160001: # if iter == 161: self.snapshot(sess, iter) sess.run(tf.assign(lr, 0.0001)) timer.tic() now = time.time() if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL: # Compute the graph with summary loss_cls, loss_box, total_loss, summary = \ self.rfcn_network.train_rfcn_step_with_summary_stage2(sess, blobs, train_op_stage2, self.net) self.writer_stage2.add_summary(summary, float(iter)) # Also check the summary on the validation set blobs_val = self.data_layer_val.forward() summary_val = self.rfcn_network.get_summary_stage2(sess, blobs_val, self.net) self.valwriter_stage2.add_summary(summary_val, float(iter)) last_summary_time = now else: # Compute the graph without summary loss_cls, loss_box, total_loss = \ self.rfcn_network.train_rfcn_step_stage2(sess, blobs, train_op_stage2, self.net) timer.toc() if iter == 200000: self.writer_stage2.close() self.valwriter_stage2.close() else: raise ValueError('illeagle input iter value') # Display training information # if iter % (cfg.TRAIN.DISPLAY) == 0: if iter % 20 == 0: print('iter: %d / %d, stage: %s, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n ' '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \ (iter, max_iters, stage_infor, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval())) print('speed: {:.3f}s / iter'.format(timer.average_time)) if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: last_snapshot_iter = iter snapshot_path, np_path = self.snapshot(sess, iter) np_paths.append(np_path) ss_paths.append(snapshot_path) iter += 1 if iter <= 200001: # if iter <= 201: ############################################### ##### merge rfcn_network to rpn_network ##### ############################################### merged_ops = merged_networks_rfcn2rpn('rfcn_network', 'rpn_network') with tf.variable_scope('rpn_network', reuse=True): rpn_conv1 = tf.get_variable('resnet_v1_101/block2/unit_1/bottleneck_v1/shortcut/weights') rpn_conv2 = tf.get_variable('resnet_v1_101/refined_reduce_depth/weights') rpn_conv3 = tf.get_variable('resnet_v1_101/block4/unit_3/bottleneck_v1/conv3/weights') rpn_conv4 = tf.get_variable('resnet_v1_101/block3/unit_18/bottleneck_v1/conv2/weights') rpn_conv5 = tf.get_variable('resnet_v1_101/block3/unit_14/bottleneck_v1/conv3/weights') with tf.variable_scope('rfcn_network', reuse=True): rfcn_conv1 = tf.get_variable('resnet_v1_101/block2/unit_1/bottleneck_v1/shortcut/weights') rfcn_conv2 = tf.get_variable('resnet_v1_101/refined_reduce_depth/weights') rfcn_conv3 = tf.get_variable('resnet_v1_101/block4/unit_3/bottleneck_v1/conv3/weights') rfcn_conv4 = tf.get_variable('resnet_v1_101/block3/unit_18/bottleneck_v1/conv2/weights') rfcn_conv5 = tf.get_variable('resnet_v1_101/block3/unit_14/bottleneck_v1/conv3/weights') with tf.control_dependencies(merged_ops): rpn_conv1 = tf.identity(rpn_conv1) bool1 = tf.equal(rpn_conv1, rfcn_conv1) bool2 = tf.equal(rpn_conv2, rfcn_conv2) bool3 = tf.equal(rpn_conv3, rfcn_conv3) bool4 = tf.equal(rpn_conv4, rfcn_conv4) bool5 = tf.equal(rpn_conv5, rfcn_conv5) bool1_val, bool2_val, bool3_val, bool4_val, bool5_val = \ sess.run([bool1, bool2, bool3, bool4, bool5]) # stage 3 and stage 4 sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE)) # while iter < max_iters + 1: while iter < 480001: # while iter < 401: if iter == 280001: # if iter == 261: # Add snapshot here before reducing the learning rate self.snapshot(sess, iter) sess.run(tf.assign(lr, 0.001)) if iter == 400001: self.snapshot(sess, iter) sess.run(tf.assign(lr, 0.001)) blobs = self.data_layer.forward() # stage 3 training rpn layers only in rpn network rpn layers if 200001 <= iter < 280001: # if 201 <= iter < 581: stage_infor = 'stage3' # if iter == 260001: if iter == 260001: self.snapshot(sess, iter) sess.run(tf.assign(lr, 0.0001)) timer.tic() now = time.time() if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL: # Compute the graph with summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \ self.net.train_rpn_step_with_summary(sess, blobs, train_op_stage3) self.writer_stage3.add_summary(summary, float(iter)) # Also check the summary on the validation set blobs_val = self.data_layer_val.forward() summary_val = self.net.get_summary(sess, blobs_val) self.valwriter_stage3.add_summary(summary_val, float(iter)) last_summary_time = now else: # Compute the graph without summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \ self.net.train_rpn_step(sess, blobs, train_op_stage3) timer.toc() if iter == 280000: self.writer_stage3.close() self.valwriter_stage3.close() # stage 4 training rfcn layer only in rpn network rfcn layers elif 280001 <= iter < 400001: # elif 581 <= iter < 1401: stage_infor = 'stage4' if iter == 360001: # if iter == 1361: sess.run(tf.assign(lr, 0.0001)) timer.tic() now = time.time() if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL: # Compute the graph with summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \ self.net.train_rpn_step_with_summary(sess, blobs, train_op_stage4) self.writer_stage4.add_summary(summary, float(iter)) # Also check the summary on the validation set blobs_val = self.data_layer_val.forward() summary_val = self.net.get_summary(sess, blobs_val) self.valwriter_stage4.add_summary(summary_val, float(iter)) last_summary_time = now else: # Compute the graph without summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \ self.net.train_rpn_step(sess, blobs, train_op_stage4) timer.toc() if iter == 400000: self.writer_stage4.close() self.valwriter_stage4.close() elif 400001 <= iter < 480001: # if 401 <= iter < 481: stage_infor = 'stage5' # if iter == 461: if iter == 460001: self.snapshot(sess, iter) sess.run(tf.assign(lr, 0.0001)) timer.tic() now = time.time() if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL: # Compute the graph with summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \ self.net.train_rpn_step_with_summary(sess, blobs, train_op_stage3) self.writer_stage3.add_summary(summary, float(iter)) # Also check the summary on the validation set blobs_val = self.data_layer_val.forward() summary_val = self.net.get_summary(sess, blobs_val) self.valwriter_stage3.add_summary(summary_val, float(iter)) last_summary_time = now else: # Compute the graph without summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \ self.net.train_rpn_step(sess, blobs, train_op_stage3) timer.toc() if iter == 480000: self.writer_stage3.close() self.valwriter_stage3.close() else: raise ValueError('iter is not allowed') # Display training information if iter % (cfg.TRAIN.DISPLAY) == 0: print('iter: %d / %d, stage: %s, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n ' '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \ (iter, max_iters, stage_infor, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval())) print('speed: {:.3f}s / iter'.format(timer.average_time)) if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: last_snapshot_iter = iter snapshot_path, np_path = self.snapshot(sess, iter) np_paths.append(np_path) ss_paths.append(snapshot_path) iter += 1 if last_snapshot_iter != iter - 1: self.snapshot(sess, iter - 1)
class SolverWrapper(object): """ A wrapper class for the training process """ def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None): self.net = network self.imdb = imdb self.roidb = roidb self.valroidb = valroidb self.output_dir = output_dir self.tbdir = tbdir # Simply put '_val' at the end to save the summaries from the validation set self.tbvaldir = tbdir + '_val' if not os.path.exists(self.tbvaldir): os.makedirs(self.tbvaldir) self.pretrained_model = pretrained_model def snapshot(self, sess, iter): net = self.net if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) # Store the model snapshot filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt' filename = os.path.join(self.output_dir, filename) self.saver.save(sess, filename) print('Wrote snapshot to: {:s}'.format(filename)) # Also store some meta information, random state, etc. nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl' nfilename = os.path.join(self.output_dir, nfilename) # current state of numpy random st0 = np.random.get_state() # current position in the database cur = self.data_layer._cur # current shuffled indexes of the database perm = self.data_layer._perm # current position in the validation database #cur_val = self.data_layer_val._cur # current shuffled indexes of the validation database #perm_val = self.data_layer_val._perm # Dump the meta info with open(nfilename, 'wb') as fid: pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL) #pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL) #pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL) return filename, nfilename def from_snapshot(self, sess, sfile, nfile): print('Restoring model snapshots from {:s}'.format(sfile)) self.saver.restore(sess, sfile) print('Restored.') # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have # tried my best to find the random states so that it can be recovered exactly # However the Tensorflow state is currently not available with open(nfile, 'rb') as fid: st0 = pickle.load(fid) cur = pickle.load(fid) perm = pickle.load(fid) #cur_val = pickle.load(fid) #perm_val = pickle.load(fid) last_snapshot_iter = pickle.load(fid) np.random.set_state(st0) self.data_layer._cur = cur self.data_layer._perm = perm #self.data_layer_val._cur = cur_val #self.data_layer_val._perm = perm_val return last_snapshot_iter def get_variables_in_checkpoint_file(self, file_name): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() return var_to_shape_map except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.") def construct_graph(self, sess): with sess.graph.as_default(): # Set the random seed for tensorflow tf.set_random_seed(cfg.RNG_SEED) # Build the main computation graph layers = self.net.create_architecture('TRAIN', self.imdb.num_classes, tag='default', anchor_scales=cfg.ANCHOR_SCALES, anchor_ratios=cfg.ANCHOR_RATIOS) # Define the loss loss = layers['total_loss'] # Set learning rate and momentum lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) self.optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM) # Compute the gradients with regard to the loss gvs = self.optimizer.compute_gradients(loss) # Double the gradient of the bias if set if cfg.TRAIN.DOUBLE_BIAS: final_gvs = [] with tf.variable_scope('Gradient_Mult') as scope: for grad, var in gvs: scale = 1. if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name: scale *= 2. if not np.allclose(scale, 1.0): grad = tf.multiply(grad, scale) final_gvs.append((grad, var)) train_op = self.optimizer.apply_gradients(final_gvs) else: train_op = self.optimizer.apply_gradients(gvs) # We will handle the snapshots ourselves self.saver = tf.train.Saver(max_to_keep=100000) # Write the train and validation information to tensorboard self.writer = tf.summary.FileWriter(self.tbdir, sess.graph) #self.valwriter = tf.summary.FileWriter(self.tbvaldir) return lr, train_op def find_previous(self): sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta') sfiles = glob.glob(sfiles) sfiles.sort(key=os.path.getmtime) # Get the snapshot name in TensorFlow redfiles = [] for stepsize in cfg.TRAIN.STEPSIZE: redfiles.append(os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}.ckpt.meta'.format(stepsize+1))) sfiles = [ss.replace('.meta', '') for ss in sfiles if ss not in redfiles] nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl') nfiles = glob.glob(nfiles) nfiles.sort(key=os.path.getmtime) redfiles = [redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles] nfiles = [nn for nn in nfiles if nn not in redfiles] lsf = len(sfiles) assert len(nfiles) == lsf return lsf, nfiles, sfiles def initialize(self, sess): # Initial file lists are empty np_paths = [] ss_paths = [] # Fresh train directly from ImageNet weights print('Loading initial model weights from {:s}'.format(self.pretrained_model)) variables = tf.global_variables() # Initialize all variables first sess.run(tf.variables_initializer(variables, name='init')) var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model) # Get the variables to restore, ignoring the variables to fix variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic) restorer = tf.train.Saver(variables_to_restore) restorer.restore(sess, self.pretrained_model) print('Loaded.') # Need to fix the variables before loading, so that the RGB weights are changed to BGR # For VGG16 it also changes the convolutional weights fc6 and fc7 to # fully connected weights self.net.fix_variables(sess, self.pretrained_model) print('Fixed.') last_snapshot_iter = 0 rate = cfg.TRAIN.LEARNING_RATE stepsizes = list(cfg.TRAIN.STEPSIZE) return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths def restore(self, sess, sfile, nfile): # Get the most recent snapshot and restore np_paths = [nfile] ss_paths = [sfile] # Restore model from snapshots last_snapshot_iter = self.from_snapshot(sess, sfile, nfile) # Set the learning rate rate = cfg.TRAIN.LEARNING_RATE stepsizes = [] for stepsize in cfg.TRAIN.STEPSIZE: if last_snapshot_iter > stepsize: rate *= cfg.TRAIN.GAMMA else: stepsizes.append(stepsize) return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths def remove_snapshot(self, np_paths, ss_paths): to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): nfile = np_paths[0] os.remove(str(nfile)) np_paths.remove(nfile) to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): sfile = ss_paths[0] # To make the code compatible to earlier versions of Tensorflow, # where the naming tradition for checkpoints are different if os.path.exists(str(sfile)): os.remove(str(sfile)) else: os.remove(str(sfile + '.data-00000-of-00001')) os.remove(str(sfile + '.index')) sfile_meta = sfile + '.meta' os.remove(str(sfile_meta)) ss_paths.remove(sfile) def train_model(self, sess, max_iters): # Build data layers for both training and validation set self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes) #self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True) # Construct the computation graph lr, train_op = self.construct_graph(sess) # Find previous snapshots if there is any to restore from lsf, nfiles, sfiles = self.find_previous() # Initialize the variables or restore them from the last snapshot if lsf == 0: rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(sess) else: rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(sess, str(sfiles[-1]), str(nfiles[-1])) timer = Timer() iter = last_snapshot_iter + 1 last_summary_time = time.time() # Make sure the lists are not empty stepsizes.append(max_iters) stepsizes.reverse() next_stepsize = stepsizes.pop() while iter < max_iters + 1: # Learning rate if iter == next_stepsize + 1: # Add snapshot here before reducing the learning rate self.snapshot(sess, iter) rate *= cfg.TRAIN.GAMMA sess.run(tf.assign(lr, rate)) next_stepsize = stepsizes.pop() timer.tic() # Get training data, one batch at a time blobs = self.data_layer.forward() now = time.time() if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL: # Compute the graph with summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \ self.net.train_step_with_summary(sess, blobs, train_op) self.writer.add_summary(summary, float(iter)) # Also check the summary on the validation set #blobs_val = self.data_layer_val.forward() #summary_val = self.net.get_summary(sess, blobs_val) #self.valwriter.add_summary(summary_val, float(iter)) last_summary_time = now else: # Compute the graph without summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \ self.net.train_step(sess, blobs, train_op) timer.toc() # Display training information if iter % (cfg.TRAIN.DISPLAY) == 0: print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n ' '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \ (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval())) print('speed: {:.3f}s / iter'.format(timer.average_time)) # Snapshotting if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: last_snapshot_iter = iter ss_path, np_path = self.snapshot(sess, iter) np_paths.append(np_path) ss_paths.append(ss_path) # Remove the old snapshots if there are too many if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT: self.remove_snapshot(np_paths, ss_paths) iter += 1 if last_snapshot_iter != iter - 1: self.snapshot(sess, iter - 1) self.writer.close()
class SolverWrapper(object): """ A wrapper class for the training process """ def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, pretrained_model=None): self.net = network self.imdb = imdb self.roidb = roidb self.valroidb = valroidb self.output_dir = output_dir self.pretrained_model = pretrained_model def snapshot(self, sess, iter): net = self.net if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) # Store the model snapshot filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt' filename = os.path.join(self.output_dir, filename) self.saver.save(sess, filename) print('Wrote snapshot to: {:s}'.format(filename)) # Also store some meta information, random state, etc. nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl' nfilename = os.path.join(self.output_dir, nfilename) # current state of numpy random st0 = np.random.get_state() # current position in the database cur = self.data_layer._cur # current shuffled indexes of the database perm = self.data_layer._perm # current position in the validation database cur_val = self.data_layer_val._cur # current shuffled indexes of the validation database perm_val = self.data_layer_val._perm # Dump the meta info with open(nfilename, 'wb') as fid: pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL) return filename, nfilename def from_snapshot(self, sess, sfile, nfile): print('Restoring model snapshots from {:s}'.format(sfile)) self.saver.restore(sess, sfile) print('Restored.') # Needs to restore the other hyper-parameters/states for training, I have # tried my best to find the random states so that it can be recovered exactly # However the Tensorflow state is currently not available with open(nfile, 'rb') as fid: st0 = pickle.load(fid) cur = pickle.load(fid) perm = pickle.load(fid) cur_val = pickle.load(fid) perm_val = pickle.load(fid) last_snapshot_iter = pickle.load(fid) np.random.set_state(st0) self.data_layer._cur = cur self.data_layer._perm = perm self.data_layer_val._cur = cur_val self.data_layer_val._perm = perm_val return last_snapshot_iter def get_variables_in_checkpoint_file(self, file_name): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() return var_to_shape_map except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.") def construct_graph(self, sess): # Set the random seed for tensorflow tf.set_random_seed(cfg.RNG_SEED) with sess.graph.as_default(): # Build the main computation graph layers = self.net.create_architecture('TRAIN', tag='default', anchor_scales=cfg.ANCHOR_SCALES, anchor_ratios=cfg.ANCHOR_RATIOS) # Define the loss loss = layers['total_loss'] # Set learning rate and momentum lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) self.optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM) # Compute the gradients with regard to the loss gvs = self.optimizer.compute_gradients(loss) # Double the gradient of the bias if set if cfg.TRAIN.DOUBLE_BIAS: final_gvs = [] with tf.variable_scope('Gradient_Mult') as scope: for grad, var in gvs: scale = 1. if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name: scale *= 2. if not np.allclose(scale, 1.0): grad = tf.multiply(grad, scale) final_gvs.append((grad, var)) train_op = self.optimizer.apply_gradients(final_gvs) else: train_op = self.optimizer.apply_gradients(gvs) # Initialize post-hist module of drl-RPN if cfg.DRL_RPN.USE_POST: loss_post = layers['total_loss_hist'] lr_post = tf.Variable(cfg.DRL_RPN_TRAIN.POST_LR, trainable=False) self.optimizer_post = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM) gvs_post = self.optimizer_post.compute_gradients(loss_post) train_op_post = self.optimizer_post.apply_gradients(gvs_post) else: lr_post = None train_op_post = None # Initialize main drl-RPN network self.net.build_drl_rpn_network() return lr, train_op, lr_post, train_op_post def initialize(self, sess): # Initial file lists are empty np_paths = [] ss_paths = [] # Fresh train directly from ImageNet weights print('Loading initial model weights from {:s}'.format(self.pretrained_model)) variables = tf.global_variables() # Initialize all variables first sess.run(tf.variables_initializer(variables, name='init')) var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model) #print(self.pretrained_model) #sleep(100) # Get the variables to restore, ignoring the variables to fix variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic) restorer = tf.train.Saver(variables_to_restore) restorer.restore(sess, self.pretrained_model) print('Loaded.') # Need to fix the variables before loading, so that the RGB weights are # hanged to BGR For VGG16 it also changes the convolutional weights fc6 # and fc7 to fully connected weights # # NOTE: IF YOU WANT TO TRAIN FROM EXISTING FASTER # R-CNN WEIGHTS, AND NOT FROM IMAGENET WEIGHTS, SET BELOW FLAG TO FALSE!!!! self.net.fix_variables(sess, self.pretrained_model, False) print('Fixed.') last_snapshot_iter = 0 rate = cfg.TRAIN.LEARNING_RATE stepsizes = list(cfg.TRAIN.STEPSIZE) return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths def restore(self, sess, sfile, nfile): # Get the most recent snapshot and restore np_paths = [nfile] ss_paths = [sfile] # Restore model from snapshots last_snapshot_iter = self.from_snapshot(sess, sfile, nfile) # Set the learning rate rate = cfg.TRAIN.LEARNING_RATE stepsizes = [] for stepsize in cfg.TRAIN.STEPSIZE: if last_snapshot_iter > stepsize: rate *= cfg.TRAIN.GAMMA else: stepsizes.append(stepsize) return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths def remove_snapshot(self, np_paths, ss_paths): to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): nfile = np_paths[0] os.remove(str(nfile)) np_paths.remove(nfile) to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): sfile = ss_paths[0] # To make the code compatible to earlier versions of Tensorflow, # where the naming tradition for checkpoints are different if os.path.exists(str(sfile)): os.remove(str(sfile)) else: os.remove(str(sfile + '.data-00000-of-00001')) os.remove(str(sfile + '.index')) sfile_meta = sfile + '.meta' os.remove(str(sfile_meta)) ss_paths.remove(sfile) def _print_det_loss(self, iter, max_iters, tot_loss, loss_cls, loss_box, lr, timer, in_string='detector'): if (iter + 1) % (cfg.TRAIN.DISPLAY) == 0: if loss_box is not None: print('iter: %d / %d, total loss: %.6f\n ' '>>> loss_cls (%s): %.6f\n ' '>>> loss_box (%s): %.6f\n >>> lr: %f' % \ (iter + 1, max_iters, tot_loss, in_string, loss_cls, in_string, loss_box, lr)) else: print('iter: %d / %d, total loss (%s): %.6f\n >>> lr: %f' % \ (iter + 1, max_iters, in_string, tot_loss, lr)) print('speed: {:.3f}s / iter'.format(timer.average_time)) def _check_if_continue(self, iter, max_iters, snapshot_add): img_start_idx = cfg.DRL_RPN_TRAIN.IMG_START_IDX if iter > img_start_idx: return iter, max_iters, snapshot_add, False if iter < img_start_idx: print("iter %d < img_start_idx %d -- continuing" % (iter, img_start_idx)) iter += 1 return iter, max_iters, snapshot_add, True if iter == img_start_idx: print("Adjusting stepsize, train-det-start etcetera") snapshot_add = img_start_idx max_iters -= img_start_idx iter = 0 cfg_from_list(['DRL_RPN_TRAIN.IMG_START_IDX', -1]) cfg_from_list(['DRL_RPN_TRAIN.DET_START', cfg.DRL_RPN_TRAIN.DET_START - img_start_idx]) cfg_from_list(['DRL_RPN_TRAIN.STEPSIZE', cfg.DRL_RPN_TRAIN.STEPSIZE - img_start_idx]) cfg_from_list(['TRAIN.STEPSIZE', [cfg.TRAIN.STEPSIZE[0] - img_start_idx]]) cfg_from_list(['DRL_RPN_TRAIN.POST_SS', [cfg.DRL_RPN_TRAIN.POST_SS[0] - img_start_idx]]) print("Done adjusting stepsize, train-det-start etcetera") return iter, max_iters, snapshot_add, False def train_model(self, sess, max_iters): # Build data layers for both training and validation set self.data_layer = RoIDataLayer(self.roidb, cfg.NBR_CLASSES) self.data_layer_val = RoIDataLayer(self.valroidb, cfg.NBR_CLASSES, True) # Construct the computation graph corresponding to the original Faster R-CNN # architecture first (and potentially the post-hist module of drl-RPN) lr_det_op, train_op, lr_post_op, train_op_post \ = self.construct_graph(sess) # Initialize the variables or restore them from the last snapshot rate, last_snapshot_iter, stepsizes, np_paths, ss_paths \ = self.initialize(sess) # We will handle the snapshots ourselves self.saver = tf.train.Saver(max_to_keep=100000) # Initialize self.net.init_rl_train(sess) # Setup initial learning rates lr_rl = cfg.DRL_RPN_TRAIN.LEARNING_RATE lr_det = cfg.TRAIN.LEARNING_RATE sess.run(tf.assign(lr_det_op, lr_det)) if cfg.DRL_RPN.USE_POST: lr_post = cfg.DRL_RPN_TRAIN.POST_LR sess.run(tf.assign(lr_post_op, lr_post)) # Sample first beta if cfg.DRL_RPN_TRAIN.USE_POST: betas = cfg.DRL_RPN_TRAIN.POST_BETAS else: betas = cfg.DRL_RPN_TRAIN.BETAS beta_idx = 0 beta = betas[beta_idx] # Setup drl-RPN timers timers = {'init': Timer(), 'fulltraj': Timer(), 'upd-obs-vol': Timer(), 'upd-seq': Timer(), 'upd-rl': Timer(), 'action-rl': Timer(), 'coll-traj': Timer(), 'run-drl-rpn': Timer(), 'train-drl-rpn': Timer()} # Create StatCollector (tracks various RL training statistics) stat_strings = ['reward', 'rew-done', 'traj-len', 'frac-area', 'gt >= 0.5 frac', 'gt-IoU-frac'] sc = StatCollector(max_iters, stat_strings) timer = Timer() iter = 0 snapshot_add = 0 while iter < max_iters: # Get training data, one batch at a time (assumes batch size 1) blobs = self.data_layer.forward() # Allows the possibility to start at arbitrary image, rather # than always starting from first image in dataset. Useful if # want to load parameters and keep going from there, rather # than having those and encountering visited images again. iter, max_iters, snapshot_add, do_continue \ = self._check_if_continue(iter, max_iters, snapshot_add) if do_continue: continue if not cfg.DRL_RPN_TRAIN.USE_POST: # Potentially update drl-RPN learning rate if (iter + 1) % cfg.DRL_RPN_TRAIN.STEPSIZE == 0: lr_rl *= cfg.DRL_RPN_TRAIN.GAMMA # Run drl-RPN in training mode timers['run-drl-rpn'].tic() stats = run_drl_rpn(sess, self.net, blobs, timers, mode='train', beta=beta, im_idx=None, extra_args=lr_rl) timers['run-drl-rpn'].toc() if (iter + 1) % cfg.DRL_RPN_TRAIN.BATCH_SIZE == 0: print("\n##### DRL-RPN BATCH GRADIENT UPDATE - START ##### \n") print('iter: %d / %d' % (iter + 1, max_iters)) print('lr-rl: %f' % lr_rl) timers['train-drl-rpn'].tic() self.net.train_drl_rpn(sess, lr_rl, sc, stats) timers['train-drl-rpn'].toc() sc.print_stats() print('TIMINGS:') print('runnn-drl-rpn: %.4f' % timers['run-drl-rpn'].get_avg()) print('train-drl-rpn: %.4f' % timers['train-drl-rpn'].get_avg()) print("\n##### DRL-RPN BATCH GRADIENT UPDATE - DONE ###### \n") # Also sample new beta for next batch beta_idx += 1 beta_idx %= len(betas) beta = betas[beta_idx] else: sc.update(0, stats) # At this point we assume that an RL-trajectory has been performed. # We next train detector with drl-RPN running in deterministic mode. # Potentially train detector component of network if cfg.DRL_RPN_TRAIN.DET_START >= 0 and \ iter >= cfg.DRL_RPN_TRAIN.DET_START: # Run drl-RPN in deterministic mode net_conv, rois_drl_rpn, gt_boxes, im_info, timers, _ \ = run_drl_rpn(sess, self.net, blobs, timers, mode='train_det', beta=beta, im_idx=None) # Learning rate if (iter + 1) % cfg.TRAIN.STEPSIZE[0] == 0: lr_det *= cfg.TRAIN.GAMMA sess.run(tf.assign(lr_det_op, lr_det)) timer.tic() # Train detector part loss_cls, loss_box, tot_loss \ = self.net.train_step_det(sess, train_op, net_conv, rois_drl_rpn, gt_boxes, im_info) timer.toc() # Display training information self._print_det_loss(iter, max_iters, tot_loss, loss_cls, loss_box, lr_det, timer) # Train post-hist module AFTER we have trained rest of drl-RPN! Specifically # once rest of drl-RPN has been trained already, copy those weights into # the folder of pretrained weights and rerun training with those as initial # weights, which will then train only the posterior-history module else: # The very first time we need to assign the ordinary detector weights # as starting point if iter == 0: self.net.assign_post_hist_weights(sess) # Sample beta beta = betas[beta_idx] beta_idx += 1 beta_idx %= len(betas) # Run drl-RPN in deterministic mode net_conv, rois_drl_rpn, gt_boxes, im_info, timers, cls_hist \ = run_drl_rpn(sess, self.net, blobs, timers, mode='train_det', beta=beta, im_idx=None) # Learning rate (assume only one learning rate iter for now!) if (iter + 1) % cfg.DRL_RPN_TRAIN.POST_SS[0] == 0: lr_post *= cfg.TRAIN.GAMMA sess.run(tf.assign(lr_post_op, lr_post)) # Train post-hist detector part tot_loss = self.net.train_step_post(sess, train_op_post, net_conv, rois_drl_rpn, gt_boxes, im_info, cls_hist) # Display training information self._print_det_loss(iter, max_iters, tot_loss, None, None, lr_post, timer, 'post-hist') # Snapshotting if (iter + 1) % cfg.TRAIN.SNAPSHOT_ITERS == 0: last_snapshot_iter = iter + 1 ss_path, np_path = self.snapshot(sess, iter + 1 + snapshot_add) np_paths.append(np_path) ss_paths.append(ss_path) # Remove the old snapshots if there are too many if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT: self.remove_snapshot(np_paths, ss_paths) # Increase iteration counter iter += 1 # Potentially save one last time if last_snapshot_iter != iter: self.snapshot(sess, iter + snapshot_add)
class SolverWrapper(object): """ A wrapper class for the training process """ def __init__(self, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None): self.net = network self.imdb = imdb self.roidb = roidb self.valroidb = valroidb self.output_dir = output_dir self.tbdir = tbdir # Simply put '_val' at the end to save the summaries from the validation set self.tbvaldir = tbdir + '_val' if not os.path.exists(self.tbvaldir): os.makedirs(self.tbvaldir, exist_ok=True) self.pretrained_model = pretrained_model def snapshot(self, iter): net = self.net if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) # Store the model snapshot filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format( iter) + '.pth' filename = os.path.join(self.output_dir, filename) torch.save(self.net.state_dict(), filename) print('Wrote snapshot to: {:s}'.format(filename)) # Also store some meta information, random state, etc. nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format( iter) + '.pkl' nfilename = os.path.join(self.output_dir, nfilename) # current state of numpy random st0 = np.random.get_state() # current position in the database cur = self.data_layer._cur # current shuffled indexes of the database perm = self.data_layer._perm # current position in the validation database cur_val = self.data_layer_val._cur # current shuffled indexes of the validation database perm_val = self.data_layer_val._perm # Dump the meta info with open(nfilename, 'wb') as fid: pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL) return filename, nfilename def from_snapshot(self, sfile, nfile): print('Restoring model snapshots from {:s}'.format(sfile)) self.net.load_state_dict(torch.load(str(sfile))) print('Restored.') # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have # tried my best to find the random states so that it can be recovered exactly # However the Tensorflow state is currently not available with open(nfile, 'rb') as fid: st0 = pickle.load(fid) cur = pickle.load(fid) perm = pickle.load(fid) cur_val = pickle.load(fid) perm_val = pickle.load(fid) last_snapshot_iter = pickle.load(fid) np.random.set_state(st0) self.data_layer._cur = cur self.data_layer._perm = perm self.data_layer_val._cur = cur_val self.data_layer_val._perm = perm_val return last_snapshot_iter def construct_graph(self): # Set the random seed torch.manual_seed(cfg.RNG_SEED) # Build the main computation graph self.net.create_architecture(self.imdb.num_classes, tag='default', anchor_scales=cfg.ANCHOR_SCALES, anchor_ratios=cfg.ANCHOR_RATIOS) # Define the loss # loss = layers['total_loss'] # Set learning rate and momentum lr = cfg.TRAIN.LEARNING_RATE params = [] for key, value in dict(self.net.named_parameters()).items(): if value.requires_grad: if 'bias' in key: params += [{ 'params': [value], 'lr': lr * (cfg.TRAIN.DOUBLE_BIAS + 1), 'weight_decay': cfg.TRAIN.BIAS_DECAY and cfg.TRAIN.WEIGHT_DECAY or 0 }] else: params += [{ 'params': [value], 'lr': lr, 'weight_decay': cfg.TRAIN.WEIGHT_DECAY }] self.optimizer = torch.optim.SGD(params, momentum=cfg.TRAIN.MOMENTUM) # Write the train and validation information to tensorboard self.writer = tb.writer.FileWriter(self.tbdir) self.valwriter = tb.writer.FileWriter(self.tbvaldir) return lr, self.optimizer def find_previous(self): sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pth') sfiles = glob.glob(sfiles) sfiles.sort(key=os.path.getmtime) # Get the snapshot name in pytorch redfiles = [] for stepsize in cfg.TRAIN.STEPSIZE: redfiles.append( os.path.join( self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}.pth'.format(stepsize + 1))) sfiles = [ss for ss in sfiles if ss not in redfiles] nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl') nfiles = glob.glob(nfiles) nfiles.sort(key=os.path.getmtime) redfiles = [redfile.replace('.pth', '.pkl') for redfile in redfiles] nfiles = [nn for nn in nfiles if nn not in redfiles] lsf = len(sfiles) assert len(nfiles) == lsf return lsf, nfiles, sfiles def initialize(self): # Initial file lists are empty np_paths = [] ss_paths = [] # Fresh train directly from ImageNet weights print('Loading initial model weights from {:s}'.format( self.pretrained_model)) self.net.load_pretrained_cnn(torch.load(self.pretrained_model)) print('Loaded.') # Need to fix the variables before loading, so that the RGB weights are changed to BGR # For VGG16 it also changes the convolutional weights fc6 and fc7 to # fully connected weights last_snapshot_iter = 0 lr = cfg.TRAIN.LEARNING_RATE stepsizes = list(cfg.TRAIN.STEPSIZE) return lr, last_snapshot_iter, stepsizes, np_paths, ss_paths def restore(self, sfile, nfile): # Get the most recent snapshot and restore np_paths = [nfile] ss_paths = [sfile] # Restore model from snapshots last_snapshot_iter = self.from_snapshot(sfile, nfile) # Set the learning rate lr_scale = 1 stepsizes = [] for stepsize in cfg.TRAIN.STEPSIZE: if last_snapshot_iter > stepsize: lr_scale *= cfg.TRAIN.GAMMA else: stepsizes.append(stepsize) scale_lr(self.optimizer, lr_scale) lr = cfg.TRAIN.LEARNING_RATE * lr_scale return lr, last_snapshot_iter, stepsizes, np_paths, ss_paths def remove_snapshot(self, np_paths, ss_paths): to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): nfile = np_paths[0] os.remove(str(nfile)) np_paths.remove(nfile) to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): sfile = ss_paths[0] # To make the code compatible to earlier versions of Tensorflow, # where the naming tradition for checkpoints are different os.remove(str(sfile)) ss_paths.remove(sfile) def train_model(self, max_iters): # Build data layers for both training and validation set self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes) self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True) # Construct the computation graph lr, train_op = self.construct_graph() # Find previous snapshots if there is any to restore from lsf, nfiles, sfiles = self.find_previous() # Initialize the variables or restore them from the last snapshot if lsf == 0: lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize( ) else: lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore( str(sfiles[-1]), str(nfiles[-1])) iter = last_snapshot_iter + 1 last_summary_time = time.time() # Make sure the lists are not empty stepsizes.append(max_iters) stepsizes.reverse() next_stepsize = stepsizes.pop() self.net.train() self.net.cuda() while iter < max_iters + 1: # Learning rate if iter == next_stepsize + 1: # Add snapshot here before reducing the learning rate self.snapshot(iter) lr *= cfg.TRAIN.GAMMA scale_lr(self.optimizer, cfg.TRAIN.GAMMA) next_stepsize = stepsizes.pop() utils.timer.timer.tic() # Get training data, one batch at a time blobs = self.data_layer.forward() now = time.time() if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL: # Compute the graph with summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \ self.net.train_step_with_summary(blobs, self.optimizer) for _sum in summary: self.writer.add_summary(_sum, float(iter)) # Also check the summary on the validation set blobs_val = self.data_layer_val.forward() summary_val = self.net.get_summary(blobs_val) for _sum in summary_val: self.valwriter.add_summary(_sum, float(iter)) last_summary_time = now else: # Compute the graph without summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \ self.net.train_step(blobs, self.optimizer) utils.timer.timer.toc() # Display training information if iter % (cfg.TRAIN.DISPLAY) == 0: print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n ' '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \ (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr)) print('speed: {:.3f}s / iter'.format( utils.timer.timer.average_time())) # for k in utils.timer.timer._average_time.keys(): # print(k, utils.timer.timer.average_time(k)) # Snapshotting if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: last_snapshot_iter = iter ss_path, np_path = self.snapshot(iter) np_paths.append(np_path) ss_paths.append(ss_path) # Remove the old snapshots if there are too many if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT: self.remove_snapshot(np_paths, ss_paths) iter += 1 if last_snapshot_iter != iter - 1: self.snapshot(iter - 1) self.writer.close() self.valwriter.close()
class SolverWrapper(object): """ A wrapper class for the training process """ def __init__(self, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None): self.net = network self.imdb = imdb self.roidb = roidb self.valroidb = valroidb self.output_dir = output_dir self.tbdir = tbdir # Simply put '_val' at the end to save the summaries from the validation set self.tbvaldir = tbdir + '_val' if not os.path.exists(self.tbvaldir): os.makedirs(self.tbvaldir) self.pretrained_model = pretrained_model def snapshot(self, iter): net = self.net if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) # Store the model snapshot filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pth' filename = os.path.join(self.output_dir, filename) torch.save(self.net.state_dict(), filename) print('Wrote snapshot to: {:s}'.format(filename)) # Also store some meta information, random state, etc. nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl' nfilename = os.path.join(self.output_dir, nfilename) # current state of numpy random st0 = np.random.get_state() # current position in the database cur = self.data_layer._cur # current shuffled indexes of the database perm = self.data_layer._perm # current position in the validation database cur_val = self.data_layer_val._cur # current shuffled indexes of the validation database perm_val = self.data_layer_val._perm # Dump the meta info with open(nfilename, 'wb') as fid: pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL) return filename, nfilename def from_snapshot(self, sfile, nfile): print('Restoring model snapshots from {:s}'.format(sfile)) self.net.load_state_dict(torch.load(str(sfile))) print('Restored.') # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have # tried my best to find the random states so that it can be recovered exactly # However the Tensorflow state is currently not available with open(nfile, 'rb') as fid: st0 = pickle.load(fid) cur = pickle.load(fid) perm = pickle.load(fid) cur_val = pickle.load(fid) perm_val = pickle.load(fid) last_snapshot_iter = pickle.load(fid) np.random.set_state(st0) self.data_layer._cur = cur self.data_layer._perm = perm self.data_layer_val._cur = cur_val self.data_layer_val._perm = perm_val return last_snapshot_iter def construct_graph(self): # Set the random seed torch.manual_seed(cfg.RNG_SEED) # Build the main computation graph self.net.create_architecture(self.imdb.num_classes, tag='default', anchor_scales=cfg.ANCHOR_SCALES, anchor_ratios=cfg.ANCHOR_RATIOS) # Define the loss # loss = layers['total_loss'] # Set learning rate and momentum lr = cfg.TRAIN.LEARNING_RATE params = [] for key, value in dict(self.net.named_parameters()).items(): if value.requires_grad: if 'bias' in key: params += [{'params':[value],'lr':lr*(cfg.TRAIN.DOUBLE_BIAS + 1), 'weight_decay': cfg.TRAIN.BIAS_DECAY and cfg.TRAIN.WEIGHT_DECAY or 0}] else: params += [{'params':[value],'lr':lr, 'weight_decay': cfg.TRAIN.WEIGHT_DECAY}] self.optimizer = torch.optim.SGD(params, momentum=cfg.TRAIN.MOMENTUM) # Write the train and validation information to tensorboard self.writer = tb.writer.FileWriter(self.tbdir) self.valwriter = tb.writer.FileWriter(self.tbvaldir) return lr, self.optimizer def find_previous(self): sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pth') sfiles = glob.glob(sfiles) sfiles.sort(key=os.path.getmtime) # Get the snapshot name in pytorch redfiles = [] for stepsize in cfg.TRAIN.STEPSIZE: redfiles.append(os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}.pth'.format(stepsize+1))) sfiles = [ss for ss in sfiles if ss not in redfiles] nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl') nfiles = glob.glob(nfiles) nfiles.sort(key=os.path.getmtime) redfiles = [redfile.replace('.pth', '.pkl') for redfile in redfiles] nfiles = [nn for nn in nfiles if nn not in redfiles] lsf = len(sfiles) assert len(nfiles) == lsf return lsf, nfiles, sfiles def initialize(self): # Initial file lists are empty np_paths = [] ss_paths = [] # Fresh train directly from ImageNet weights print('Loading initial model weights from {:s}'.format(self.pretrained_model)) self.net.load_pretrained_cnn(torch.load(self.pretrained_model)) print('Loaded.') # Need to fix the variables before loading, so that the RGB weights are changed to BGR # For VGG16 it also changes the convolutional weights fc6 and fc7 to # fully connected weights last_snapshot_iter = 0 lr = cfg.TRAIN.LEARNING_RATE stepsizes = list(cfg.TRAIN.STEPSIZE) return lr, last_snapshot_iter, stepsizes, np_paths, ss_paths def restore(self, sfile, nfile): # Get the most recent snapshot and restore np_paths = [nfile] ss_paths = [sfile] # Restore model from snapshots last_snapshot_iter = self.from_snapshot(sfile, nfile) # Set the learning rate lr_scale = 1 stepsizes = [] for stepsize in cfg.TRAIN.STEPSIZE: if last_snapshot_iter > stepsize: lr_scale *= cfg.TRAIN.GAMMA else: stepsizes.append(stepsize) scale_lr(self.optimizer, lr_scale) lr = cfg.TRAIN.LEARNING_RATE * lr_scale return lr, last_snapshot_iter, stepsizes, np_paths, ss_paths def remove_snapshot(self, np_paths, ss_paths): to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): nfile = np_paths[0] os.remove(str(nfile)) np_paths.remove(nfile) to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): sfile = ss_paths[0] # To make the code compatible to earlier versions of Tensorflow, # where the naming tradition for checkpoints are different os.remove(str(sfile)) ss_paths.remove(sfile) def train_model(self, max_iters): # Build data layers for both training and validation set self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes) self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True) # Construct the computation graph lr, train_op = self.construct_graph() # Find previous snapshots if there is any to restore from lsf, nfiles, sfiles = self.find_previous() # Initialize the variables or restore them from the last snapshot if lsf == 0: lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize() else: lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(str(sfiles[-1]), str(nfiles[-1])) iter = last_snapshot_iter + 1 last_summary_time = time.time() # Make sure the lists are not empty stepsizes.append(max_iters) stepsizes.reverse() next_stepsize = stepsizes.pop() self.net.train() self.net.cuda() while iter < max_iters + 1: # Learning rate if iter == next_stepsize + 1: # Add snapshot here before reducing the learning rate self.snapshot(iter) lr *= cfg.TRAIN.GAMMA scale_lr(self.optimizer, cfg.TRAIN.GAMMA) next_stepsize = stepsizes.pop() utils.timer.timer.tic() # Get training data, one batch at a time blobs = self.data_layer.forward() now = time.time() if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL: # Compute the graph with summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \ self.net.train_step_with_summary(blobs, self.optimizer) for _sum in summary: self.writer.add_summary(_sum, float(iter)) # Also check the summary on the validation set blobs_val = self.data_layer_val.forward() summary_val = self.net.get_summary(blobs_val) for _sum in summary_val: self.valwriter.add_summary(_sum, float(iter)) last_summary_time = now else: # Compute the graph without summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \ self.net.train_step(blobs, self.optimizer) utils.timer.timer.toc() # Display training information if iter % (cfg.TRAIN.DISPLAY) == 0: print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n ' '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \ (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr)) print('speed: {:.3f}s / iter'.format(utils.timer.timer.average_time())) # for k in utils.timer.timer._average_time.keys(): # print(k, utils.timer.timer.average_time(k)) # Snapshotting if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: last_snapshot_iter = iter ss_path, np_path = self.snapshot(iter) np_paths.append(np_path) ss_paths.append(ss_path) # Remove the old snapshots if there are too many if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT: self.remove_snapshot(np_paths, ss_paths) iter += 1 if last_snapshot_iter != iter - 1: self.snapshot(iter - 1) self.writer.close() self.valwriter.close()
class SolverWrapper(object): """ A wrapper class for the training process """ def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None): self.net = network self.imdb = imdb self.roidb = roidb self.valroidb = valroidb self.output_dir = output_dir self.tbdir = tbdir # Simply put '_val' at the end to save the summaries from the validation set self.tbvaldir = tbdir + '_val' if not os.path.exists(self.tbvaldir): os.makedirs(self.tbvaldir) self.pretrained_model = pretrained_model def snapshot(self, sess, iter): net = self.net if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) # Store the model snapshot filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format( iter) + '.ckpt' filename = os.path.join(self.output_dir, filename) self.saver.save(sess, filename) print('Wrote snapshot to: {:s}'.format(filename)) # Also store some meta information, random state, etc. nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format( iter) + '.pkl' nfilename = os.path.join(self.output_dir, nfilename) # current state of numpy random st0 = np.random.get_state() # current position in the database cur = self.data_layer._cur # current shuffled indexes of the database perm = self.data_layer._perm # current position in the validation database cur_val = self.data_layer_val._cur # current shuffled indexes of the validation database perm_val = self.data_layer_val._perm # Dump the meta info with open(nfilename, 'wb') as fid: pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL) return filename, nfilename def from_snapshot(self, sess, sfile, nfile): print('Restoring model snapshots from {:s}'.format(sfile)) self.saver.restore(sess, sfile) print('Restored.') # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have # tried my best to find the random states so that it can be recovered exactly # However the Tensorflow state is currently not available with open(nfile, 'rb') as fid: st0 = pickle.load(fid) cur = pickle.load(fid) perm = pickle.load(fid) cur_val = pickle.load(fid) perm_val = pickle.load(fid) last_snapshot_iter = pickle.load(fid) np.random.set_state(st0) self.data_layer._cur = cur self.data_layer._perm = perm self.data_layer_val._cur = cur_val self.data_layer_val._perm = perm_val return last_snapshot_iter def get_variables_in_checkpoint_file(self, file_name): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() return var_to_shape_map except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print( "It's likely that your checkpoint file has been compressed " "with SNAPPY.") def construct_graph(self, sess): with sess.graph.as_default(): # Set the random seed for tensorflow tf.set_random_seed(cfg.RNG_SEED) # Build the main computation graph layers = self.net.create_architecture( 'TRAIN', self.imdb.num_classes, tag='default', anchor_scales=cfg.ANCHOR_SCALES, anchor_ratios=cfg.ANCHOR_RATIOS) # Define the loss loss = layers['total_loss'] # Set learning rate and momentum lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) self.optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM) # Compute the gradients with regard to the loss gvs = self.optimizer.compute_gradients(loss) # Double the gradient of the bias if set if cfg.TRAIN.DOUBLE_BIAS: final_gvs = [] with tf.variable_scope('Gradient_Mult') as scope: for grad, var in gvs: scale = 1. if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name: scale *= 2. if not np.allclose(scale, 1.0): grad = tf.multiply(grad, scale) final_gvs.append((grad, var)) train_op = self.optimizer.apply_gradients(final_gvs) else: train_op = self.optimizer.apply_gradients(gvs) # We will handle the snapshots ourselves self.saver = tf.train.Saver(max_to_keep=100000) # Write the train and validation information to tensorboard self.writer = tf.summary.FileWriter(self.tbdir, sess.graph) self.valwriter = tf.summary.FileWriter(self.tbvaldir) return lr, train_op def find_previous(self): sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta') sfiles = glob.glob(sfiles) sfiles.sort(key=os.path.getmtime) # Get the snapshot name in TensorFlow redfiles = [] for stepsize in cfg.TRAIN.STEPSIZE: redfiles.append( os.path.join( self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}.ckpt.meta'.format(stepsize + 1))) sfiles = [ ss.replace('.meta', '') for ss in sfiles if ss not in redfiles ] nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl') nfiles = glob.glob(nfiles) nfiles.sort(key=os.path.getmtime) redfiles = [ redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles ] nfiles = [nn for nn in nfiles if nn not in redfiles] lsf = len(sfiles) assert len(nfiles) == lsf return lsf, nfiles, sfiles def initialize(self, sess): # Initial file lists are empty np_paths = [] ss_paths = [] # Fresh train directly from ImageNet weights print('Loading initial model weights from {:s}'.format( self.pretrained_model)) variables = tf.global_variables() # Initialize all variables first sess.run(tf.variables_initializer(variables, name='init')) var_keep_dic = self.get_variables_in_checkpoint_file( self.pretrained_model) # Get the variables to restore, ignoring the variables to fix variables_to_restore = self.net.get_variables_to_restore( variables, var_keep_dic) restorer = tf.train.Saver(variables_to_restore) restorer.restore(sess, self.pretrained_model) print('Loaded.') # Need to fix the variables before loading, so that the RGB weights are changed to BGR # For VGG16 it also changes the convolutional weights fc6 and fc7 to # fully connected weights self.net.fix_variables(sess, self.pretrained_model) print('Fixed.') last_snapshot_iter = 0 rate = cfg.TRAIN.LEARNING_RATE stepsizes = list(cfg.TRAIN.STEPSIZE) return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths def restore(self, sess, sfile, nfile): # Get the most recent snapshot and restore np_paths = [nfile] ss_paths = [sfile] # Restore model from snapshots last_snapshot_iter = self.from_snapshot(sess, sfile, nfile) # Set the learning rate rate = cfg.TRAIN.LEARNING_RATE stepsizes = [] for stepsize in cfg.TRAIN.STEPSIZE: if last_snapshot_iter > stepsize: rate *= cfg.TRAIN.GAMMA else: stepsizes.append(stepsize) return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths def remove_snapshot(self, np_paths, ss_paths): to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): nfile = np_paths[0] os.remove(str(nfile)) np_paths.remove(nfile) to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): sfile = ss_paths[0] # To make the code compatible to earlier versions of Tensorflow, # where the naming tradition for checkpoints are different if os.path.exists(str(sfile)): os.remove(str(sfile)) else: os.remove(str(sfile + '.data-00000-of-00001')) os.remove(str(sfile + '.index')) sfile_meta = sfile + '.meta' os.remove(str(sfile_meta)) ss_paths.remove(sfile) def train_model(self, sess, max_iters): # Build data layers for both training and validation set self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes) self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True) # Construct the computation graph lr, train_op = self.construct_graph(sess) # Find previous snapshots if there is any to restore from lsf, nfiles, sfiles = self.find_previous() # Initialize the variables or restore them from the last snapshot if lsf == 0: rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize( sess) else: rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore( sess, str(sfiles[-1]), str(nfiles[-1])) timer = Timer() iter = last_snapshot_iter + 1 last_summary_time = time.time() # Make sure the lists are not empty stepsizes.append(max_iters) stepsizes.reverse() next_stepsize = stepsizes.pop() while iter < max_iters + 1: # Learning rate if iter == next_stepsize + 1: # Add snapshot here before reducing the learning rate self.snapshot(sess, iter) rate *= cfg.TRAIN.GAMMA sess.run(tf.assign(lr, rate)) next_stepsize = stepsizes.pop() timer.tic() # Get training data, one batch at a time blobs = self.data_layer.forward() now = time.time() if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL: # Compute the graph with summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \ self.net.train_step_with_summary(sess, blobs, train_op) self.writer.add_summary(summary, float(iter)) # Also check the summary on the validation set blobs_val = self.data_layer_val.forward() summary_val = self.net.get_summary(sess, blobs_val) self.valwriter.add_summary(summary_val, float(iter)) last_summary_time = now else: # Compute the graph without summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \ self.net.train_step(sess, blobs, train_op) timer.toc() # Display training information if iter % (cfg.TRAIN.DISPLAY) == 0: print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n ' '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \ (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval())) print('speed: {:.3f}s / iter'.format(timer.average_time)) # Snapshotting if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: last_snapshot_iter = iter ss_path, np_path = self.snapshot(sess, iter) np_paths.append(np_path) ss_paths.append(ss_path) # Remove the old snapshots if there are too many if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT: self.remove_snapshot(np_paths, ss_paths) iter += 1 if last_snapshot_iter != iter - 1: self.snapshot(sess, iter - 1) self.writer.close() self.valwriter.close()
class Solver(object): def __init__(self, roidb, net, freeze=0): # Holds current iteration number. self.iter = 0 # How frequently we should print the training info. self.display_freq = 1 # Holds the path prefix for snapshots. self.snapshot_prefix = 'snapshot' self.roidb = roidb self.net = net self.freeze = freeze self.roi_data_layer = RoIDataLayer() self.roi_data_layer.setup() self.roi_data_layer.set_roidb(self.roidb) self.stepfn = self.build_step_fn(self.net) self.predfn = self.build_pred_fn(self.net) # This might be a useful static method to have. #@staticmethod not so static anymore def build_step_fn(self, net): target_y = T.vector("target Y",dtype='int64') tl = lasagne.objectives.categorical_crossentropy(net.prediction,target_y) loss = tl.mean() accuracy = lasagne.objectives.categorical_accuracy(net.prediction,target_y).mean() weights = net.params grads = theano.grad(loss, weights) scales = np.ones(len(weights)) if self.freeze: scales[:-self.freeze] = 0 print 'GRAD SCALE >>>', scales for idx, param in enumerate(weights): grads[idx] *= scales[idx] grads[idx] = grads[idx].astype('float32') #updates_sgd = lasagne.updates.sgd(loss, net.params, learning_rate=0.0001) updates_sgd = lasagne.updates.sgd(grads, net.params, learning_rate=0.0001) stepfn = theano.function([net.inp, target_y], [loss, accuracy], updates=updates_sgd, allow_input_downcast=True) return stepfn @staticmethod def build_pred_fn(net): predfn = theano.function([net.inp], net.prediction, allow_input_downcast=True) return predfn def get_training_batch(self): """Uses ROIDataLayer to fetch a training batch. Returns: input_data (ndarray): input data suitable for R-CNN processing labels (ndarray): batch labels (of type int32) """ data, rois, labels = deepcopy(self.roi_data_layer.top[: 3]) X = roi_layer(data, rois) y = labels.astype('int') return X, y def step(self): self.roi_data_layer.forward() data, labels = self.get_training_batch() """Conducts a single step of SGD.""" loss, acc = self.stepfn(data, labels) self.loss = loss self.acc = acc ###################################################### Your code goes here. # Among other things, assign the current loss value to self.loss. self.iter += 1 if self.iter % self.display_freq == 0: print 'Iteration {:<5} Train loss: {} Train acc: {}'.format(self.iter, self.loss, self.acc) def save(self, filename): self.net.save(filename)
class SolverWrapper(object): """ A wrapper class for the training process """ def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None): self.net = network # network类的实例 self.imdb = imdb # imdb类的实例 self.roidb = roidb # roidb字典 self.valroidb = valroidb # 验证roidb字典 self.output_dir = output_dir # 模型保存路径 self.tbdir = tbdir # tensorboard保存路径 # Simply put '_val' at the end to save the summaries from the validation set self.tbvaldir = tbdir + '_val' # 验证过程的tensorboard保存路径 if not os.path.exists(self.tbvaldir): os.makedirs(self.tbvaldir) self.pretrained_model = pretrained_model # 预训练权重的路径 # 保存快照,包括模型权重的ckpt文件和训练参数的pkl文件 def snapshot(self, sess, iter): net = self.net if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) # Store the model snapshot # 保存模型权重 例:shufflenetv2_faster_rcnn_iter_10000.cpkt等三个文件 # SNAPSHOT_PREFIX在yml文件中配置 例:'shufflenetv2_faster_rcnn' filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt' filename = os.path.join(self.output_dir, filename) self.saver.save(sess, filename) print('Wrote snapshot to: {:s}'.format(filename)) # Also store some meta information, random state, etc. # 保存随机数状态等训练过程参数 nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl' nfilename = os.path.join(self.output_dir, nfilename) # current state of numpy random # 保存随机数种子 st0 = np.random.get_state() # current position in the database # 保存当前图片序号 cur = self.data_layer._cur # current shuffled indexes of the database # 保存打乱后的图片序号列表 perm = self.data_layer._perm # current position in the validation database # 验证过程,同上 cur_val = self.data_layer_val._cur # current shuffled indexes of the validation database perm_val = self.data_layer_val._perm # Dump the meta info # 写入 例:shufflenetv2_faster_rcnn_iter_10000.pkl with open(nfilename, 'wb') as fid: pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL) return filename, nfilename # 载入快照 def from_snapshot(self, sess, sfile, nfile): print('Restoring model snapshots from {:s}'.format(sfile)) self.saver.restore(sess, sfile) print('Restored.') # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have # tried my best to find the random states so that it can be recovered exactly # However the Tensorflow state is currently not available with open(nfile, 'rb') as fid: st0 = pickle.load(fid) cur = pickle.load(fid) perm = pickle.load(fid) cur_val = pickle.load(fid) perm_val = pickle.load(fid) last_snapshot_iter = pickle.load(fid) np.random.set_state(st0) self.data_layer._cur = cur self.data_layer._perm = perm self.data_layer_val._cur = cur_val self.data_layer_val._perm = perm_val return last_snapshot_iter # 返回checkpoint文件中得到的参数词典 def get_variables_in_checkpoint_file(self, file_name): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() return var_to_shape_map except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.") # 构建模型 def construct_graph(self, sess): with sess.graph.as_default(): # Set the random seed for tensorflow # 为了方便复现,随机数种子在cfg中设置,并在保存模型时保存在pkl文件中 tf.set_random_seed(cfg.RNG_SEED) # Build the main computation graph # layers为模型输出,包含roi区域,loss值,预测结果,形式为字典 layers = self.net.create_architecture('TRAIN', self.imdb.num_classes, tag='default', anchor_scales=cfg.ANCHOR_SCALES, anchor_ratios=cfg.ANCHOR_RATIOS) # Define the loss loss = layers['total_loss'] # Set learning rate and momentum # 设置优化器,设定学习速率和动量 lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) self.optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM) # 计算梯度,返回(gradient,variable)列表 # Compute the gradients with regard to the loss gvs = self.optimizer.compute_gradients(loss) # Double the gradient of the bias if set # 在yml文件中重新设置为false,如果为true就将biases的梯度翻倍 if cfg.TRAIN.DOUBLE_BIAS: final_gvs = [] with tf.variable_scope('Gradient_Mult') as scope: for grad, var in gvs: scale = 1. if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name: scale *= 2. if not np.allclose(scale, 1.0): grad = tf.multiply(grad, scale) final_gvs.append((grad, var)) # 更新梯度 train_op = self.optimizer.apply_gradients(final_gvs) else: # 更新梯度 train_op = self.optimizer.apply_gradients(gvs) # 创建Saver类,默认保存所有检查点 # We will handle the snapshots ourselves self.saver = tf.train.Saver(max_to_keep=100000) # Write the train and validation information to tensorboard # 创建训练和验证过程tensorboard的filewrite类 self.writer = tf.summary.FileWriter(self.tbdir, sess.graph) self.valwriter = tf.summary.FileWriter(self.tbvaldir) return lr, train_op def find_previous(self): # 设置匹配字符串 sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta') # 返回当前目录下符合格式的文件列表 sfiles = glob.glob(sfiles) # 按照文件最后修改时间排序 sfiles.sort(key=os.path.getmtime) # Get the snapshot name in TensorFlow # 不是很懂这部分的目的 # '''lb:应该用于查看是否之前保存过模型,如果有的restore,没有就initialize''' redfiles = [] for stepsize in cfg.TRAIN.STEPSIZE: redfiles.append(os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}.ckpt.meta'.format(stepsize+1))) sfiles = [ss.replace('.meta', '') for ss in sfiles if ss not in redfiles] nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl') nfiles = glob.glob(nfiles) nfiles.sort(key=os.path.getmtime) redfiles = [redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles] nfiles = [nn for nn in nfiles if nn not in redfiles] lsf = len(sfiles) assert len(nfiles) == lsf return lsf, nfiles, sfiles # 模型初始化 def initialize(self, sess): # Initial file lists are empty # 作用是初始化变量 np_paths = [] ss_paths = [] # Fresh train directly from ImageNet weights print('Loading initial model weights from {:s}'.format(self.pretrained_model)) # 获取全部变量 variables = tf.global_variables() # Initialize all variables first # 初始化全部变量 sess.run(tf.variables_initializer(variables, name='init')) # 从预训练的checkpoint中获得变量和对应值的字典 var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model) # Get the variables to restore, ignoring the variables to fix # 在resnet等子类中实现,获得需要重载的参数字典 variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic) # 从预训练权重中重载variables_to_restore中的参数 restorer = tf.train.Saver(variables_to_restore) restorer.restore(sess, self.pretrained_model) print('Loaded.') # Need to fix the variables before loading, so that the RGB weights are changed to BGR # For VGG16 it also changes the convolutional weights fc6 and fc7 to # fully connected weights # 部分参数需要转化后再重载,在每个子网络中实现 self.net.fix_variables(sess, self.pretrained_model) print('Fixed.') last_snapshot_iter = 0 rate = cfg.TRAIN.LEARNING_RATE stepsizes = list(cfg.TRAIN.STEPSIZE) return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths # 获取最新的快照(用于重载) def restore(self, sess, sfile, nfile): # Get the most recent snapshot and restore np_paths = [nfile] ss_paths = [sfile] # Restore model from snapshots # 从pkl中回复变量 last_snapshot_iter = self.from_snapshot(sess, sfile, nfile) # Set the learning rate # 初始学习速率 rate = cfg.TRAIN.LEARNING_RATE stepsizes = [] # 目前cfg.TRAIN.STEPSIZE仅有一项,为[30000],迭代超过30000次后,学习速率乘0.1 # 在train_faster_rcnn.sh中根据数据集重新设置,voc_2007_trainval默认为[50000] for stepsize in cfg.TRAIN.STEPSIZE: if last_snapshot_iter > stepsize: rate *= cfg.TRAIN.GAMMA else: stepsizes.append(stepsize) return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths # 删除多余的模型快照,默认保存3个 def remove_snapshot(self, np_paths, ss_paths): to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): nfile = np_paths[0] os.remove(str(nfile)) np_paths.remove(nfile) to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): sfile = ss_paths[0] # To make the code compatible to earlier versions of Tensorflow, # where the naming tradition for checkpoints are different if os.path.exists(str(sfile)): os.remove(str(sfile)) else: os.remove(str(sfile + '.data-00000-of-00001')) os.remove(str(sfile + '.index')) sfile_meta = sfile + '.meta' os.remove(str(sfile_meta)) ss_paths.remove(sfile) # 训练模型 def train_model(self, sess, max_iters): # max_iters在train_faster_rcnn.sh指定 # Build data layers for both training and validation set # 创建RoIDataLayer类 self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes) self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True) # Construct the computation graph # 构建计算图 lr, train_op = self.construct_graph(sess) # Find previous snapshots if there is any to restore from # 返回快照文件 lsf为快照文件个数 lsf, nfiles, sfiles = self.find_previous() # Initialize the variables or restore them from the last snapshot # 如果没有快照文件就从头初始化,有则直接重载 if lsf == 0: rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(sess) else: rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(sess, str(sfiles[-1]), str(nfiles[-1])) timer = Timer() iter = last_snapshot_iter + 1 last_summary_time = time.time() # Make sure the lists are not empty # max_iter被添加到列表末尾,又反向排列成了第一个,然后pop从末尾开始提取元素 # 在initialize中:stepsizes = list(cfg.TRAIN.STEPSIZE) # 在restore中:如果cfg.TRAIN.STEPSIZE中的值大于last_snapshot_iter的保存下来 stepsizes.append(max_iters) stepsizes.reverse() next_stepsize = stepsizes.pop() while iter < max_iters + 1: # Learning rate # 每满足下一个stepsize,就保存一次快照,并将学习速率降低 if iter == next_stepsize + 1: # Add snapshot here before reducing the learning rate self.snapshot(sess, iter) rate *= cfg.TRAIN.GAMMA sess.run(tf.assign(lr, rate)) next_stepsize = stepsizes.pop() # 开始计时 timer.tic() # Get training data, one batch at a time # 提取一个batch的blobs数据 # blobs包含三组键值对{data、gt_boxes、im_info},具体内容在roi_data_layer/minibatch.py中 blobs = self.data_layer.forward() now = time.time() # cfg.TRAIN.SUMMARY_INTERVAL=180 每3分钟保存一次摘要 if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL: # Compute the graph with summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \ self.net.train_step_with_summary(sess, blobs, train_op) # 保存摘要 self.writer.add_summary(summary, float(iter)) # Also check the summary on the validation set # 从验证集提取一个batch进行验证,并计算保存摘要 blobs_val = self.data_layer_val.forward() summary_val = self.net.get_summary(sess, blobs_val) self.valwriter.add_summary(summary_val, float(iter)) last_summary_time = now else: # Compute the graph without summary # 只训练,不计算保存摘要 rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \ self.net.train_step(sess, blobs, train_op) # 结束计时 timer.toc() # Display training information if iter % (cfg.TRAIN.DISPLAY) == 0: print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n ' '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \ (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval())) print('speed: {:.3f}s / iter'.format(timer.average_time)) # Snapshotting # 每间隔SNAPSHOT_ITERS次保存一次快照 if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: last_snapshot_iter = iter ss_path, np_path = self.snapshot(sess, iter) np_paths.append(np_path) ss_paths.append(ss_path) # Remove the old snapshots if there are too many # 删除多余快照 if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT: self.remove_snapshot(np_paths, ss_paths) iter += 1 # 保存训练结束时的快照 if last_snapshot_iter != iter - 1: self.snapshot(sess, iter - 1) self.writer.close() self.valwriter.close()
if not os.path.exists(output_dir): os.makedirs(output_dir) # training train_loss = 0 tp, tf, fg, bg = 0., 0., 0, 0 step_cnt = 0 re_cnt = False t = Timer() t.tic() for step in range(start_step, end_step + 1): net.train() # get one batch blobs = data_layer.forward() im_data = blobs['data'] # im_data = (1, 600, 901, 3) rois = blobs['rois'] #rois = (64, 5) im_info = blobs['im_info'] # im_info = (1,3) gt_vec = blobs['labels'] # gt_vec = (1,20) # import pdb; pdb.set_trace() #gt_boxes = blobs['gt_boxes'] # forward # rois = (128,5) im_info = (1,3) gt_vec = (1, 20), one-hot label for two streams net(im_data, rois, im_info, gt_vec) loss = net.loss train_loss += loss.data[0] step_cnt += 1 # backward pass and update optimizer.zero_grad()
class SolverWrapper(object): """ A wrapper class for the training process """ def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None): self.net = network self.imdb = imdb self.roidb = roidb self.valroidb = valroidb self.output_dir = output_dir self.tbdir = tbdir # Simply put '_val' at the end to save the summaries from the validation set self.tbvaldir = tbdir + '_val' if not os.path.exists(self.tbvaldir): os.makedirs(self.tbvaldir) self.pretrained_model = pretrained_model def snapshot(self, sess, iter): net = self.net if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) # Store the model snapshot filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt' filename = os.path.join(self.output_dir, filename) self.saver.save(sess, filename) print('Wrote snapshot to: {:s}'.format(filename)) # Also store some meta information, random state, etc. nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl' nfilename = os.path.join(self.output_dir, nfilename) # current state of numpy random st0 = np.random.get_state() # current position in the database cur = self.data_layer._cur # current shuffled indeces of the database perm = self.data_layer._perm # current position in the validation database cur_val = self.data_layer_val._cur # current shuffled indeces of the validation database perm_val = self.data_layer_val._perm # Dump the meta info with open(nfilename, 'wb') as fid: pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL) return filename, nfilename def get_variables_in_checkpoint_file(self, file_name): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() return var_to_shape_map except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.") def train_model(self, sess, max_iters): # Build data layers for both training and validation set self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes) self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True) # Determine different scales for anchors, see paper with sess.graph.as_default(): # Set the random seed for tensorflow tf.set_random_seed(cfg.RNG_SEED) # Build the main computation graph layers = self.net.create_architecture(sess, 'TRAIN', self.imdb.num_classes, tag='default', anchor_scales=cfg.ANCHOR_SCALES) # Define the loss loss = layers['total_loss'] # Set learning rate and momentum lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) momentum = cfg.TRAIN.MOMENTUM self.optimizer = tf.train.MomentumOptimizer(lr, momentum) # Compute the gradients wrt the loss gvs = self.optimizer.compute_gradients(loss) # Double the gradient of the bias if set if cfg.TRAIN.DOUBLE_BIAS: final_gvs = [] with tf.variable_scope('Gradient_Mult') as scope: for grad, var in gvs: scale = 1. if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name: scale *= 2. if not np.allclose(scale, 1.0): grad = tf.multiply(grad, scale) final_gvs.append((grad, var)) train_op = self.optimizer.apply_gradients(final_gvs) else: train_op = self.optimizer.apply_gradients(gvs) # We will handle the snapshots ourselves self.saver = tf.train.Saver(max_to_keep=100000) # Write the train and validation information to tensorboard self.writer = tf.summary.FileWriter(self.tbdir, sess.graph) self.valwriter = tf.summary.FileWriter(self.tbvaldir) # Find previous snapshots if there is any to restore from sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta') sfiles = glob.glob(sfiles) sfiles.sort(key=os.path.getmtime) # Get the snapshot name in TensorFlow sfiles = [ss.replace('.meta', '') for ss in sfiles] nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl') nfiles = glob.glob(nfiles) nfiles.sort(key=os.path.getmtime) lsf = len(sfiles) assert len(nfiles) == lsf np_paths = nfiles ss_paths = sfiles if lsf == 0: # Fresh train directly from ImageNet weights print('Loading initial model weights from {:s}'.format(self.pretrained_model)) variables = tf.global_variables() # Only initialize the variables that were not initialized when the graph was built sess.run(tf.variables_initializer(variables, name='init')) var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model) variables_to_restore = [] var_to_dic = {} # print(var_keep_dic) for v in variables: # exclude the conv weights that are fc weights in vgg16 if v.name == 'vgg_16/fc6/weights:0' or v.name == 'vgg_16/fc7/weights:0': var_to_dic[v.name] = v continue if v.name.split(':')[0] in var_keep_dic: print('Varibles restored: %s' % v.name) variables_to_restore.append(v) restorer = tf.train.Saver(variables_to_restore) restorer.restore(sess, self.pretrained_model) print('Loaded.') sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE)) # A temporary solution to fix the vgg16 issue from conv weights to fc weights if self.net._arch == 'vgg16': print('Converting VGG16 fc layers..') with tf.device("/cpu:0"): fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False) fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False) restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv, "vgg_16/fc7/weights": fc7_conv}) restorer_fc.restore(sess, self.pretrained_model) sess.run(tf.assign(var_to_dic['vgg_16/fc6/weights:0'], tf.reshape(fc6_conv, var_to_dic['vgg_16/fc6/weights:0'].get_shape()))) sess.run(tf.assign(var_to_dic['vgg_16/fc7/weights:0'], tf.reshape(fc7_conv, var_to_dic['vgg_16/fc7/weights:0'].get_shape()))) last_snapshot_iter = 0 else: # Get the most recent snapshot and restore ss_paths = [ss_paths[-1]] np_paths = [np_paths[-1]] print('Restorining model snapshots from {:s}'.format(sfiles[-1])) self.saver.restore(sess, str(sfiles[-1])) print('Restored.') # Needs to restore the other hyperparameters/states for training, (TODO xinlei) I have # tried my best to find the random states so that it can be recovered exactly # However the Tensorflow state is currently not available with open(str(nfiles[-1]), 'rb') as fid: st0 = pickle.load(fid) cur = pickle.load(fid) perm = pickle.load(fid) cur_val = pickle.load(fid) perm_val = pickle.load(fid) last_snapshot_iter = pickle.load(fid) np.random.set_state(st0) self.data_layer._cur = cur self.data_layer._perm = perm self.data_layer_val._cur = cur_val self.data_layer_val._perm = perm_val # Set the learning rate, only reduce once if last_snapshot_iter > cfg.TRAIN.STEPSIZE: sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA)) else: sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE)) timer = Timer() iter = last_snapshot_iter + 1 last_summary_time = time.time() while iter < max_iters + 1: # Learning rate if iter == cfg.TRAIN.STEPSIZE + 1: # Add snapshot here before reducing the learning rate self.snapshot(sess, iter) sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA)) timer.tic() # Get training data, one batch at a time blobs = self.data_layer.forward() now = time.time() if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL: # Compute the graph with summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \ self.net.train_step_with_summary(sess, blobs, train_op) self.writer.add_summary(summary, float(iter)) # Also check the summary on the validation set blobs_val = self.data_layer_val.forward() summary_val = self.net.get_summary(sess, blobs_val) self.valwriter.add_summary(summary_val, float(iter)) last_summary_time = now else: # Compute the graph without summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \ self.net.train_step(sess, blobs, train_op) timer.toc() # Display training information if iter % (cfg.TRAIN.DISPLAY) == 0: print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n ' '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \ (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval())) print('speed: {:.3f}s / iter'.format(timer.average_time)) if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: last_snapshot_iter = iter snapshot_path, np_path = self.snapshot(sess, iter) np_paths.append(np_path) ss_paths.append(snapshot_path) # Remove the old snapshots if there are too many if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT: to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): nfile = np_paths[0] os.remove(str(nfile)) np_paths.remove(nfile) if len(ss_paths) > cfg.TRAIN.SNAPSHOT_KEPT: to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): sfile = ss_paths[0] # To make the code compatible to earlier versions of Tensorflow, # where the naming tradition for checkpoints are different if os.path.exists(str(sfile)): os.remove(str(sfile)) else: os.remove(str(sfile + '.data-00000-of-00001')) os.remove(str(sfile + '.index')) sfile_meta = sfile + '.meta' os.remove(str(sfile_meta)) ss_paths.remove(sfile) iter += 1 if last_snapshot_iter != iter - 1: self.snapshot(sess, iter - 1) self.writer.close() self.valwriter.close()
class SolverWrapper(object): """ A wrapper class for the training process """ def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None): self.net = network self.imdb = imdb self.roidb = roidb self.valroidb = valroidb self.output_dir = output_dir self.tbdir = tbdir # Simply put '_val' at the end to save the summaries from the validation set self.tbvaldir = tbdir + '_val' common.check_dir(self.tbvaldir) self.pretrained_model = pretrained_model def snapshot(self, sess, iter): net = self.net if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) # Store the model snapshot filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt' filename = os.path.join(self.output_dir, filename) self.saver.save(sess, filename) print('Wrote snapshot to: {:s}'.format(filename)) # Also store some meta information, random state, etc. nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl' nfilename = os.path.join(self.output_dir, nfilename) # current state of numpy random st0 = np.random.get_state() # current position in the database cur = self.data_layer._cur # current shuffled indexes of the database perm = self.data_layer._perm # current position in the validation database cur_val = self.data_layer_val._cur # current shuffled indexes of the validation database perm_val = self.data_layer_val._perm # Dump the meta info with open(nfilename, 'wb') as fid: pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL) return filename, nfilename def from_snapshot(self, sess, sfile, nfile): print('Restoring model snapshots from {:s}'.format(sfile)) self.saver.restore(sess, sfile) print('Restored.') # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have # tried my best to find the random states so that it can be recovered exactly # However the Tensorflow state is currently not available with open(nfile, 'rb') as fid: st0 = pickle.load(fid) cur = pickle.load(fid) perm = pickle.load(fid) cur_val = pickle.load(fid) perm_val = pickle.load(fid) last_snapshot_iter = pickle.load(fid) np.random.set_state(st0) self.data_layer._cur = cur self.data_layer._perm = perm self.data_layer_val._cur = cur_val self.data_layer_val._perm = perm_val return last_snapshot_iter def get_variables_in_checkpoint_file(self, file_name): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() return var_to_shape_map except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.") def construct_graph(self, sess): with sess.graph.as_default(): # Set the random seed for tensorflow tf.set_random_seed(cfg.RNG_SEED) # Build the main computation graph layers = self.net.create_architecture('TRAIN', self.imdb.num_classes, tag='default', anchor_width=cfg.CTPN.ANCHOR_WIDTH, anchor_h_ratio_step=cfg.CTPN.H_RADIO_STEP, num_anchors=cfg.CTPN.NUM_ANCHORS) # Define the loss total_loss = layers['total_loss'] # Set learning rate and momentum lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) if cfg.TRAIN.OPTIMIZER == 'Adam': self.optimizer = tf.train.AdamOptimizer(lr) elif cfg.TRAIN.OPTIMIZER == 'Momentum': self.optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM) elif cfg.TRAIN.OPTIMIZER == 'RMS': self.optimizer = tf.train.RMSPropOptimizer(lr) else: raise NotImplementedError global_step = tf.Variable(0, trainable=False) with_clip = False if with_clip: tvars = tf.trainable_variables() grads, norm = tf.clip_by_global_norm(tf.gradients(total_loss, tvars), 10.0) train_op = self.optimizer.apply_gradients(list(zip(grads, tvars)), global_step=global_step) else: # required by tf.layers.batch_normalization() # add update ops(for moving_mean and moving_variance) as a dependency to the train_op # https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = self.optimizer.minimize(total_loss, global_step=global_step) # We will handle the snapshots ourselves self.saver = tf.train.Saver(max_to_keep=100000) # Write the train and validation information to tensorboard self.writer = tf.summary.FileWriter(self.tbdir, sess.graph) self.valwriter = tf.summary.FileWriter(self.tbvaldir) return lr, train_op def find_previous(self): sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta') sfiles = glob.glob(sfiles) sfiles.sort(key=os.path.getmtime) # Get the snapshot name in TensorFlow redfiles = [] for stepsize in cfg.TRAIN.STEPSIZE: redfiles.append(os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}.ckpt.meta'.format(stepsize + 1))) sfiles = [ss.replace('.meta', '') for ss in sfiles if ss not in redfiles] nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl') nfiles = glob.glob(nfiles) nfiles.sort(key=os.path.getmtime) redfiles = [redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles] nfiles = [nn for nn in nfiles if nn not in redfiles] lsf = len(sfiles) assert len(nfiles) == lsf return lsf, nfiles, sfiles def restore_ckpt_from_dir(self, sess, net, checkpoint_dir): print("Restoring checkpoint from: " + checkpoint_dir) ckpt = tf.train.latest_checkpoint(checkpoint_dir) if ckpt is None: print("Checkpoint not found") exit(-1) meta_file = ckpt + '.meta' try: print('Restore variables from {}'.format(ckpt)) print('Restore meta_filr from {}'.format(meta_file)) saver = tf.train.Saver(net.variables_to_restore) saver.restore(sess, ckpt) except Exception: raise Exception("Can not restore from {}".format(checkpoint_dir)) def initialize(self, sess): # Initial file lists are empty np_paths = [] ss_paths = [] variables = tf.global_variables() # Initialize all variables first sess.run(tf.variables_initializer(variables, name='init')) if self.pretrained_model is not None: if self.pretrained_model.endswith('.ckpt'): # Fresh train directly from ImageNet weights print('Loading initial model weights from {:s}'.format(self.pretrained_model)) var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model) # Get the variables to restore, ignoring the variables to fix variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic) restorer = tf.train.Saver(variables_to_restore) restorer.restore(sess, self.pretrained_model) print('Loaded.') else: # Restore from checkpoint and meta file self.restore_ckpt_from_dir(sess, self.net, self.pretrained_model) print('Loaded.') last_snapshot_iter = 0 rate = cfg.TRAIN.LEARNING_RATE stepsizes = list(cfg.TRAIN.STEPSIZE) return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths def restore(self, sess, sfile, nfile): # Get the most recent snapshot and restore np_paths = [nfile] ss_paths = [sfile] # Restore model from snapshots last_snapshot_iter = self.from_snapshot(sess, sfile, nfile) # Set the learning rate rate = cfg.TRAIN.LEARNING_RATE stepsizes = [] for stepsize in cfg.TRAIN.STEPSIZE: if last_snapshot_iter > stepsize: rate *= cfg.TRAIN.GAMMA else: stepsizes.append(stepsize) return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths def remove_snapshot(self, np_paths, ss_paths): to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): nfile = np_paths[0] os.remove(str(nfile)) np_paths.remove(nfile) to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): sfile = ss_paths[0] # To make the code compatible to earlier versions of Tensorflow, # where the naming tradition for checkpoints are different if os.path.exists(str(sfile)): os.remove(str(sfile)) else: os.remove(str(sfile + '.data-00000-of-00001')) os.remove(str(sfile + '.index')) sfile_meta = sfile + '.meta' os.remove(str(sfile_meta)) ss_paths.remove(sfile) def train_model(self, sess, max_iters): # Build data layers for both training and validation set self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes) self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True) # Construct the computation graph lr, train_op = self.construct_graph(sess) # Find previous snapshots if there is any to restore from lsf, nfiles, sfiles = self.find_previous() # Initialize the variables or restore them from the last snapshot if lsf == 0: rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(sess) else: rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(sess, str(sfiles[-1]), str(nfiles[-1])) timer = Timer() iter = last_snapshot_iter + 1 last_summary_time = time.time() # Make sure the lists are not empty stepsizes.append(max_iters) stepsizes.reverse() next_stepsize = stepsizes.pop() while iter < max_iters + 1: # Learning rate if iter == next_stepsize + 1: # Add snapshot here before reducing the learning rate self.snapshot(sess, iter) rate *= cfg.TRAIN.GAMMA sess.run(tf.assign(lr, rate)) next_stepsize = stepsizes.pop() timer.tic() # Get training data, one batch at a time blobs = self.data_layer.forward() now = time.time() if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL: # Compute the graph with summary rpn_loss_cls, rpn_loss_box, rpn_loss, total_loss, summary = \ self.net.train_step_with_summary(sess, blobs, train_op) self.writer.add_summary(summary, float(iter)) # Also check the summary on the validation set blobs_val = self.data_layer_val.forward() summary_val = self.net.get_summary(sess, blobs_val) self.valwriter.add_summary(summary_val, float(iter)) last_summary_time = now else: # Compute the graph without summary rpn_loss_cls, rpn_loss_box, rpn_loss, total_loss, _ = \ self.net.train_step(sess, blobs, train_op) timer.toc() print('%d/%d time: %.3f total_loss: %.3f rpn_loss: %.3f rpn_loss_cls: %.3f ' 'rpn_loss_box: %.3f lr: %f' % ( iter, max_iters, timer.diff, total_loss, rpn_loss, rpn_loss_cls, rpn_loss_box, lr.eval())) # Snapshotting if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: last_snapshot_iter = iter ss_path, np_path = self.snapshot(sess, iter) np_paths.append(np_path) ss_paths.append(ss_path) # Remove the old snapshots if there are too many if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT: self.remove_snapshot(np_paths, ss_paths) iter += 1 if last_snapshot_iter != iter - 1: self.snapshot(sess, iter - 1) self.writer.close() self.valwriter.close()
class SolverWrapper(object): """ A wrapper class for the training process """ def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None): self.net = network self.imdb = imdb self.roidb = roidb self.valroidb = valroidb self.output_dir = output_dir self.tbdir = tbdir # Simply put '_val' at the end to save the summaries from the validation set self.tbvaldir = tbdir + '_val' if not os.path.exists(self.tbvaldir): os.makedirs(self.tbvaldir) self.pretrained_model = pretrained_model def snapshot(self, sess, iter): net = self.net if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) # Store the model snapshot filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format( iter) + '.ckpt' filename = os.path.join(self.output_dir, filename) self.saver.save(sess, filename) print 'Wrote snapshot to: {:s}'.format(filename) # Also store some meta information, random state, etc. nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format( iter) + '.pkl' nfilename = os.path.join(self.output_dir, nfilename) # current state of numpy random st0 = np.random.get_state() # current position in the database cur = self.data_layer._cur # current shuffled indeces of the database perm = self.data_layer._perm # current position in the validation database cur_val = self.data_layer_val._cur # current shuffled indeces of the validation database perm_val = self.data_layer_val._perm # Dump the meta info with open(nfilename, 'wb') as fid: cPickle.dump(st0, fid, cPickle.HIGHEST_PROTOCOL) cPickle.dump(cur, fid, cPickle.HIGHEST_PROTOCOL) cPickle.dump(perm, fid, cPickle.HIGHEST_PROTOCOL) cPickle.dump(cur_val, fid, cPickle.HIGHEST_PROTOCOL) cPickle.dump(perm_val, fid, cPickle.HIGHEST_PROTOCOL) cPickle.dump(iter, fid, cPickle.HIGHEST_PROTOCOL) return filename, nfilename def train_model(self, sess, max_iters): # Build data layers for both training and validation set self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes) self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True) # Determine different scales for anchors, see paper if self.imdb.name.startswith('voc'): anchors = [8, 16, 32] else: anchors = [4, 8, 16, 32] with sess.graph.as_default(): # Set the random seed for tensorflow tf.set_random_seed(cfg.RNG_SEED) # Build the main computation graph layers = self.net.create_architecture( sess, "TRAIN", self.imdb.num_classes, caffe_weight_path=self.pretrained_model, tag='default', anchor_scales=anchors) # Define the loss loss = layers['total_loss'] # Set learning rate and momentum lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) momentum = cfg.TRAIN.MOMENTUM self.optimizer = tf.train.MomentumOptimizer(lr, momentum) # Compute the gradients wrt the loss gvs = self.optimizer.compute_gradients(loss) # Double the gradient of the bias if set if cfg.TRAIN.DOUBLE_BIAS: final_gvs = [] with tf.variable_scope('Gradient_Mult') as scope: for grad, var in gvs: scale = 1. if cfg.TRAIN.DOUBLE_BIAS and '/bias:' in var.name: scale *= 2. if not np.allclose(scale, 1.0): grad = tf.multiply(grad, scale) final_gvs.append((grad, var)) train_op = self.optimizer.apply_gradients(final_gvs) else: train_op = self.optimizer.apply_gradients(gvs) # We will handle the snapshots ourselves self.saver = tf.train.Saver(max_to_keep=100000) # Write the train and validation information to tensorboard self.writer = tf.summary.FileWriter(self.tbdir, sess.graph) self.valwriter = tf.summary.FileWriter(self.tbvaldir) # Find previous snapshots if there is any to restore from sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta') sfiles = glob.glob(sfiles) sfiles.sort(key=os.path.getmtime) # Get the snapshot name in TensorFlow sfiles = [ss.replace('.meta', '') for ss in sfiles] nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl') nfiles = glob.glob(nfiles) nfiles.sort(key=os.path.getmtime) lsf = len(sfiles) assert len(nfiles) == lsf np_paths = nfiles ss_paths = sfiles if lsf == 0: # Fresh train directly from VGG weights print('Loading initial model weights from {:s}').format( self.pretrained_model) variables = tf.global_variables() # Only initialize the variables that were not initialized when the graph was built for vbs in self.net._initialized: variables.remove(vbs) sess.run(tf.variables_initializer(variables, name='init')) print 'Loaded.' sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE)) last_snapshot_iter = 0 else: # Get the most recent snapshot and restore ss_paths = [ss_paths[-1]] np_paths = [np_paths[-1]] print('Restorining model snapshots from {:s}').format(sfiles[-1]) self.saver.restore(sess, str(sfiles[-1])) print 'Restored.' # Needs to restore the other hyperparameters/states for training, (TODO xinlei) I have # tried my best to find the random states so that it can be recovered exactly # However the Tensorflow state is currently not available with open(str(nfiles[-1]), 'rb') as fid: st0 = cPickle.load(fid) cur = cPickle.load(fid) perm = cPickle.load(fid) cur_val = cPickle.load(fid) perm_val = cPickle.load(fid) last_snapshot_iter = cPickle.load(fid) np.random.set_state(st0) self.data_layer._cur = cur self.data_layer._perm = perm self.data_layer_val._cur = cur_val self.data_layer_val._perm = perm_val # Set the learning rate, only reduce once if last_snapshot_iter > cfg.TRAIN.STEPSIZE: sess.run( tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA)) else: sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE)) timer = Timer() iter = last_snapshot_iter + 1 last_summary_time = time.time() while iter < max_iters + 1: # Learning rate if iter == cfg.TRAIN.STEPSIZE + 1: # Add snapshot here before reducing the learning rate self.snapshot(sess, iter) sess.run( tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA)) timer.tic() # Get training data, one batch at a time blobs = self.data_layer.forward() now = time.time() if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL: # Compute the graph with summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \ self.net.train_step_with_summary(sess, blobs, train_op) self.writer.add_summary(summary, float(iter)) # Also check the summary on the validation set blobs_val = self.data_layer_val.forward() summary_val = self.net.get_summary(sess, blobs_val) self.valwriter.add_summary(summary_val, float(iter)) last_summary_time = now else: # Compute the graph without summary rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \ self.net.train_step(sess, blobs, train_op) timer.toc() # Display training information if iter % (cfg.TRAIN.DISPLAY) == 0: print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n ' '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \ (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval())) print 'speed: {:.3f}s / iter'.format(timer.average_time) if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: last_snapshot_iter = iter snapshot_path, np_path = self.snapshot(sess, iter) np_paths.append(np_path) ss_paths.append(snapshot_path) # Remove the old snapshots if there are too many if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT: to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in xrange(to_remove): nfile = np_paths[0] os.remove(str(nfile)) np_paths.remove(nfile) if len(ss_paths) > cfg.TRAIN.SNAPSHOT_KEPT: to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in xrange(to_remove): sfile = ss_paths[0] # To make the code compatible to earlier versions of Tensorflow, # where the naming tradition for checkpoints are different if os.path.exists(str(sfile)): os.remove(str(sfile)) else: os.remove(str(sfile + '.data-00000-of-00001')) os.remove(str(sfile + '.index')) sfile_meta = sfile + '.meta' os.remove(str(sfile_meta)) ss_paths.remove(sfile) iter += 1 if last_snapshot_iter != iter - 1: self.snapshot(sess, iter - 1) self.writer.close() self.valwriter.close()
class SolverWrapper(object): """ A wrapper class for the training process """ def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None): self.net = network self.imdb = imdb self.roidb = roidb self.valroidb = valroidb self.output_dir = output_dir self.tbdir = tbdir # Simply put '_val' at the end to save the summaries from the validation set self.tbvaldir = tbdir + '_val' common.check_dir(self.tbvaldir) self.pretrained_model = pretrained_model def snapshot(self, sess, iter): net = self.net if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) # Store the model snapshot filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format( iter) + '.ckpt' filename = os.path.join(self.output_dir, filename) self.saver.save(sess, filename) print('Wrote snapshot to: {:s}'.format(filename)) # Also store some meta information, random state, etc. nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format( iter) + '.pkl' nfilename = os.path.join(self.output_dir, nfilename) # current state of numpy random st0 = np.random.get_state() # current position in the database cur = self.data_layer._cur # current shuffled indexes of the database perm = self.data_layer._perm # current position in the validation database cur_val = self.data_layer_val._cur # current shuffled indexes of the validation database perm_val = self.data_layer_val._perm # Dump the meta info with open(nfilename, 'wb') as fid: pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL) return filename, nfilename def from_snapshot(self, sess, sfile, nfile): print('Restoring model snapshots from {:s}'.format(sfile)) self.saver.restore(sess, sfile) print('Restored.') # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have # tried my best to find the random states so that it can be recovered exactly # However the Tensorflow state is currently not available with open(nfile, 'rb') as fid: st0 = pickle.load(fid) cur = pickle.load(fid) perm = pickle.load(fid) cur_val = pickle.load(fid) perm_val = pickle.load(fid) last_snapshot_iter = pickle.load(fid) np.random.set_state(st0) self.data_layer._cur = cur self.data_layer._perm = perm self.data_layer_val._cur = cur_val self.data_layer_val._perm = perm_val return last_snapshot_iter def get_variables_in_checkpoint_file(self, file_name): try: reader = pywrap_tensorflow.NewCheckpointReader(file_name) var_to_shape_map = reader.get_variable_to_shape_map() return var_to_shape_map except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print( "It's likely that your checkpoint file has been compressed " "with SNAPPY.") def construct_graph(self, sess): with sess.graph.as_default(): # Set the random seed for tensorflow tf.set_random_seed(cfg.RNG_SEED) # Build the main computation graph layers = self.net.create_architecture( 'TRAIN', self.imdb.num_classes, tag='default', anchor_width=cfg.CTPN.ANCHOR_WIDTH, anchor_h_ratio_step=cfg.CTPN.H_RADIO_STEP, num_anchors=cfg.CTPN.NUM_ANCHORS) # Define the loss total_loss = layers['total_loss'] # Set learning rate and momentum lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) if cfg.TRAIN.OPTIMIZER == 'Adam': self.optimizer = tf.train.AdamOptimizer(lr) elif cfg.TRAIN.OPTIMIZER == 'Momentum': self.optimizer = tf.train.MomentumOptimizer( lr, cfg.TRAIN.MOMENTUM) elif cfg.TRAIN.OPTIMIZER == 'RMS': self.optimizer = tf.train.RMSPropOptimizer(lr) else: raise NotImplementedError global_step = tf.Variable(0, trainable=False) with_clip = False if with_clip: tvars = tf.trainable_variables() grads, norm = tf.clip_by_global_norm( tf.gradients(total_loss, tvars), 10.0) train_op = self.optimizer.apply_gradients( list(zip(grads, tvars)), global_step=global_step) else: # required by tf.layers.batch_normalization() # add update ops(for moving_mean and moving_variance) as a dependency to the train_op # https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = self.optimizer.minimize(total_loss, global_step=global_step) # We will handle the snapshots ourselves self.saver = tf.train.Saver(max_to_keep=100000) # Write the train and validation information to tensorboard self.writer = tf.summary.FileWriter(self.tbdir, sess.graph) self.valwriter = tf.summary.FileWriter(self.tbvaldir) return lr, train_op def find_previous(self): sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta') sfiles = glob.glob(sfiles) sfiles.sort(key=os.path.getmtime) # Get the snapshot name in TensorFlow redfiles = [] for stepsize in cfg.TRAIN.STEPSIZE: redfiles.append( os.path.join( self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}.ckpt.meta'.format(stepsize + 1))) sfiles = [ ss.replace('.meta', '') for ss in sfiles if ss not in redfiles ] nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl') nfiles = glob.glob(nfiles) nfiles.sort(key=os.path.getmtime) redfiles = [ redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles ] nfiles = [nn for nn in nfiles if nn not in redfiles] lsf = len(sfiles) assert len(nfiles) == lsf return lsf, nfiles, sfiles def restore_ckpt_from_dir(self, sess, net, checkpoint_dir): print("Restoring checkpoint from: " + checkpoint_dir) ckpt = tf.train.latest_checkpoint(checkpoint_dir) if ckpt is None: print("Checkpoint not found") exit(-1) meta_file = ckpt + '.meta' try: print('Restore variables from {}'.format(ckpt)) print('Restore meta_filr from {}'.format(meta_file)) saver = tf.train.Saver(net.variables_to_restore) saver.restore(sess, ckpt) except Exception: raise Exception("Can not restore from {}".format(checkpoint_dir)) def initialize(self, sess): # Initial file lists are empty np_paths = [] ss_paths = [] variables = tf.global_variables() # Initialize all variables first sess.run(tf.variables_initializer(variables, name='init')) if self.pretrained_model is not None: if self.pretrained_model.endswith('.ckpt'): # Fresh train directly from ImageNet weights print('Loading initial model weights from {:s}'.format( self.pretrained_model)) var_keep_dic = self.get_variables_in_checkpoint_file( self.pretrained_model) # Get the variables to restore, ignoring the variables to fix variables_to_restore = self.net.get_variables_to_restore( variables, var_keep_dic) restorer = tf.train.Saver(variables_to_restore) restorer.restore(sess, self.pretrained_model) print('Loaded.') else: # Restore from checkpoint and meta file self.restore_ckpt_from_dir(sess, self.net, self.pretrained_model) print('Loaded.') last_snapshot_iter = 0 rate = cfg.TRAIN.LEARNING_RATE stepsizes = list(cfg.TRAIN.STEPSIZE) return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths def restore(self, sess, sfile, nfile): # Get the most recent snapshot and restore np_paths = [nfile] ss_paths = [sfile] # Restore model from snapshots last_snapshot_iter = self.from_snapshot(sess, sfile, nfile) # Set the learning rate rate = cfg.TRAIN.LEARNING_RATE stepsizes = [] for stepsize in cfg.TRAIN.STEPSIZE: if last_snapshot_iter > stepsize: rate *= cfg.TRAIN.GAMMA else: stepsizes.append(stepsize) return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths def remove_snapshot(self, np_paths, ss_paths): to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): nfile = np_paths[0] os.remove(str(nfile)) np_paths.remove(nfile) to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT for c in range(to_remove): sfile = ss_paths[0] # To make the code compatible to earlier versions of Tensorflow, # where the naming tradition for checkpoints are different if os.path.exists(str(sfile)): os.remove(str(sfile)) else: os.remove(str(sfile + '.data-00000-of-00001')) os.remove(str(sfile + '.index')) sfile_meta = sfile + '.meta' os.remove(str(sfile_meta)) ss_paths.remove(sfile) def train_model(self, sess, max_iters): # Build data layers for both training and validation set self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes) self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True) # Construct the computation graph lr, train_op = self.construct_graph(sess) # Find previous snapshots if there is any to restore from lsf, nfiles, sfiles = self.find_previous() # Initialize the variables or restore them from the last snapshot if lsf == 0: rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize( sess) else: rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore( sess, str(sfiles[-1]), str(nfiles[-1])) timer = Timer() iter = last_snapshot_iter + 1 last_summary_time = time.time() # Make sure the lists are not empty stepsizes.append(max_iters) stepsizes.reverse() next_stepsize = stepsizes.pop() while iter < max_iters + 1: # Learning rate if iter == next_stepsize + 1: # Add snapshot here before reducing the learning rate self.snapshot(sess, iter) rate *= cfg.TRAIN.GAMMA sess.run(tf.assign(lr, rate)) next_stepsize = stepsizes.pop() timer.tic() # Get training data, one batch at a time blobs = self.data_layer.forward() now = time.time() if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL: # Compute the graph with summary rpn_loss_cls, rpn_loss_box, rpn_loss, total_loss, summary = \ self.net.train_step_with_summary(sess, blobs, train_op) self.writer.add_summary(summary, float(iter)) # Also check the summary on the validation set blobs_val = self.data_layer_val.forward() summary_val = self.net.get_summary(sess, blobs_val) self.valwriter.add_summary(summary_val, float(iter)) last_summary_time = now else: # Compute the graph without summary rpn_loss_cls, rpn_loss_box, rpn_loss, total_loss, _ = \ self.net.train_step(sess, blobs, train_op) timer.toc() print( '%d/%d time: %.3f total_loss: %.3f rpn_loss: %.3f rpn_loss_cls: %.3f ' 'rpn_loss_box: %.3f lr: %f' % (iter, max_iters, timer.diff, total_loss, rpn_loss, rpn_loss_cls, rpn_loss_box, lr.eval())) # Snapshotting if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: last_snapshot_iter = iter ss_path, np_path = self.snapshot(sess, iter) np_paths.append(np_path) ss_paths.append(ss_path) # Remove the old snapshots if there are too many if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT: self.remove_snapshot(np_paths, ss_paths) iter += 1 if last_snapshot_iter != iter - 1: self.snapshot(sess, iter - 1) self.writer.close() self.valwriter.close()