Beispiel #1
0
def load_database(args):
    print("Setting up image database: " + args.dataset)
    imdb = get_imdb(args.dataset)
    print('Loaded dataset `{:s}` for training'.format(imdb.name))
    roidb = get_training_roidb(imdb, args.use_flipped == "True")
    print('{:d} roidb entries'.format(len(roidb)))

    if args.dataset_validation != "no":
        print("Setting up validation image database: " +
              args.dataset_validation)
        imdb_val = get_imdb(args.dataset_validation)
        print('Loaded dataset `{:s}` for validation'.format(imdb_val.name))
        roidb_val = get_training_roidb(imdb_val, False)
        print('{:d} roidb entries'.format(len(roidb_val)))
    else:
        imdb_val = None
        roidb_val = None

    data_layer = RoIDataLayer(roidb, imdb.num_classes)

    if roidb_val is not None:
        data_layer_val = RoIDataLayer(roidb_val,
                                      imdb_val.num_classes,
                                      random=True)

    return imdb, roidb, imdb_val, roidb_val, data_layer, data_layer_val
Beispiel #2
0
def get_data_layer(roidb, num_classes):
    """return a data layer."""
    if cfg.TRAIN.HAS_RPN:
        if cfg.IS_MULTISCALE:
            layer = GtDataLayer(roidb)
        else:
            layer = RoIDataLayer(roidb, num_classes)
    else:
        layer = RoIDataLayer(roidb, num_classes)
    return layer
Beispiel #3
0
def get_data_layer(roidb, num_classes):
    """return a data layer."""
    if cfg.TRAIN.HAS_RPN:
        if cfg.IS_MULTISCALE:
            # obsolete
            # layer = GtDataLayer(roidb)
            raise "Calling caffe modules..."
        else:
            layer = RoIDataLayer(roidb, num_classes)
    else:
        layer = RoIDataLayer(roidb, num_classes)

    return layer
def get_data_layer(roidb, num_classes):
    """return a data layer."""
    # HAS_RPN = True when training
    if cfg.TRAIN.HAS_RPN:
        if cfg.IS_MULTISCALE:
            layer = GtDataLayer(roidb)
        else:
            layer = RoIDataLayer(roidb, num_classes)
    else:
        # layer._roidb = roidb
        # layer._num_classes = num_classes
        # layer._shuffle_roidb_inds()
        layer = RoIDataLayer(roidb, num_classes)

    return layer
Beispiel #5
0
def test_net(sess,
             net,
             imdb,
             weights_filename,
             max_per_image=100,
             thresh=0.05,
             roidb=None):
    np.random.seed(cfg.RNG_SEED)
    """Test a Fast R-CNN network on an image database."""
    num_images = len(imdb.image_index)
    # all detections are collected into:
    #  all_boxes[cls][image] = N x 5 array of detections in
    #  (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]

    output_dir = get_output_dir(imdb, weights_filename)
    # timers
    _t = {'im_detect': Timer(), 'misc': Timer()}

    data_layer = RoIDataLayer(roidb, imdb.num_classes, shuffle=False)
    for i in range(num_images):
        _t['im_detect'].tic()
        scores, boxes = im_detect(sess, net, None, data_layer=data_layer)
        _t['im_detect'].toc()

        _t['misc'].tic()

        # skip j = 0, because it's the background class
        for j in range(1, imdb.num_classes):
            inds = np.where(scores[:, j] > thresh)[0]
            cls_scores = scores[inds, j]
            cls_boxes = boxes[inds, j * 4:(j + 1) * 4]
            cls_dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])) \
              .astype(np.float32, copy=False)
            keep = nms(cls_dets, cfg.TEST.NMS)
            cls_dets = cls_dets[keep, :]
            all_boxes[j][i] = cls_dets

        # Limit to max_per_image detections *over all classes*
        if max_per_image > 0:
            image_scores = np.hstack(
                [all_boxes[j][i][:, -1] for j in range(1, imdb.num_classes)])
            if len(image_scores) > max_per_image:
                image_thresh = np.sort(image_scores)[-max_per_image]
                for j in range(1, imdb.num_classes):
                    keep = np.where(all_boxes[j][i][:, -1] >= image_thresh)[0]
                    all_boxes[j][i] = all_boxes[j][i][keep, :]
        _t['misc'].toc()

        print('im_detect: {:d}/{:d} {:.3f}s {:.3f}s' \
            .format(i + 1, num_images, _t['im_detect'].average_time,
                _t['misc'].average_time))

    det_file = os.path.join(output_dir, 'detections.pkl')
    with open(det_file, 'wb') as f:
        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)

    print('Evaluating detections')
    imdb.evaluate_detections(all_boxes, output_dir)
Beispiel #6
0
def get_data_layer(roidb, num_classes):
    if cfg.TRAIN.HAS_RPN:
        if cfg.IS_MULTISCALE:
            raise NotImplementedError(" error")
        else:
            layer = RoIDataLayer(roidb, num_classes)
    else:
        raise NotImplementedError(" error")
    return layer
Beispiel #7
0
    def __init__(self, roidb, net, freeze=0):
        # Holds current iteration number.
        self.iter = 0

        # How frequently we should print the training info.
        self.display_freq = 1

        # Holds the path prefix for snapshots.
        self.snapshot_prefix = 'snapshot'

        self.roidb = roidb
        self.net = net
        self.freeze = freeze
        self.roi_data_layer = RoIDataLayer()
        self.roi_data_layer.setup()
        self.roi_data_layer.set_roidb(self.roidb)
        self.stepfn = self.build_step_fn(self.net)
        self.predfn = self.build_pred_fn(self.net)
Beispiel #8
0
 def get_eval_summary(self, sess, num_entries=1e3):
     val_subset = self.valroidb[:num_entries]
     data_layer_val = RoIDataLayer(self.valroidb,
                                   self.imdb.num_classes,
                                   random=False)
     # parse and accumulate over epoch
     summaries = defaultdict(list)
     for _ in range(len(val_subset)):
         blobs_val = data_layer_val.forward()
         summary_val = self.net.get_summary(sess, blobs_val)
         summary_proto = tf.Summary()
         summary_proto.ParseFromString(summary_val)
         for val in summary_proto.value:
             # Assuming all summaries are scalars.
             summaries[val.tag].append(val.simple_value)
     # create a new epoch mean summary
     epoch_summary = tf.Summary()
     epoch_summary.CopyFrom(summary_proto)
     for val in epoch_summary.value:
         val.simple_value = np.nanmean(summaries[val.tag])
     return epoch_summary.SerializeToString()
    def __init__(self,
                 mode='train',
                 roidb=None,
                 augment_en=False,
                 num_classes=0,
                 Thread=Thread):

        if (mode == 'train'):
            self.data_layer = RoIDataLayer(roidb, num_classes, 'train')
        elif (mode == 'val'):
            self.data_layer = RoIDataLayer(roidb,
                                           num_classes,
                                           'val',
                                           random=True)
        self._queue = Queue(maxsize=8)
        #self._ptr_queue = Queue(maxsize=32)
        #self._perm_queue = Queue(maxsize=32)
        #self._v_queue = Queue(maxsize=32)
        self._daemon_en = Value('b', False)
        self._lock = Lock()
        self.finished = False
        self._augment_en = Value('b', augment_en)
        self._cur = 0
        self._queue_count = 0
        self._perm = []
        if (cfg.DEBUG.EN):
            self._proc = threading.Thread(
                name='{} data generator'.format(mode),
                target=self._run,
                args=((self._lock, self._queue, self.data_layer,
                       self._daemon_en, self._augment_en)))
        else:
            self._proc = Process(name='{} data generator'.format(mode),
                                 target=self._run,
                                 args=((self._lock, self._queue,
                                        self.data_layer, self._daemon_en,
                                        self._augment_en)))
Beispiel #10
0
  def __init__(self, roidb, net, freeze=0):
    # Holds current iteration number. 
    self.iter = 0

    # How frequently we should print the training info.
    self.display_freq = 1

    # Holds the path prefix for snapshots.
    self.snapshot_prefix = 'snapshot'
    
    self.roidb = roidb
    self.net = net
    self.freeze = freeze
    self.roi_data_layer = RoIDataLayer()
    self.roi_data_layer.setup()
    self.roi_data_layer.set_roidb(self.roidb)
    self.stepfn = self.build_step_fn(self.net)
    self.predfn = self.build_pred_fn(self.net)
class SolverWrapper(object):
    """
    A wrapper class for the training process
  """
    def __init__(self,
                 network,
                 imdb,
                 roidb,
                 imdb_T,
                 roidb_T,
                 valroidb,
                 output_dir,
                 tbdir,
                 pretrained_model=None):
        self.net = network
        self.imdb = imdb
        self.roidb = roidb
        self.imdb_T = imdb_T
        self.roidb_T = roidb_T
        self.valroidb = valroidb
        self.output_dir = output_dir
        self.tbdir = tbdir
        # Simply put '_val' at the end to save the summaries from the validation set
        self.tbvaldir = tbdir + '_val'
        if not os.path.exists(self.tbvaldir):
            os.makedirs(self.tbvaldir)
        self.pretrained_model = pretrained_model

    def snapshot(self, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.pth'
        filename = os.path.join(self.output_dir, filename)
        torch.save(self.net.state_dict(), filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indexes of the database
        perm = self.data_layer._perm
        # current position in the validation database
        cur_val = self.data_layer_val._cur
        # current shuffled indexes of the validation database
        perm_val = self.data_layer_val._perm
        # current position in the database
        curT = self.data_layer_T._cur
        # current shuffled indexes of the database
        permT = self.data_layer_T._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(curT, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(permT, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename

    def from_snapshot(self, sfile, nfile):
        print('Restoring model snapshots from {:s}'.format(sfile))
        self.net.load_state_dict(torch.load(str(sfile)))
        print('Restored.')
        # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
        # tried my best to find the random states so that it can be recovered exactly
        # However the Tensorflow state is currently not available
        with open(nfile, 'rb') as fid:
            st0 = pickle.load(fid)
            cur = pickle.load(fid)
            perm = pickle.load(fid)
            cur_val = pickle.load(fid)
            perm_val = pickle.load(fid)
            curT = pickle.load(fid)
            permT = pickle.load(fid)
            last_snapshot_iter = pickle.load(fid)

            np.random.set_state(st0)
            self.data_layer._cur = cur
            self.data_layer._perm = perm
            self.data_layer_val._cur = cur_val
            self.data_layer_val._perm = perm_val
            self.data_layer_T._cur = curT
            self.data_layer_T._perm = permT

        return last_snapshot_iter

    def construct_graph(self):
        # Set the random seed
        torch.backends.cudnn.deterministic = True
        torch.manual_seed(cfg.RNG_SEED)
        torch.cuda.manual_seed_all(cfg.RNG_SEED)
        # Build the main computation graph
        self.net.create_architecture(self.imdb.num_classes,
                                     tag='default',
                                     anchor_scales=cfg.ANCHOR_SCALES,
                                     anchor_ratios=cfg.ANCHOR_RATIOS)
        # Define the loss
        # loss = layers['total_loss']
        # Set learning rate and momentum
        lr = cfg.TRAIN.LEARNING_RATE
        params = []

        for key, value in dict(self.net.named_parameters()).items():
            if 'D_img' in key:
                # print(key)
                continue

            if value.requires_grad:
                # print(key)
                if 'bias' in key:
                    params += [{
                        'params': [value],
                        'lr': lr,
                        'weight_decay': cfg.TRAIN.WEIGHT_DECAY
                    }]
                else:
                    params += [{
                        'params': [value],
                        'lr': lr,
                        'weight_decay': cfg.TRAIN.WEIGHT_DECAY
                    }]
        self.optimizer = torch.optim.SGD(params, momentum=cfg.TRAIN.MOMENTUM)

        self.D_img_op = torch.optim.SGD(self.net.D_img.parameters(),
                                        lr=lr * cfg.D_lr_mult,
                                        momentum=cfg.TRAIN.MOMENTUM)

        # Write the train and validation information to tensorboard
        self.writer = tb.writer.FileWriter(self.tbdir)
        self.valwriter = tb.writer.FileWriter(self.tbvaldir)

        return lr, self.optimizer

    def find_previous(self):
        sfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pth')
        sfiles = glob.glob(sfiles)
        sfiles.sort(key=os.path.getmtime)
        # Get the snapshot name in pytorch
        redfiles = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            redfiles.append(
                os.path.join(
                    self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX +
                    '_iter_{:d}.pth'.format(stepsize + 1)))
        sfiles = [ss for ss in sfiles if ss not in redfiles]

        nfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
        nfiles = glob.glob(nfiles)
        nfiles.sort(key=os.path.getmtime)
        redfiles = [redfile.replace('.pth', '.pkl') for redfile in redfiles]
        nfiles = [nn for nn in nfiles if nn not in redfiles]

        lsf = len(sfiles)
        assert len(nfiles) == lsf

        return lsf, nfiles, sfiles

    def initialize(self):
        # Initial file lists are empty
        np_paths = []
        ss_paths = []
        # Fresh train directly from ImageNet weights
        print('Loading initial model weights from {:s}'.format(
            self.pretrained_model))
        self.net.load_pretrained_cnn(torch.load(self.pretrained_model))
        print('Loaded.')
        # Need to fix the variables before loading, so that the RGB weights are changed to BGR
        # For VGG16 it also changes the convolutional weights fc6 and fc7 to
        # fully connected weights
        last_snapshot_iter = 0
        lr = cfg.TRAIN.LEARNING_RATE
        stepsizes = list(cfg.TRAIN.STEPSIZE)

        return lr, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def restore(self, sfile, nfile):
        # Get the most recent snapshot and restore
        np_paths = [nfile]
        ss_paths = [sfile]
        # Restore model from snapshots
        last_snapshot_iter = self.from_snapshot(sfile, nfile)
        # Set the learning rate
        lr_scale = 1
        stepsizes = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            if last_snapshot_iter > stepsize:
                lr_scale *= cfg.TRAIN.GAMMA
            else:
                stepsizes.append(stepsize)
        scale_lr(self.optimizer, lr_scale)
        lr = cfg.TRAIN.LEARNING_RATE * lr_scale
        return lr, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def remove_snapshot(self, np_paths, ss_paths):
        to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            nfile = np_paths[0]
            os.remove(str(nfile))
            np_paths.remove(nfile)

        to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            sfile = ss_paths[0]
            # To make the code compatible to earlier versions of Tensorflow,
            # where the naming tradition for checkpoints are different
            os.remove(str(sfile))
            ss_paths.remove(sfile)

    def train_model(self, max_iters):
        MIN_TOTAT_LOSS = np.inf
        MIN_D_LOSS_T = np.inf
        BEST_ITER = None
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb,
                                           self.imdb.num_classes,
                                           random=True)
        self.data_layer_T = RoIDataLayer(self.roidb_T, self.imdb.num_classes)

        # Construct the computation graph
        lr, train_op = self.construct_graph()

        # Find previous snapshots if there is any to restore from
        lsf, nfiles, sfiles = self.find_previous()

        # Initialize the variables or restore them from the last snapshot
        if lsf == 0:
            lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(
            )
        else:
            lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(
                str(sfiles[-1]), str(nfiles[-1]))
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        # Make sure the lists are not empty
        stepsizes.append(max_iters)
        stepsizes.reverse()
        next_stepsize = stepsizes.pop()

        self.net.train()
        self.net.cuda()

        self.net.D_img.train()
        self.net.D_img.cuda()

        #self.net.D_img2.train()
        #self.net.D_img2.cuda()
        mywriter = tb.SummaryWriter()
        while iter < max_iters + 1:
            # Learning rate
            if iter == next_stepsize + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(iter)
                lr *= cfg.TRAIN.GAMMA
                scale_lr(self.optimizer, cfg.TRAIN.GAMMA)
                #scale_lr(self.D_img_op, cfg.TRAIN.GAMMA)
                next_stepsize = stepsizes.pop()

            utils.timer.timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()
            blobsT = self.data_layer_T.forward()
            # print("#########################:blobs['data_path'][0]",blobs['data_path'][0])
            # print("synth_weight:",imdb.D_T_score[os.path.basename(blobs['data_path'][0])])
            # break
            now = time.time()
            #if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
            if False:
                # Compute the graph with summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, D_inst_loss_S, D_img_loss_S, D_const_loss_S, D_inst_loss_T, D_img_loss_T, D_const_loss_T, summary = \
                  self.net.train_adapt_step_with_summary(blobs, blobsT, self.optimizer, self.D_inst_op, self.D_img_op)
                for _sum in summary:
                    self.writer.add_summary(_sum, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(blobs_val)
                for _sum in summary_val:
                    self.valwriter.add_summary(_sum, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                if 'train' in blobs['data_path'][0]:
                    synth_weight = self.imdb.D_T_score[os.path.basename(
                        blobs['data_path'][0])]
                else:
                    synth_weight = 1

                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, D_img_loss_S, D_img_loss_T = \
                    self.net.train_adapt_step_img(blobs, blobsT, self.optimizer, self.D_img_op, synth_weight)

            utils.timer.timer.toc()
            if (((loss_cls + loss_box) < MIN_TOTAT_LOSS)
                    and (D_img_loss_T < MIN_D_LOSS_T)):
                MIN_TOTAT_LOSS = loss_cls + loss_box
                MIN_D_LOSS_T = D_img_loss_T
                BEST_ITER = iter
                print("Curr MIN_TOTAT_LOSS=:{} and min_D_loss_T=:{}".format(
                    MIN_TOTAT_LOSS, MIN_D_LOSS_T))

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                mywriter.add_scalar("Total Loss", total_loss, iter)
                mywriter.add_scalar("cls Loss", loss_cls, iter)
                mywriter.add_scalar("Box loss", loss_box, iter)
                mywriter.add_scalar("D_img_loss_S", D_img_loss_S, iter)
                mywriter.add_scalar("D_img_loss_T", D_img_loss_T, iter)
                fp = open('training_log.txt', 'a+')
                print("Writing Training log")
                temp_log = str(iter) + ',' + str(total_loss) + ',' + str(
                    rpn_loss_cls) + ',' + str(rpn_loss_box) + ',' + str(
                        loss_cls) + ',' + str(loss_box) + ',' + str(
                            D_img_loss_S) + ',' + str(D_img_loss_T)
                fp.write(temp_log)
                fp.write('\n')
                fp.close()
                print("Done !!!!!!!!")
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n '
                      '>>> D_img_loss_S: %.6f\n >>> D_img_loss_T: %.6f\n '
                      '>>> lambda: %f >>> lr: %f ' % \
                      (iter, max_iters, total_loss, rpn_loss_cls, \
                        rpn_loss_box, loss_cls, loss_box, \
                        D_img_loss_S, D_img_loss_T, \
                        cfg.ADAPT_LAMBDA, lr))
                print('speed: {:.3f}s / iter'.format(
                    utils.timer.timer.average_time()))

                # for k in utils.timer.timer._average_time.keys():
                #   print(k, utils.timer.timer.average_time(k))

            # Snapshotting
            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                ss_path, np_path = self.snapshot(iter)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            iter += 1
        print("##################### Best iteration = ", BEST_ITER,
              "###################################")
        print("##################### Optimal Total Loss=:", MIN_TOTAT_LOSS,
              "###########################")
        print("##################### Optimal lD_LOSS_T=:", MIN_D_LOSS_T,
              "##############################")
        _, _ = self.snapshot(BEST_ITER)
        if last_snapshot_iter != iter - 1:
            self.snapshot(iter - 1)

        self.writer.close()
        self.valwriter.close()
Beispiel #12
0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


bbox_dist = np.load(osp.join(cfg.VG_DIR, cfg.TRAIN.BBOX_TARGET_NORMALIZATION_FILE), encoding='latin1').item()
bbox_means = bbox_dist['means']
bbox_stds = bbox_dist['stds']

cfg.TRAIN.USE_FLIPPED = False
imdb, roidb = combined_roidb('visual_genome_train_rel')
num_images = len(roidb)
data_layer = RoIDataLayer(imdb, roidb, bbox_means, bbox_stds)

epoch = 10
thresh = 0.8
fg_bg = AverageMeter()
print_freq = 100
for e in range(epoch):
    for i in range(num_images):
        blobs = data_layer.forward()
        predicates = blobs['predicates']
        rel_rois = blobs['rel_rois']
        fg_rel_inds = np.where(predicates)[0]
        bg_rel_inds = np.where(predicates==0)[0]
        fg_rel_rois = rel_rois[fg_rel_inds, 1:]
        bg_rel_rois = rel_rois[bg_rel_inds, 1:]
        fg_bg_overlaps = bbox_overlaps(fg_rel_rois, bg_rel_rois)
Beispiel #13
0
class MemorySolverWrapper(SolverWrapper):
    """
      A wrapper class for the training process of spatial memory
    """
    def construct_graph(self, sess):
        with sess.graph.as_default():
            # Set the random seed for tensorflow
            tf.set_random_seed(cfg.RNG_SEED)
            # Build the main computation graph
            layers = self.net.create_architecture('TRAIN',
                                                  self.imdb.num_classes,
                                                  self.imdb.num_predicates,
                                                  tag='default')
            # Define the loss
            loss = layers['total_loss']
            # Set learning rate and momentum
            lr = tf.Variable(cfg.TRAIN.RATE, trainable=False)
            self.optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)

            # Compute the gradients with regard to the loss
            gvs = self.optimizer.compute_gradients(loss)
            grad_summaries = []
            for grad, var in gvs:
                if 'SMN' not in var.name and 'GMN' not in var.name:
                    continue
                grad_summaries.append(
                    tf.summary.histogram('TRAIN/' + var.name, var))
                if grad is not None:
                    grad_summaries.append(
                        tf.summary.histogram('GRAD/' + var.name, grad))

            # Double the gradient of the bias if set
            if cfg.TRAIN.DOUBLE_BIAS:
                final_gvs = []
                with tf.variable_scope('Gradient_Mult') as scope:
                    for grad, var in gvs:
                        scale = 1.
                        if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((grad, var))
                train_op = self.optimizer.apply_gradients(final_gvs)
            else:
                train_op = self.optimizer.apply_gradients(gvs)
            self.summary_grads = tf.summary.merge(grad_summaries)

            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=100000)
            # Write the train and validation information to tensorboard
            self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
            self.valwriter = tf.summary.FileWriter(self.tbvaldir)

        return lr, train_op

    def train_model(self, sess, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.imdb, self.roidb, self.bbox_means,
                                       self.bbox_stds)
        self.data_layer_val = RoIDataLayer(self.valimdb,
                                           self.valroidb,
                                           self.bbox_means,
                                           self.bbox_stds,
                                           random=True)

        # Construct the computation graph
        lr, train_op = self.construct_graph(sess)

        # Find previous snapshots if there is any to restore from
        lsf, nfiles, sfiles = self.find_previous()

        # Initialize the variables or restore them from the last snapshot
        if lsf == 0:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(
                sess)
        else:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(
                sess, str(sfiles[-1]), str(nfiles[-1]))
        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_iter = iter
        last_summary_time = time.time()
        # Make sure the lists are not empty
        stepsizes.append(max_iters)
        stepsizes.reverse()
        next_stepsize = stepsizes.pop()
        while iter < max_iters + 1:
            # Learning rate
            if iter == next_stepsize + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(sess, iter)
                rate *= cfg.TRAIN.GAMMA
                sess.run(tf.assign(lr, rate))
                next_stepsize = stepsizes.pop()

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            now = time.time()
            if iter == 1 or \
                    (now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL and \
                                     iter - last_summary_iter > cfg.TRAIN.SUMMARY_ITERS):
                # Compute the graph with summary
                # loss_cls, loss_bbox, loss_rel, loss_tag, total_loss, summary, gsummary = \
                #    self.net.train_step_with_summary(sess, blobs, train_op, self.summary_grads)
                loss_cls, loss_bbox, loss_rel, loss_tag, total_loss = self.net.train_step(
                    sess, blobs, train_op)
                # self.writer.add_summary(summary, float(iter))
                # self.writer.add_summary(gsummary, float(iter + 1))
                # Also check the summary on the validation set
                # blobs_val = self.data_layer_val.forward()
                # summary_val = self.net.get_summary(sess, blobs_val)
                # self.valwriter.add_summary(summary_val, float(iter))
                # last_summary_iter = iter
                # last_summary_time = now
            else:
                # Compute the graph without summary
                loss_cls, loss_bbox, loss_rel, loss_tag, total_loss = self.net.train_step(
                    sess, blobs, train_op)
            timer.toc()

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                print(
                    'iter: %d / %d, total loss: %.6f\n >>> loss_cls: %.6f\n >>> loss_bbox: %.6f\n '
                    '>>> loss_rel: %.6f\n >>> loss_tag: %.6f\n >>> lr: %f' % \
                    (iter, max_iters, total_loss, loss_cls, loss_bbox, loss_rel, loss_tag, lr.eval()))
                print('speed: {:.3f}s / iter'.format(timer.average_time))

            # Snapshotting
            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                ss_path, np_path = self.snapshot(sess, iter)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(sess, iter - 1)

        self.writer.close()
        self.valwriter.close()
Beispiel #14
0
    def train_model(self, sess, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb,
                                           self.imdb.num_classes,
                                           random=True)

        # Construct the computation graph
        lr, train_op = self.construct_graph(sess)

        # Find previous snapshots if there is any to restore from
        lsf, nfiles, sfiles = self.find_previous()

        # Initialize the variables or restore them from the last snapshot
        if lsf == 0:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(
                sess)
        else:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(
                sess, str(sfiles[-1]), str(nfiles[-1]))
        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        # Make sure the lists are not empty
        stepsizes.append(max_iters)
        stepsizes.reverse()
        next_stepsize = stepsizes.pop()
        while iter < max_iters + 1:
            # Learning rate
            if iter == next_stepsize + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(sess, iter)
                rate *= cfg.TRAIN.GAMMA
                sess.run(tf.assign(lr, rate))
                next_stepsize = stepsizes.pop()

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            now = time.time()
            if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
                  self.net.train_step_with_summary(sess, blobs, train_op)
                self.writer.add_summary(summary, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(sess, blobs_val)
                self.valwriter.add_summary(summary_val, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
                  self.net.train_step(sess, blobs, train_op)
            timer.toc()

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
                      (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
                print('speed: {:.3f}s / iter'.format(timer.average_time))

            # Snapshotting
            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                ss_path, np_path = self.snapshot(sess, iter)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(sess, iter - 1)

        self.writer.close()
        self.valwriter.close()
Beispiel #15
0
class Solver(object):
  def __init__(self, roidb, net, freeze=0):
    # Holds current iteration number. 
    self.iter = 0

    # How frequently we should print the training info.
    self.display_freq = 1

    # Holds the path prefix for snapshots.
    self.snapshot_prefix = 'snapshot'
    
    self.roidb = roidb
    self.net = net
    self.freeze = freeze
    self.roi_data_layer = RoIDataLayer()
    self.roi_data_layer.setup()
    self.roi_data_layer.set_roidb(self.roidb)
    self.stepfn = self.build_step_fn(self.net)
    self.predfn = self.build_pred_fn(self.net)

  # This might be a useful static method to have.
  #@staticmethod not so static anymore
  def build_step_fn(self, net):
    target_y = T.vector("target Y",dtype='int64')
    tl = lasagne.objectives.categorical_crossentropy(net.prediction,target_y)
    loss = tl.mean()
    accuracy = lasagne.objectives.categorical_accuracy(net.prediction,target_y).mean()
   
    weights = net.params
    grads = theano.grad(loss, weights)
    
    scales = np.ones(len(weights))

    if self.freeze:
        scales[:-self.freeze] = 0
        
    print 'GRAD SCALE >>>', scales
    
    for idx, param in enumerate(weights):
        grads[idx] *= scales[idx]
        grads[idx] = grads[idx].astype('float32')
        
    #updates_sgd = lasagne.updates.sgd(loss, net.params, learning_rate=0.0001)
    updates_sgd = lasagne.updates.sgd(grads, net.params, learning_rate=0.0001)
    
    stepfn = theano.function([net.inp, target_y], [loss, accuracy], updates=updates_sgd, allow_input_downcast=True)
    return stepfn

  @staticmethod
  def build_pred_fn(net):
    predfn = theano.function([net.inp], net.prediction, allow_input_downcast=True)
    return predfn

  def get_training_batch(self):
    """Uses ROIDataLayer to fetch a training batch.

    Returns:
      input_data (ndarray): input data suitable for R-CNN processing
      labels (ndarray): batch labels (of type int32)
    """
    data, rois, labels = deepcopy(self.roi_data_layer.top[: 3])
    X = roi_layer(data, rois)
    y = labels.astype('int')
    
    return X, y

  def step(self):
    self.roi_data_layer.forward()
    data, labels = self.get_training_batch()
    """Conducts a single step of SGD."""
    
    loss, acc = self.stepfn(data, labels)
    
    self.loss = loss
    self.acc = acc
    ###################################################### Your code goes here.
    # Among other things, assign the current loss value to self.loss.

    self.iter += 1
    if self.iter % self.display_freq == 0:
      print 'Iteration {:<5} Train loss: {} Train acc: {}'.format(self.iter, self.loss, self.acc)

  def save(self, filename):
    self.net.save(filename)
Beispiel #16
0
class SolverWrapper(object):
  """
    A wrapper class for the training process
  """

  def __init__(self, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None):
    self.net = network
    self.imdb = imdb
    self.roidb = roidb
    self.valroidb = valroidb
    self.output_dir = output_dir
    self.tbdir = tbdir
    # Simply put '_val' at the end to save the summaries from the validation set
    self.tbvaldir = tbdir + '_val'
    if not os.path.exists(self.tbvaldir):
      os.makedirs(self.tbvaldir)
    self.pretrained_model = pretrained_model

  def snapshot(self, iter):
    net = self.net

    if not os.path.exists(self.output_dir):
      os.makedirs(self.output_dir)

    # Store the model snapshot
    filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pth'
    filename = os.path.join(self.output_dir, filename)
    torch.save(self.net.state_dict(), filename)
    print('Wrote snapshot to: {:s}'.format(filename))

    # Also store some meta information, random state, etc.
    nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
    nfilename = os.path.join(self.output_dir, nfilename)
    # current state of numpy random
    st0 = np.random.get_state()
    # current position in the database
    cur = self.data_layer._cur
    # current shuffled indexes of the database
    perm = self.data_layer._perm
    # current position in the validation database
    cur_val = self.data_layer_val._cur
    # current shuffled indexes of the validation database
    perm_val = self.data_layer_val._perm

    # Dump the meta info
    with open(nfilename, 'wb') as fid:
      pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

    return filename, nfilename

  def from_snapshot(self, sfile, nfile):
    print('Restoring model snapshots from {:s}'.format(sfile))
    self.net.load_state_dict(torch.load(str(sfile)))
    print('Restored.')
    # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
    # tried my best to find the random states so that it can be recovered exactly
    # However the Tensorflow state is currently not available
    with open(nfile, 'rb') as fid:
      st0 = pickle.load(fid)
      cur = pickle.load(fid)
      perm = pickle.load(fid)
      cur_val = pickle.load(fid)
      perm_val = pickle.load(fid)
      last_snapshot_iter = pickle.load(fid)

      np.random.set_state(st0)
      self.data_layer._cur = cur
      self.data_layer._perm = perm
      self.data_layer_val._cur = cur_val
      self.data_layer_val._perm = perm_val

    return last_snapshot_iter

  def construct_graph(self):
    # Set the random seed
    torch.manual_seed(cfg.RNG_SEED)
    # Build the main computation graph
    self.net.create_architecture(self.imdb.num_classes, tag='default',
                                            anchor_scales=cfg.ANCHOR_SCALES,
                                            anchor_ratios=cfg.ANCHOR_RATIOS)
    # Define the loss
    # loss = layers['total_loss']
    # Set learning rate and momentum
    lr = cfg.TRAIN.LEARNING_RATE
    params = []
    for key, value in dict(self.net.named_parameters()).items():
      if value.requires_grad:
        if 'bias' in key:
          params += [{'params':[value],'lr':lr*(cfg.TRAIN.DOUBLE_BIAS + 1), 'weight_decay': cfg.TRAIN.BIAS_DECAY and cfg.TRAIN.WEIGHT_DECAY or 0}]
        else:
          params += [{'params':[value],'lr':lr, 'weight_decay': cfg.TRAIN.WEIGHT_DECAY}]
    self.optimizer = torch.optim.SGD(params, momentum=cfg.TRAIN.MOMENTUM)
    # Write the train and validation information to tensorboard
    self.writer = tb.writer.FileWriter(self.tbdir)
    self.valwriter = tb.writer.FileWriter(self.tbvaldir)

    return lr, self.optimizer

  def find_previous(self):
    sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pth')
    sfiles = glob.glob(sfiles)
    sfiles.sort(key=os.path.getmtime)
    # Get the snapshot name in pytorch
    redfiles = []
    for stepsize in cfg.TRAIN.STEPSIZE:
      redfiles.append(os.path.join(self.output_dir, 
                      cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}.pth'.format(stepsize+1)))
    sfiles = [ss for ss in sfiles if ss not in redfiles]

    nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
    nfiles = glob.glob(nfiles)
    nfiles.sort(key=os.path.getmtime)
    redfiles = [redfile.replace('.pth', '.pkl') for redfile in redfiles]
    nfiles = [nn for nn in nfiles if nn not in redfiles]

    lsf = len(sfiles)
    assert len(nfiles) == lsf

    return lsf, nfiles, sfiles

  def initialize(self):
    # Initial file lists are empty
    np_paths = []
    ss_paths = []
    # Fresh train directly from ImageNet weights
    print('Loading initial model weights from {:s}'.format(self.pretrained_model))
    self.net.load_pretrained_cnn(torch.load(self.pretrained_model))
    print('Loaded.')
    # Need to fix the variables before loading, so that the RGB weights are changed to BGR
    # For VGG16 it also changes the convolutional weights fc6 and fc7 to
    # fully connected weights
    last_snapshot_iter = 0
    lr = cfg.TRAIN.LEARNING_RATE
    stepsizes = list(cfg.TRAIN.STEPSIZE)

    return lr, last_snapshot_iter, stepsizes, np_paths, ss_paths

  def restore(self, sfile, nfile):
    # Get the most recent snapshot and restore
    np_paths = [nfile]
    ss_paths = [sfile]
    # Restore model from snapshots
    last_snapshot_iter = self.from_snapshot(sfile, nfile)
    # Set the learning rate
    lr_scale = 1
    stepsizes = []
    for stepsize in cfg.TRAIN.STEPSIZE:
      if last_snapshot_iter > stepsize:
        lr_scale *= cfg.TRAIN.GAMMA
      else:
        stepsizes.append(stepsize)
    scale_lr(self.optimizer, lr_scale)
    lr = cfg.TRAIN.LEARNING_RATE * lr_scale
    return lr, last_snapshot_iter, stepsizes, np_paths, ss_paths

  def remove_snapshot(self, np_paths, ss_paths):
    to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
    for c in range(to_remove):
      nfile = np_paths[0]
      os.remove(str(nfile))
      np_paths.remove(nfile)

    to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
    for c in range(to_remove):
      sfile = ss_paths[0]
      # To make the code compatible to earlier versions of Tensorflow,
      # where the naming tradition for checkpoints are different
      os.remove(str(sfile))
      ss_paths.remove(sfile)

  def train_model(self, max_iters):
    # Build data layers for both training and validation set
    self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
    self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)

    # Construct the computation graph
    lr, train_op = self.construct_graph()

    # Find previous snapshots if there is any to restore from
    lsf, nfiles, sfiles = self.find_previous()

    # Initialize the variables or restore them from the last snapshot
    if lsf == 0:
      lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize()
    else:
      lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(str(sfiles[-1]), 
                                                                             str(nfiles[-1]))
    iter = last_snapshot_iter + 1
    last_summary_time = time.time()
    # Make sure the lists are not empty
    stepsizes.append(max_iters)
    stepsizes.reverse()
    next_stepsize = stepsizes.pop()

    self.net.train()
    self.net.cuda()

    while iter < max_iters + 1:
      # Learning rate
      if iter == next_stepsize + 1:
        # Add snapshot here before reducing the learning rate
        self.snapshot(iter)
        lr *= cfg.TRAIN.GAMMA
        scale_lr(self.optimizer, cfg.TRAIN.GAMMA)
        next_stepsize = stepsizes.pop()

      utils.timer.timer.tic()
      # Get training data, one batch at a time
      blobs = self.data_layer.forward()

      now = time.time()
      if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
        # Compute the graph with summary
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
          self.net.train_step_with_summary(blobs, self.optimizer)
        for _sum in summary: self.writer.add_summary(_sum, float(iter))
        # Also check the summary on the validation set
        blobs_val = self.data_layer_val.forward()
        summary_val = self.net.get_summary(blobs_val)
        for _sum in summary_val: self.valwriter.add_summary(_sum, float(iter))
        last_summary_time = now
      else:
        # Compute the graph without summary
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
          self.net.train_step(blobs, self.optimizer)
      utils.timer.timer.toc()

      # Display training information
      if iter % (cfg.TRAIN.DISPLAY) == 0:
        print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
              '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
              (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr))
        print('speed: {:.3f}s / iter'.format(utils.timer.timer.average_time()))

        # for k in utils.timer.timer._average_time.keys():
        #   print(k, utils.timer.timer.average_time(k))

      # Snapshotting
      if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
        last_snapshot_iter = iter
        ss_path, np_path = self.snapshot(iter)
        np_paths.append(np_path)
        ss_paths.append(ss_path)

        # Remove the old snapshots if there are too many
        if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
          self.remove_snapshot(np_paths, ss_paths)

      iter += 1

    if last_snapshot_iter != iter - 1:
      self.snapshot(iter - 1)

    self.writer.close()
    self.valwriter.close()
Beispiel #17
0
    def train_model(self, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb,
                                           self.imdb.num_classes,
                                           random=True)

        # Construct the computation graph
        lr, train_op = self.construct_graph()

        # Find previous snapshots if there is any to restore from
        lsf, nfiles, sfiles = self.find_previous()

        # Initialize the variables or restore them from the last snapshot
        if lsf == 0:
            lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(
            )
        else:
            lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(
                str(sfiles[-1]), str(nfiles[-1]))
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        # Make sure the lists are not empty
        stepsizes.append(max_iters)
        stepsizes.reverse()
        next_stepsize = stepsizes.pop()

        self.net.train()
        self.net.to(self.net._device)

        while iter < max_iters + 1:
            # Learning rate
            if iter == next_stepsize + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(iter)
                lr *= cfg.TRAIN.GAMMA
                scale_lr(self.optimizer, cfg.TRAIN.GAMMA)
                next_stepsize = stepsizes.pop()
            #if ((iter -1) % cfg.TRAIN.MIL_RECURRENT_STEP) == 0:
            #  num_epoch = int((iter - 1) / cfg.TRAIN.MIL_RECURRENT_STEP) + 1
            #  cfg.TRAIN.MIL_RECURRECT_WEIGHT = ((num_epoch - 1)/20.0)/1.5
            #if iter == cfg.TRAIN.MIL_RECURRENT_STEP + 1:
            #  cfg.TRAIN.MIL_RECURRECT_WEIGHT = cfg.TRAIN.MIL_RECURRECT_WEIGHT * 10

            utils.timer.timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            now = time.time()
            if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                cls_det_loss, refine_loss_1, refine_loss_2, total_loss, summary = \
                  self.net.train_step_with_summary(blobs, self.optimizer)
                for _sum in summary:
                    self.writer.add_summary(_sum, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(blobs_val)
                for _sum in summary_val:
                    self.valwriter.add_summary(_sum, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                cls_det_loss, refine_loss_1, refine_loss_2, total_loss = self.net.train_step(
                    blobs, self.optimizer)
            utils.timer.timer.toc()

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                print('iter: %d / %d, total loss: %.6f\n >>> cls_det_loss: %.6f\n '
                      '>>> refine_loss_1: %.6f\n >>> refine_loss_2: %.6f\n >>> lr: %f' % \
                      (iter, max_iters, total_loss, cls_det_loss, refine_loss_1, refine_loss_2, lr))
                print('speed: {:.3f}s / iter'.format(
                    utils.timer.timer.average_time()))

                # for k in utils.timer.timer._average_time.keys():
                #   print(k, utils.timer.timer.average_time(k))

            # Snapshotting
            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                ss_path, np_path = self.snapshot(iter)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(iter - 1)

        self.writer.close()
        self.valwriter.close()
Beispiel #18
0
  def train_model(self, sess, max_iters, mode):
    # Build data layers for both training and validation set
    self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
    self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes)

    # Construct the computation graph
    lr, train_op = self.construct_graph(sess, mode)

    total_parameters = 0
    for variable in tf.global_variables():
      if "BatchNorm" not in variable.name:
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
          variable_parameters *= dim.value
        total_parameters += variable_parameters

    print('Total params: %.2fM' % (total_parameters / 1000000.0))

    # Find previous snapshots if there is any to restore from
    lsf, nfiles, sfiles = self.find_previous()

    # Initialize the variables or restore them from the last snapshot
    if lsf == 0:
      rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(sess)
    else:
      sess.run(tf.variables_initializer(tf.global_variables(), name='init'))
      rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(sess, 
                                                                            str(sfiles[-1]), 
                                                                            str(nfiles[-1]))
    timer = Timer()
    iter = last_snapshot_iter + 1
    last_summary_time = time.time()
    # Make sure the lists are not empty
    stepsizes.append(max_iters)
    stepsizes.reverse()
    next_stepsize = stepsizes.pop()
    while iter < max_iters + 1:
      # Learning rate
      if iter == next_stepsize + 1:
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        rate *= cfg.TRAIN.GAMMA
        sess.run(tf.assign(lr, rate))
        next_stepsize = stepsizes.pop()

      timer.tic()
      # Get training data, one batch at a time
      blobs = self.data_layer.forward()

      now = time.time()
      if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
        # Compute the graph with summary
        out_blob, sim_train, qual_train = \
          self.net.train_step_with_summary(sess, blobs, train_op)
        if self.frcnn_training:
            self.writer.add_summary(out_blob['summary'], float(iter))
        elif self.quality_training:
          if qual_train:
            self.writer.add_summary(out_blob['summary'], float(iter))
        elif self.similarity_training:
          if sim_train:
            self.writer.add_summary(out_blob['summary'], float(iter))
        # Also check the summary on the validation set
        blobs_val = self.data_layer_val.forward()
        summary_val = self.net.get_summary(sess, blobs_val)
        self.valwriter.add_summary(summary_val, float(iter))
        last_summary_time = now
      else:
        # Compute the graph without summary
        out_blob, sim_train, qual_train = \
          self.net.train_step(sess, blobs, train_op)
      timer.toc()

      # Display training information
      if iter % (cfg.TRAIN.DISPLAY) == 0:
        print ('iter: %d / %d,' % (iter, max_iters))
        if sim_train:
          print(' >>> ID_loss: %.6f' % (out_blob['ID_loss']))
        if qual_train:
          try:
            print(' >>> SS_loss: %.6f' % (out_blob['SS_loss']))
          except:
            pass
        if self.frcnn_training or self.quality_training:
          print(' >>> rpn_loss_cls: %.6f\n >>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f' %
                (out_blob['rpn_loss_cls'], out_blob['rpn_loss_box'], out_blob['loss_cls'], out_blob['loss_box']))
        print (' >>> lr: %f' % (lr.eval()))
        print('speed: {:.3f}s / iter'.format(timer.average_time))
      # Snapshotting
      if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
        last_snapshot_iter = iter
        ss_path, np_path = self.snapshot(sess, iter)
        np_paths.append(np_path)
        ss_paths.append(ss_path)

        # Remove the old snapshots if there are too many
        if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
          self.remove_snapshot(np_paths, ss_paths)

      iter += 1

    if last_snapshot_iter != iter - 1:
      self.snapshot(sess, iter - 1)

    self.writer.close()
    self.valwriter.close()
  def train_model(self, sess, max_iters):
    # Build data layers for both training and validation set
    self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
    self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)

    # Determine different scales for anchors, see paper
    with sess.graph.as_default():
      # Set the random seed for tensorflow
      tf.set_random_seed(cfg.RNG_SEED)
      # Build the main computation graph
      rpn_layers, proposal_targets = self.net.create_architecture(sess, 'TRAIN', self.imdb.num_classes, scope='rpn_network', tag='default',
                                            anchor_scales=cfg.ANCHOR_SCALES,
                                            anchor_ratios=cfg.ANCHOR_RATIOS)
      rfcn_layers, _ = self.rfcn_network.create_architecture(sess, 'TRAIN', self.imdb.num_classes, scope='rfcn_network', tag='default',
                                            anchor_scales=cfg.ANCHOR_SCALES,
                                            anchor_ratios=cfg.ANCHOR_RATIOS,
                                            input_rois=rpn_layers['rois'],
                                            roi_scores=rpn_layers['roi_scores'],
                                            proposal_targets=proposal_targets)

      # Define the loss
      rpn_loss = rpn_layers['rpn_loss']
      rfcn_loss = rfcn_layers['rfcn_loss']
      rpn_rfcn_loss = rpn_layers['rfcn_loss']
      # Set learning rate and momentum
      lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
      momentum = cfg.TRAIN.MOMENTUM
      self.optimizer = tf.train.MomentumOptimizer(lr, momentum)

      # Compute the gradients wrt the loss
      rpn_trainable_variables_stage1 = self.net.get_train_variables('rpn_network')
      rfcn_trainable_variables_stage2 = self.rfcn_network.get_train_variables('rfcn_network')
      rpn_trainable_variables_stage3 = self.net.get_train_variables_stage3('rpn_network')
      rpn_trainable_variables_stage4 = self.net.get_train_variables_stage4('rpn_network')
      gvs_rpn_stage1 = self.optimizer.compute_gradients(rpn_loss, rpn_trainable_variables_stage1)
      gvs_rfcn_stage2 = self.optimizer.compute_gradients(rfcn_loss, rfcn_trainable_variables_stage2)
      gvs_rpn_stage3 = self.optimizer.compute_gradients(rpn_loss, rpn_trainable_variables_stage3)
      gvs_rfcn_stage4 = self.optimizer.compute_gradients(rpn_rfcn_loss, rpn_trainable_variables_stage4)

      train_op_stage1 = self.optimizer.apply_gradients(gvs_rpn_stage1)
      train_op_stage2 = self.optimizer.apply_gradients(gvs_rfcn_stage2)
      train_op_stage3 = self.optimizer.apply_gradients(gvs_rpn_stage3)
      train_op_stage4 = self.optimizer.apply_gradients(gvs_rfcn_stage4)

      # We will handle the snapshots ourselves
      self.saver = tf.train.Saver(max_to_keep=1000000)
      # Write the train and validation information to tensorboard
      self.writer_stage1 = tf.summary.FileWriter(self.tbdir+'/stage1', sess.graph)
      self.writer_stage2 = tf.summary.FileWriter(self.tbdir+'/stage2', sess.graph)
      self.writer_stage3 = tf.summary.FileWriter(self.tbdir+'/stage3', sess.graph)
      self.writer_stage4 = tf.summary.FileWriter(self.tbdir+'/stage4', sess.graph)
      self.valwriter_stage1 = tf.summary.FileWriter(self.tbvaldir+'/stage1')
      self.valwriter_stage2 = tf.summary.FileWriter(self.tbvaldir+'/stage2')
      self.valwriter_stage3 = tf.summary.FileWriter(self.tbvaldir+'/stage3')
      self.valwriter_stage4 = tf.summary.FileWriter(self.tbvaldir+'/stage4')

    # Find previous snapshots if there is any to restore from
    sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
    sfiles = glob.glob(sfiles)
    sfiles.sort(key=os.path.getmtime)
    # Get the snapshot name in TensorFlow
    redstr = '_iter_{:d}.'.format(cfg.TRAIN.STEPSIZE+1)
    sfiles = [ss.replace('.meta', '') for ss in sfiles]
    sfiles = [ss for ss in sfiles if redstr not in ss]

    nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
    nfiles = glob.glob(nfiles)
    nfiles.sort(key=os.path.getmtime)
    nfiles = [nn for nn in nfiles if redstr not in nn]

    lsf = len(sfiles)
    assert len(nfiles) == lsf

    np_paths = nfiles
    ss_paths = sfiles

    if lsf == 0:
      # Fresh train directly from ImageNet weights
      print('Loading initial model weights from {:s}'.format(self.pretrained_model))
      variables = tf.global_variables()
      # Initialize all variables first
      sess.run(tf.variables_initializer(variables, name='init'))
      var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model)
      # Get the variables to restore, ignorizing the variables to fix
      variables_to_restore_rpn, variables_to_restore_rfcn = self.net.get_variables_to_restore(variables, var_keep_dic)

      self.saver_rpn = tf.train.Saver(variables_to_restore_rpn)
      self.saver_rpn.restore(sess, self.pretrained_model)

      self.saver_rfcn = tf.train.Saver(variables_to_restore_rfcn)
      self.saver_rfcn.restore(sess, self.pretrained_model)

      print('Loaded.')
      # Need to fix the variables before loading, so that the RGB weights are changed to BGR
      # For VGG16 it also changes the convolutional weights fc6 and fc7 to
      # fully connected weights
      self.net.fix_variables(sess, self.pretrained_model)
      print('Fixed.')
      sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))
      last_snapshot_iter = 0
    else:
      # Get the most recent snapshot and restore
      ss_paths = [ss_paths[-1]]
      np_paths = [np_paths[-1]]

      print('Restorining model snapshots from {:s}'.format(sfiles[-1]))
      self.saver.restore(sess, str(sfiles[-1]))
      print('Restored.')
      # Needs to restore the other hyperparameters/states for training, (TODO xinlei) I have
      # tried my best to find the random states so that it can be recovered exactly
      # However the Tensorflow state is currently not available
      with open(str(nfiles[-1]), 'rb') as fid:
        st0 = pickle.load(fid)
        cur = pickle.load(fid)
        perm = pickle.load(fid)
        cur_val = pickle.load(fid)
        perm_val = pickle.load(fid)
        last_snapshot_iter = pickle.load(fid)

        np.random.set_state(st0)
        self.data_layer._cur = cur
        self.data_layer._perm = perm
        self.data_layer_val._cur = cur_val
        self.data_layer_val._perm = perm_val

        # Set the learning rate, only reduce once
        if last_snapshot_iter > cfg.TRAIN.STEPSIZE:
          sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA))
        else:
          sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))

    timer = Timer()
    iter = last_snapshot_iter + 1
    last_summary_time = time.time()
    stage_infor = ''
    # while iter < max_iters + 1:
    while iter < 200001:
    # while iter < 201:
      # Learning rate
      if iter == 80001:
      # if iter == 81:
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        sess.run(tf.assign(lr, 0.001))
      # Get training data, one batch at a time
      blobs = self.data_layer.forward()
      # stage 1  training rpn layers and backbones  in rpn network
      if iter < 80001:
      # if iter < 81:
        stage_infor = 'stage1'
        if iter == 60001:
        # if iter == 61:
          sess.run(tf.assign(lr, 0.0001))
        timer.tic()
        now = time.time()
        if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
          # Compute the graph with summary
          rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
            self.net.train_rpn_step_with_summary(sess, blobs, train_op_stage1)
          self.writer_stage1.add_summary(summary, float(iter))
          # Also check the summary on the validation set
          blobs_val = self.data_layer_val.forward()
          summary_val = self.net.get_summary(sess, blobs_val)
          self.valwriter_stage1.add_summary(summary_val, float(iter))
          last_summary_time = now
        else:
          # Compute the graph without summary
          rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
            self.net.train_rpn_step(sess, blobs, train_op_stage1)
        timer.toc()
        if iter == 80000:
          self.writer_stage1.close()
          self.valwriter_stage1.close()
      # stage 2 training rfcn layers and backbones  in rfcn network
      elif 80001 <= iter < 200001:
      # elif 81 <= iter < 201:
        stage_infor = 'stage2'
        rpn_loss_cls = 0
        rpn_loss_box = 0
        if iter == 160001:
        # if iter == 161:
          self.snapshot(sess, iter)
          sess.run(tf.assign(lr, 0.0001))
        timer.tic()
        now = time.time()
        if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
          # Compute the graph with summary
          loss_cls, loss_box, total_loss, summary = \
            self.rfcn_network.train_rfcn_step_with_summary_stage2(sess, blobs, train_op_stage2, self.net)
          self.writer_stage2.add_summary(summary, float(iter))
          # Also check the summary on the validation set
          blobs_val = self.data_layer_val.forward()
          summary_val = self.rfcn_network.get_summary_stage2(sess, blobs_val, self.net)
          self.valwriter_stage2.add_summary(summary_val, float(iter))
          last_summary_time = now
        else:
          # Compute the graph without summary
          loss_cls, loss_box, total_loss = \
            self.rfcn_network.train_rfcn_step_stage2(sess, blobs, train_op_stage2, self.net)
        timer.toc()
        if iter == 200000:
          self.writer_stage2.close()
          self.valwriter_stage2.close()
      else:
        raise ValueError('illeagle input iter value')

      # Display training information
      # if iter % (cfg.TRAIN.DISPLAY) == 0:
      if iter % 20 == 0:
        print('iter: %d / %d, stage: %s, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
              '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
              (iter, max_iters, stage_infor, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
        print('speed: {:.3f}s / iter'.format(timer.average_time))

      if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
        last_snapshot_iter = iter
        snapshot_path, np_path = self.snapshot(sess, iter)
        np_paths.append(np_path)
        ss_paths.append(snapshot_path)
      iter += 1

    if iter <= 200001:
    # if iter <= 201:
      ###############################################
      #####  merge rfcn_network to rpn_network  #####
      ###############################################
      merged_ops = merged_networks_rfcn2rpn('rfcn_network', 'rpn_network')
      with tf.variable_scope('rpn_network', reuse=True):
        rpn_conv1 = tf.get_variable('resnet_v1_101/block2/unit_1/bottleneck_v1/shortcut/weights')
        rpn_conv2 = tf.get_variable('resnet_v1_101/refined_reduce_depth/weights')
        rpn_conv3 = tf.get_variable('resnet_v1_101/block4/unit_3/bottleneck_v1/conv3/weights')
        rpn_conv4 = tf.get_variable('resnet_v1_101/block3/unit_18/bottleneck_v1/conv2/weights')
        rpn_conv5 = tf.get_variable('resnet_v1_101/block3/unit_14/bottleneck_v1/conv3/weights')
      with tf.variable_scope('rfcn_network', reuse=True):
        rfcn_conv1 = tf.get_variable('resnet_v1_101/block2/unit_1/bottleneck_v1/shortcut/weights')
        rfcn_conv2 = tf.get_variable('resnet_v1_101/refined_reduce_depth/weights')
        rfcn_conv3 = tf.get_variable('resnet_v1_101/block4/unit_3/bottleneck_v1/conv3/weights')
        rfcn_conv4 = tf.get_variable('resnet_v1_101/block3/unit_18/bottleneck_v1/conv2/weights')
        rfcn_conv5 = tf.get_variable('resnet_v1_101/block3/unit_14/bottleneck_v1/conv3/weights')
      with tf.control_dependencies(merged_ops):
        rpn_conv1 = tf.identity(rpn_conv1)
        bool1 = tf.equal(rpn_conv1, rfcn_conv1)
        bool2 = tf.equal(rpn_conv2, rfcn_conv2)
        bool3 = tf.equal(rpn_conv3, rfcn_conv3)
        bool4 = tf.equal(rpn_conv4, rfcn_conv4)
        bool5 = tf.equal(rpn_conv5, rfcn_conv5)
        bool1_val, bool2_val, bool3_val, bool4_val, bool5_val = \
          sess.run([bool1, bool2, bool3, bool4, bool5])
    # stage 3 and stage 4
    sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))
    # while iter < max_iters + 1:
    while iter < 480001:
    # while iter < 401:
        if iter == 280001:
        # if iter == 261:
          # Add snapshot here before reducing the learning rate
          self.snapshot(sess, iter)
          sess.run(tf.assign(lr, 0.001))
        if iter == 400001:
          self.snapshot(sess, iter)
          sess.run(tf.assign(lr, 0.001))

        blobs = self.data_layer.forward()
        # stage 3 training rpn layers only  in rpn network rpn layers
        if 200001 <= iter < 280001:
        # if 201 <= iter < 581:
          stage_infor = 'stage3'
          # if iter == 260001:
          if iter == 260001:
            self.snapshot(sess, iter)
            sess.run(tf.assign(lr, 0.0001))
          timer.tic()
          now = time.time()
          if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
            # Compute the graph with summary
            rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
              self.net.train_rpn_step_with_summary(sess, blobs, train_op_stage3)
            self.writer_stage3.add_summary(summary, float(iter))
            # Also check the summary on the validation set
            blobs_val = self.data_layer_val.forward()
            summary_val = self.net.get_summary(sess, blobs_val)
            self.valwriter_stage3.add_summary(summary_val, float(iter))
            last_summary_time = now
          else:
            # Compute the graph without summary
            rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
              self.net.train_rpn_step(sess, blobs, train_op_stage3)
          timer.toc()
          if iter == 280000:
            self.writer_stage3.close()
            self.valwriter_stage3.close()
        # stage 4 training rfcn layer only in rpn network rfcn layers
        elif 280001 <= iter < 400001:
        # elif 581 <= iter < 1401:
          stage_infor = 'stage4'
          if iter == 360001:
          # if iter == 1361:
            sess.run(tf.assign(lr, 0.0001))
          timer.tic()
          now = time.time()
          if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
            # Compute the graph with summary
            rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
              self.net.train_rpn_step_with_summary(sess, blobs, train_op_stage4)
            self.writer_stage4.add_summary(summary, float(iter))
            # Also check the summary on the validation set
            blobs_val = self.data_layer_val.forward()
            summary_val = self.net.get_summary(sess, blobs_val)
            self.valwriter_stage4.add_summary(summary_val, float(iter))
            last_summary_time = now
          else:
            # Compute the graph without summary
            rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
              self.net.train_rpn_step(sess, blobs, train_op_stage4)
          timer.toc()
          if iter == 400000:
            self.writer_stage4.close()
            self.valwriter_stage4.close()
        elif 400001 <= iter < 480001:
          # if 401 <= iter < 481:
          stage_infor = 'stage5'
          # if iter == 461:
          if iter == 460001:
            self.snapshot(sess, iter)
            sess.run(tf.assign(lr, 0.0001))
          timer.tic()
          now = time.time()
          if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
            # Compute the graph with summary
            rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
              self.net.train_rpn_step_with_summary(sess, blobs, train_op_stage3)
            self.writer_stage3.add_summary(summary, float(iter))
            # Also check the summary on the validation set
            blobs_val = self.data_layer_val.forward()
            summary_val = self.net.get_summary(sess, blobs_val)
            self.valwriter_stage3.add_summary(summary_val, float(iter))
            last_summary_time = now
          else:
            # Compute the graph without summary
            rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
              self.net.train_rpn_step(sess, blobs, train_op_stage3)
          timer.toc()
          if iter == 480000:
            self.writer_stage3.close()
            self.valwriter_stage3.close()
        else:
          raise ValueError('iter is not allowed')

        # Display training information
        if iter % (cfg.TRAIN.DISPLAY) == 0:
          print('iter: %d / %d, stage: %s, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
                (iter, max_iters, stage_infor, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
          print('speed: {:.3f}s / iter'.format(timer.average_time))

        if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
          last_snapshot_iter = iter
          snapshot_path, np_path = self.snapshot(sess, iter)
          np_paths.append(np_path)
          ss_paths.append(snapshot_path)
        iter += 1

    if last_snapshot_iter != iter - 1:
      self.snapshot(sess, iter - 1)
  def train_model(self, sess, max_iters):
    # Build data layers for both training and validation set
    self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
    #self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)

    # Construct the computation graph
    lr, train_op = self.construct_graph(sess)

    # Find previous snapshots if there is any to restore from
    lsf, nfiles, sfiles = self.find_previous()

    # Initialize the variables or restore them from the last snapshot
    if lsf == 0:
      rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(sess)
    else:
      rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(sess, 
                                                                            str(sfiles[-1]), 
                                                                            str(nfiles[-1]))
    timer = Timer()
    iter = last_snapshot_iter + 1
    last_summary_time = time.time()
    # Make sure the lists are not empty
    stepsizes.append(max_iters)
    stepsizes.reverse()
    next_stepsize = stepsizes.pop()
    while iter < max_iters + 1:
      # Learning rate
      if iter == next_stepsize + 1:
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        rate *= cfg.TRAIN.GAMMA
        sess.run(tf.assign(lr, rate))
        next_stepsize = stepsizes.pop()

      timer.tic()
      # Get training data, one batch at a time
      blobs = self.data_layer.forward()

      now = time.time()
      if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
        # Compute the graph with summary
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
          self.net.train_step_with_summary(sess, blobs, train_op)
        self.writer.add_summary(summary, float(iter))
        # Also check the summary on the validation set
        #blobs_val = self.data_layer_val.forward()
        #summary_val = self.net.get_summary(sess, blobs_val)
        #self.valwriter.add_summary(summary_val, float(iter))
        last_summary_time = now
      else:
        # Compute the graph without summary
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
          self.net.train_step(sess, blobs, train_op)
      timer.toc()

      # Display training information
      if iter % (cfg.TRAIN.DISPLAY) == 0:
        print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
              '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
              (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
        print('speed: {:.3f}s / iter'.format(timer.average_time))

      # Snapshotting
      if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
        last_snapshot_iter = iter
        ss_path, np_path = self.snapshot(sess, iter)
        np_paths.append(np_path)
        ss_paths.append(ss_path)

        # Remove the old snapshots if there are too many
        if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
          self.remove_snapshot(np_paths, ss_paths)

      iter += 1

    if last_snapshot_iter != iter - 1:
      self.snapshot(sess, iter - 1)

    self.writer.close()
Beispiel #21
0
    def train_model(self, sess, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb,
                                           self.imdb.num_classes,
                                           random=True)
        self.data_layer_val_copy = RoIDataLayer(self.valroidb,
                                                self.imdb.num_classes)

        # Construct the computation graph
        lr, train_op = self.construct_graph(sess)

        # Find previous snapshots if there is any to restore from
        lsf, nfiles, sfiles = self.find_previous()

        # Initialize the variables or restore them from the last snapshot
        if lsf == 0:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(
                sess)
        else:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(
                sess, str(sfiles[-1]), str(nfiles[-1]))
        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        # Make sure the lists are not empty
        stepsizes.append(max_iters)
        stepsizes.reverse()
        next_stepsize = stepsizes.pop()
        max_AP = 0.0
        while iter < max_iters + 1:
            # Learning rate
            if iter == next_stepsize + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(sess, iter)
                rate *= cfg.TRAIN.GAMMA
                sess.run(tf.assign(lr, rate))
                next_stepsize = stepsizes.pop()

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            now = time.time()
            if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
                    self.net.train_step_with_summary(sess, blobs, train_op)
                self.writer.add_summary(summary, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(sess, blobs_val)
                self.valwriter.add_summary(summary_val, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
                    self.net.train_step(sess, blobs, train_op)
            timer.toc()

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                if not cfg.ONLY_RPN:
                    print(
                        'iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                        '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f'
                        % (iter, max_iters, total_loss, rpn_loss_cls,
                           rpn_loss_box, loss_cls, loss_box, lr.eval()))
                else:
                    print(
                        'iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                        '>>> rpn_loss_box: %.6f\n >>> lr: %f' %
                        (iter, max_iters, total_loss, rpn_loss_cls,
                         rpn_loss_box, lr.eval()))
                print('speed: {:.3f}s / iter'.format(timer.average_time))

            # Validation full dataset
            if iter == 1 or iter % cfg.VAL_ITERS == 0:
                if cfg.ONLY_RPN:
                    mAP = DetectionMAP(self.imdb.num_classes - 1)
                    for i in range(cfg.VAL_NUM):
                        blobs_val = self.data_layer_val_copy.forward()
                        rois, scores = self.net.test_rpn_image(sess, blobs_val)
                        boxes = rois[:, 1:5]

                        # apply threshold
                        inds = np.where(scores[:, 1] > 0.5)[0]
                        cls_scores = scores[inds, 1]
                        cls_boxes = boxes[inds, :]
                        cls_dets = np.hstack(
                            (cls_boxes,
                             cls_scores[:, np.newaxis])).astype(np.float32,
                                                                copy=False)
                        keep = nms(cls_dets, cfg.TEST.NMS)
                        cls_scores = cls_scores[keep]
                        cls_boxes = cls_boxes[keep, :]

                        if len(cls_scores) > 1:
                            image_thresh = np.sort(cls_scores)[-1]
                            keep = np.where(cls_scores >= image_thresh)[0]
                            cls_scores = cls_scores[keep]
                            cls_boxes = cls_boxes[keep, :]

                        mAP.evaluate(cls_boxes, np.zeros_like(cls_scores),
                                     cls_scores, blobs_val["gt_boxes"][:, 0:4],
                                     blobs_val["gt_boxes"][:, 4] - 1)

                    precisions, recalls = mAP.compute_precision_recall_(
                        0, True)
                    AP = mAP.compute_ap(precisions, recalls)
                    print("Evaluate AP: {:.3f}".format(AP))
                    summary_scalar(self.writer, iter, tags=["AP"], values=[AP])
                    if AP > max_AP:
                        max_AP = AP
                        self.snapshot_best(sess, iter)
                else:
                    raise NotImplementedError

            # Snapshotting
            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                ss_path, np_path = self.snapshot(sess, iter)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(sess, iter - 1)

        self.writer.close()
        self.valwriter.close()
    def train_model(self, max_iters):
        MIN_TOTAT_LOSS = np.inf
        MIN_D_LOSS_T = np.inf
        BEST_ITER = None
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb,
                                           self.imdb.num_classes,
                                           random=True)
        self.data_layer_T = RoIDataLayer(self.roidb_T, self.imdb.num_classes)

        # Construct the computation graph
        lr, train_op = self.construct_graph()

        # Find previous snapshots if there is any to restore from
        lsf, nfiles, sfiles = self.find_previous()

        # Initialize the variables or restore them from the last snapshot
        if lsf == 0:
            lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(
            )
        else:
            lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(
                str(sfiles[-1]), str(nfiles[-1]))
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        # Make sure the lists are not empty
        stepsizes.append(max_iters)
        stepsizes.reverse()
        next_stepsize = stepsizes.pop()

        self.net.train()
        self.net.cuda()

        self.net.D_img.train()
        self.net.D_img.cuda()

        #self.net.D_img2.train()
        #self.net.D_img2.cuda()
        mywriter = tb.SummaryWriter()
        while iter < max_iters + 1:
            # Learning rate
            if iter == next_stepsize + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(iter)
                lr *= cfg.TRAIN.GAMMA
                scale_lr(self.optimizer, cfg.TRAIN.GAMMA)
                #scale_lr(self.D_img_op, cfg.TRAIN.GAMMA)
                next_stepsize = stepsizes.pop()

            utils.timer.timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()
            blobsT = self.data_layer_T.forward()
            # print("#########################:blobs['data_path'][0]",blobs['data_path'][0])
            # print("synth_weight:",imdb.D_T_score[os.path.basename(blobs['data_path'][0])])
            # break
            now = time.time()
            #if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
            if False:
                # Compute the graph with summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, D_inst_loss_S, D_img_loss_S, D_const_loss_S, D_inst_loss_T, D_img_loss_T, D_const_loss_T, summary = \
                  self.net.train_adapt_step_with_summary(blobs, blobsT, self.optimizer, self.D_inst_op, self.D_img_op)
                for _sum in summary:
                    self.writer.add_summary(_sum, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(blobs_val)
                for _sum in summary_val:
                    self.valwriter.add_summary(_sum, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                if 'train' in blobs['data_path'][0]:
                    synth_weight = self.imdb.D_T_score[os.path.basename(
                        blobs['data_path'][0])]
                else:
                    synth_weight = 1

                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, D_img_loss_S, D_img_loss_T = \
                    self.net.train_adapt_step_img(blobs, blobsT, self.optimizer, self.D_img_op, synth_weight)

            utils.timer.timer.toc()
            if (((loss_cls + loss_box) < MIN_TOTAT_LOSS)
                    and (D_img_loss_T < MIN_D_LOSS_T)):
                MIN_TOTAT_LOSS = loss_cls + loss_box
                MIN_D_LOSS_T = D_img_loss_T
                BEST_ITER = iter
                print("Curr MIN_TOTAT_LOSS=:{} and min_D_loss_T=:{}".format(
                    MIN_TOTAT_LOSS, MIN_D_LOSS_T))

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                mywriter.add_scalar("Total Loss", total_loss, iter)
                mywriter.add_scalar("cls Loss", loss_cls, iter)
                mywriter.add_scalar("Box loss", loss_box, iter)
                mywriter.add_scalar("D_img_loss_S", D_img_loss_S, iter)
                mywriter.add_scalar("D_img_loss_T", D_img_loss_T, iter)
                fp = open('training_log.txt', 'a+')
                print("Writing Training log")
                temp_log = str(iter) + ',' + str(total_loss) + ',' + str(
                    rpn_loss_cls) + ',' + str(rpn_loss_box) + ',' + str(
                        loss_cls) + ',' + str(loss_box) + ',' + str(
                            D_img_loss_S) + ',' + str(D_img_loss_T)
                fp.write(temp_log)
                fp.write('\n')
                fp.close()
                print("Done !!!!!!!!")
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n '
                      '>>> D_img_loss_S: %.6f\n >>> D_img_loss_T: %.6f\n '
                      '>>> lambda: %f >>> lr: %f ' % \
                      (iter, max_iters, total_loss, rpn_loss_cls, \
                        rpn_loss_box, loss_cls, loss_box, \
                        D_img_loss_S, D_img_loss_T, \
                        cfg.ADAPT_LAMBDA, lr))
                print('speed: {:.3f}s / iter'.format(
                    utils.timer.timer.average_time()))

                # for k in utils.timer.timer._average_time.keys():
                #   print(k, utils.timer.timer.average_time(k))

            # Snapshotting
            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                ss_path, np_path = self.snapshot(iter)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            iter += 1
        print("##################### Best iteration = ", BEST_ITER,
              "###################################")
        print("##################### Optimal Total Loss=:", MIN_TOTAT_LOSS,
              "###########################")
        print("##################### Optimal lD_LOSS_T=:", MIN_D_LOSS_T,
              "##############################")
        _, _ = self.snapshot(BEST_ITER)
        if last_snapshot_iter != iter - 1:
            self.snapshot(iter - 1)

        self.writer.close()
        self.valwriter.close()
    def train_model(self, max_iters):
        # Build data layers for both training and validation set
        # 构建ROI数据集合,随机打乱顺序
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb,
                                           self.imdb.num_classes,
                                           random=True)

        # Construct the computation graph构建计算图,tensoroard
        # 初始化weight,初始化各层
        lr, train_op = self.construct_graph()

        # Find previous snapshots if there is any to restore from
        lsf, nfiles, sfiles = self.find_previous()
        # Initialize the variables or restore them from the last snapshot
        if lsf == 0:
            lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(
            )
        else:
            lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(
                str(sfiles[-1]), str(nfiles[-1]))
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        # Make sure the lists are not empty
        stepsizes.append(max_iters)
        stepsizes.reverse()
        next_stepsize = stepsizes.pop()

        self.net.train()
        self.net.to(self.net._device)  #数据传入设备
        ###################################正式训练开始###################################################################
        while iter < max_iters + 1:
            # Learning rate
            if iter == next_stepsize + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(iter)
                lr *= cfg.TRAIN.GAMMA
                scale_lr(self.optimizer, cfg.TRAIN.GAMMA)
                next_stepsize = stepsizes.pop()

            utils.timer.timer.tic()
            # Get training data, one batch at a time ,这里做图片读取修改
            #########################################################
            blobs = self.data_layer.forward()  #转到layer.forward

            #########################################################

            now = time.time()
            if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                #############################################################################################
                # 计算loss,反向传播 rpn_loss 仅在分开训练时使用
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
                  self.net.train_step_with_summary(blobs, self.optimizer)
                #############################################################################################
                for _sum in summary:
                    self.writer.add_summary(_sum, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(blobs_val)
                for _sum in summary_val:
                    self.valwriter.add_summary(_sum, float(iter))
                last_summary_time = now
            else:
                #############################################################################################
                # Compute the graph without summary  rpn_loss 仅在分开训练时使用
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
                  self.net.train_step(blobs, self.optimizer)


#############################################################################################
            utils.timer.timer.toc()

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
                      (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr))

                print('speed: {:.3f}s / iter'.format(
                    utils.timer.timer.average_time()))

                #torch.cuda.empty_cache()
                # for k in utils.timer.timer._average_time.keys():
                #   print(k, utils.timer.timer.average_time(k))

            # Snapshotting
            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0 or iter == 1:
                last_snapshot_iter = iter
                ss_path, np_path = self.snapshot(iter)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(iter - 1)

        self.writer.close()
        self.valwriter.close()
Beispiel #24
0
class SolverWrapper(object):
    """ A wrapper class for the training process """
    def __init__(self,
                 sess,
                 network,
                 imdb,
                 roidb,
                 valroidb,
                 output_dir,
                 pretrained_model=None,
                 logger=None):
        self.net = network
        self.imdb = imdb
        self.roidb = roidb
        self.valroidb = valroidb
        self.output_dir = output_dir
        self.pretrained_model = pretrained_model
        self.logger = logger

    def snapshot(self, sess, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.ckpt'
        filename = os.path.join(self.output_dir, filename)
        self.saver.save(sess, filename)
        self.logger.info('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indexes of the database
        perm = self.data_layer._perm
        # current position in the validation database
        cur_val = self.data_layer_val._cur
        # current shuffled indexes of the validation database
        perm_val = self.data_layer_val._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename

    def from_snapshot(self, sess, sfile, nfile):
        self.logger.info('Restoring model snapshots from {:s}'.format(sfile))
        self.saver.restore(sess, sfile)
        self.logger.info('Restored.')
        # Needs to restore the other hyper-parameters/states for training, I have
        # tried my best to find the random states so that it can be recovered exactly
        # However the Tensorflow state is currently not available
        with open(nfile, 'rb') as fid:
            st0 = pickle.load(fid)
            cur = pickle.load(fid)
            perm = pickle.load(fid)
            cur_val = pickle.load(fid)
            perm_val = pickle.load(fid)
            last_snapshot_iter = pickle.load(fid)

            np.random.set_state(st0)
            self.data_layer._cur = cur
            self.data_layer._perm = perm
            self.data_layer_val._cur = cur_val
            self.data_layer_val._perm = perm_val

        return last_snapshot_iter

    def get_variables_in_checkpoint_file(self, file_name):
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print(
                    "It's likely that your checkpoint file has been compressed "
                    "with SNAPPY.")

    def _return_gradients(self, gvs):
        # grads, vars = gvs
        grads = [g for g, _ in gvs]
        vars = [v for _, v in gvs]
        return [
            grad if grad is not None else tf.zeros_like(var)
            for var, grad in zip(vars, grads)
        ]

    def _compute_gradients(self, tensor, var_list):
        grads = tf.gradients(tensor, var_list)
        return [
            grad if grad is not None else tf.zeros_like(var)
            for var, grad in zip(var_list, grads)
        ]

    def construct_graph(self, sess):
        # Set the random seed for tensorflow
        tf.set_random_seed(cfg.RNG_SEED)
        with sess.graph.as_default():
            # Build the main computation graph
            layers = self.net.create_architecture(
                'TRAIN',
                self.imdb.num_classes,
                tag='default',
                anchor_sizes=cfg.ANCHOR_SIZES,
                anchor_strides=cfg.ANCHOR_STRIDES,
                anchor_ratios=cfg.ANCHOR_RATIOS)
            # Define the loss
            loss = layers['total_loss']
            # Set learning rate and momentum
            lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
            self.optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)

            # Compute the gradients with regard to the loss
            gvs = self.optimizer.compute_gradients(loss)
            self.net.grads = self._compute_gradients(loss, self.net.fr_tvars)
            self.return_grads = self._return_gradients(gvs)
            train_op = self.optimizer.apply_gradients(gvs)

            # Initialize main LRP-HAI network
            self.net.build_LRP_HAI_network()

        return lr, train_op

    def find_previous(self):
        sfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
        sfiles = glob.glob(sfiles)
        sfiles.sort(key=os.path.getmtime)
        # Get the snapshot name in TensorFlow
        redfiles = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            redfiles.append(
                os.path.join(
                    self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX +
                    '_iter_{:d}.ckpt.meta'.format(stepsize + 1)))
        sfiles = [
            ss.replace('.meta', '') for ss in sfiles if ss not in redfiles
        ]

        nfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
        nfiles = glob.glob(nfiles)
        nfiles.sort(key=os.path.getmtime)
        redfiles = [
            redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles
        ]
        nfiles = [nn for nn in nfiles if nn not in redfiles]

        lsf = len(sfiles)
        assert len(nfiles) == lsf

        return lsf, nfiles, sfiles

    def initialize(self, sess):
        # Initial file lists are empty
        np_paths = []
        ss_paths = []
        # Fresh train directly from ImageNet weights
        self.logger.info('Loading initial model weights from {:s}'.format(
            self.pretrained_model))
        variables = tf.global_variables()
        # Initialize all variables first
        sess.run(tf.variables_initializer(variables, name='init'))
        var_keep_dic = self.get_variables_in_checkpoint_file(
            self.pretrained_model)
        # print(self.pretrained_model)
        # sleep(100)
        # Get the variables to restore, ignoring the variables to fix
        variables_to_restore = self.net.get_variables_to_restore(
            variables, var_keep_dic)
        restorer = tf.train.Saver(variables_to_restore)
        restorer.restore(sess, self.pretrained_model)
        self.logger.info('Loaded.')

        last_snapshot_iter = 0
        fr_rate = cfg.TRAIN.LEARNING_RATE
        fr_stepsize = cfg.TRAIN.STEPSIZE
        drl_rate = cfg.LRP_HAI_TRAIN.LEARNING_RATE
        drl_stepsize = cfg.LRP_HAI_TRAIN.STEPSIZE
        return fr_rate, drl_rate, last_snapshot_iter, fr_stepsize, drl_stepsize, np_paths, ss_paths

    def restore(self, sess, sfile, nfile):
        # Get the most recent snapshot and restore
        np_paths = [nfile]
        ss_paths = [sfile]
        # Restore model from snapshots
        last_snapshot_iter = self.from_snapshot(sess, sfile, nfile)
        # Set the learning rate
        fr_rate = cfg.TRAIN.LEARNING_RATE
        fr_stepsize = cfg.TRAIN.STEPSIZE[0]
        drl_rate = cfg.LRP_HAI_TRAIN.LEARNING_RATE
        drl_stepsize = cfg.LRP_HAI_TRAIN.STEPSIZE
        if last_snapshot_iter > fr_stepsize:
            fr_rate *= cfg.TRAIN.GAMMA
        if last_snapshot_iter > drl_stepsize:
            drl_rate *= cfg.LRP_HAI_TRAIN.GAMMA

        return fr_rate, drl_rate, last_snapshot_iter, fr_stepsize, drl_stepsize, np_paths, ss_paths

    def remove_snapshot(self, np_paths, ss_paths):
        to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            nfile = np_paths[0]
            os.remove(str(nfile))
            np_paths.remove(nfile)
        to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            sfile = ss_paths[0]
            # To make the code compatible to earlier versions of Tensorflow,
            # where the naming tradition for checkpoints are different
            if os.path.exists(str(sfile)):
                os.remove(str(sfile))
            else:
                os.remove(str(sfile + '.data-00000-of-00001'))
                os.remove(str(sfile + '.index'))
            sfile_meta = sfile + '.meta'
            os.remove(str(sfile_meta))
            ss_paths.remove(sfile)

    def _print_det_loss(self,
                        iter,
                        max_iters,
                        tot_loss,
                        loss_cls,
                        loss_box,
                        lr,
                        timer,
                        in_string='detector'):
        if (iter + 1) % (cfg.TRAIN.DISPLAY) == 0:
            if loss_box is not None:
                self.logger.info('iter: %d / %d, total loss: %.6f\n '
                                 '>>> loss_cls (%s): %.6f\n '
                                 '>>> loss_box (%s): %.6f\n >>> lr: %f' % \
                                 (iter + 1, max_iters, tot_loss, in_string, loss_cls, in_string,
                                  loss_box, lr))
            else:
                self.logger.info('iter: %d / %d, total loss (%s): %.6f\n >>> lr: %f' % \
                                 (iter + 1, max_iters, in_string, tot_loss, lr))
            self.logger.info('speed: {:.3f}s / iter'.format(
                timer.average_time))

    def _check_if_continue(self, iter, max_iters, snapshot_add):
        img_start_idx = cfg.LRP_HAI_TRAIN.IMG_START_IDX
        if iter > img_start_idx:
            return iter, max_iters, snapshot_add, False
        if iter < img_start_idx:
            self.logger.info("iter %d < img_start_idx %d -- continuing" %
                             (iter, img_start_idx))
            iter += 1
            return iter, max_iters, snapshot_add, True
        if iter == img_start_idx:
            self.logger.info("Adjusting stepsize, train-det-start etcetera")
            snapshot_add = img_start_idx
            max_iters -= img_start_idx
            iter = 0
            cfg_from_list(['LRP_HAI_TRAIN.IMG_START_IDX', -1])
            cfg_from_list([
                'LRP_HAI_TRAIN.DET_START',
                cfg.LRP_HAI_TRAIN.DET_START - img_start_idx
            ])
            cfg_from_list([
                'LRP_HAI_TRAIN.STEPSIZE',
                cfg.LRP_HAI_TRAIN.STEPSIZE - img_start_idx
            ])
            cfg_from_list(
                ['TRAIN.STEPSIZE', [cfg.TRAIN.STEPSIZE[0] - img_start_idx]])
            self.logger.info(
                "Done adjusting stepsize, train-det-start etcetera")
            return iter, max_iters, snapshot_add, False

    def train_model(self, sess, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, cfg.NBR_CLASSES)
        self.data_layer_val = RoIDataLayer(self.valroidb, cfg.NBR_CLASSES,
                                           True)

        # Construct the computation graph corresponding to the original Faster R-CNN
        # architecture first
        lr_det_op, train_op = self.construct_graph(sess)

        # We will handle the snapshots ourselves
        self.saver = tf.train.Saver(max_to_keep=100000)

        # Find previous snapshots if there is any to restore from
        lsf, nfiles, sfiles = self.find_previous()

        # Initialize the variables or restore them from the last snapshot
        if lsf == 0:
            fr_rate, drl_rate, last_snapshot_iter, \
            fr_stepsize, drl_stepsize, np_paths, ss_paths = self.initialize(sess)
        else:
            fr_rate, drl_rate, last_snapshot_iter, \
            fr_stepsize, drl_stepsize, np_paths, ss_paths = self.restore(sess,
                                                                         str(sfiles[-1]),
                                                                         str(nfiles[-1]))

        # Initialize
        self.net.init_rl_train(sess)

        # Setup initial learning rates
        # 0.00002
        lr_rl = drl_rate
        # 0.00025
        lr_det = fr_rate
        sess.run(tf.assign(lr_det_op, lr_det))

        # Sample first beta
        beta = cfg.LRP_HAI_TRAIN.BETA

        # Setup LRP-HAI timers
        timers = {
            'init': Timer(),
            'fulltraj': Timer(),
            'upd-obs-vol': Timer(),
            'upd-seq': Timer(),
            'upd-rl': Timer(),
            'action-rl': Timer(),
            'coll-traj': Timer(),
            'run-LRP-HAI': Timer(),
            'train-LRP-HAI': Timer(),
            'batch_time': Timer(),
            'total': Timer()
        }

        # Create StatCollector (tracks various RL training statistics)
        stat_strings = [
            'rews_total_traj', 'traj-len', 'frac-area', 'gt >= 0.5 frac',
            'gt-IoU-frac'
        ]
        sc = StatCollector(max_iters, stat_strings,
                           cfg.LRP_HAI_TRAIN.BATCH_SIZE, self.output_dir)

        timer = Timer()
        iter = last_snapshot_iter
        snapshot_add = 0
        timers['total'].tic()
        timers['batch_time'].tic()
        while iter < max_iters:
            # Get training data, one batch at a time (assumes batch size 1)
            blobs = self.data_layer.forward()

            iter, max_iters, snapshot_add, do_continue \
                = self._check_if_continue(iter, max_iters, snapshot_add)
            if do_continue:
                continue
            # Potentially update LRP-HAI learning rate
            # 90000
            if (iter + 1) % cfg.LRP_HAI_TRAIN.STEPSIZE == 0:
                # lr_rl = lr_rl * 0.2
                lr_rl *= cfg.LRP_HAI_TRAIN.GAMMA

            # Run LRP-HAI in training mode
            timers['run-LRP-HAI'].tic()
            net_conv, rois_LRP_HAI, gt_boxes, im_info, timers, stats = run_LRP_HAI(
                sess,
                self.net,
                blobs,
                timers,
                mode='train',
                beta=beta,
                im_idx=None,
                extra_args=lr_rl,
                alpha=cfg.LRP_HAI.ALPHA)
            timers['run-LRP-HAI'].toc()

            # BATCH_SIZE = 50
            if (iter + 1) % cfg.LRP_HAI_TRAIN.BATCH_SIZE == 0:
                self.logger.info(
                    "\n##### LRP-HAI BATCH GRADIENT UPDATE - START ##### \n")
                self.logger.info('iter: %d / %d' % (iter + 1, max_iters))
                self.logger.info('lr-rl: %f' % lr_rl)
                timers['train-LRP-HAI'].tic()
                self.net.train_LRP_HAI(sess, lr_rl, sc, stats)
                timers['train-LRP-HAI'].toc()
                sc.print_stats(iter=iter + 1, logger=self.logger)

                batch_time = timers['batch_time'].toc()
                self.logger.info('TIMINGS:')
                self.logger.info('runnn-LRP-HAI: %.4f' %
                                 timers['run-LRP-HAI'].get_avg())
                self.logger.info('train-LRP-HAI: %.4f' %
                                 timers['train-LRP-HAI'].get_avg())
                self.logger.info('train-LRP-HAI-batch: %.4f' % batch_time)
                self.logger.info(
                    "\n##### LRP-HAI BATCH GRADIENT UPDATE - DONE ###### \n")

                timers['batch_time'].tic()
            else:
                sc.update(0, 0, 0, stats)

            # At this point we assume that an RL-trajectory has been performed.
            # We next train detector with LRP-HAI running in deterministic mode.
            # Potentially train detector component of network
            # DET_START = 40000
            if 0 <= cfg.LRP_HAI_TRAIN.DET_START <= iter:

                # Run LRP-HAI in deterministic mode
                # net_conv, rois_LRP_HAI, gt_boxes, im_info, timers \
                #     = run_LRP_HAI(sess, self.net, blobs, timers, mode='train_det',
                #                   beta=beta, im_idx=None, alpha=cfg.LRP_HAI.ALPHA)

                # Learning rate
                if (iter + 1) % cfg.TRAIN.STEPSIZE[0] == 0:
                    lr_det *= cfg.TRAIN.GAMMA
                    sess.run(tf.assign(lr_det_op, lr_det))

                if rois_LRP_HAI is not None:
                    timer.tic()
                    # Train detector part
                    loss_cls, loss_box, tot_loss \
                        = self.net.train_step_det(sess, train_op, net_conv, rois_LRP_HAI,
                                                gt_boxes, im_info)
                    timer.toc()

                # Display training information
                self._print_det_loss(iter, max_iters, tot_loss, loss_cls,
                                     loss_box, lr_det, timer)

            # Snapshotting
            if (iter + 1) % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter + 1
                ss_path, np_path = self.snapshot(sess, iter + 1 + snapshot_add)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            # Increase iteration counter
            iter += 1

        # Potentially save one last time
        if last_snapshot_iter != iter:
            self.snapshot(sess, iter + snapshot_add)

        timers['total'].toc()
        total_time = timers['total'].total_time
        m, s = divmod(total_time, 60)
        h, m = divmod(m, 60)
        self.logger.info("total time: %02d:%02d:%02d" % (h, m, s))
Beispiel #25
0
  def train_model(self, sess, max_iters):
    # Build data layers for both training and validation set
    self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
    self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)

    # Determine different scales for anchors, see paper
    with sess.graph.as_default():
      # Set the random seed for tensorflow
      tf.set_random_seed(cfg.RNG_SEED)
      # Build the main computation graph
      layers = self.net.create_architecture(sess, 'TRAIN', self.imdb.num_classes, tag='default',
                                            anchor_scales=cfg.ANCHOR_SCALES,
                                            anchor_ratios=cfg.ANCHOR_RATIOS)
      # Define the loss
      loss = layers['total_loss']
      # Set learning rate and momentum
      lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
      momentum = cfg.TRAIN.MOMENTUM
      self.optimizer = tf.train.MomentumOptimizer(lr, momentum)

      # Compute the gradients wrt the loss
      gvs = self.optimizer.compute_gradients(loss)

      # Double the gradient of the bias if set
      if cfg.TRAIN.DOUBLE_BIAS:
        final_gvs = []
        with tf.variable_scope('Gradient_Mult') as scope:
          for grad, var in gvs:
            scale = 1.
            if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
              scale *= 2.
            if not np.allclose(scale, 1.0):
              grad = tf.multiply(grad, scale)
            final_gvs.append((grad, var))
        train_op = self.optimizer.apply_gradients(final_gvs)
      else:
        train_op = self.optimizer.apply_gradients(gvs)

      # We will handle the snapshots ourselves
      self.saver = tf.train.Saver(max_to_keep=100000)
      # Write the train and validation information to tensorboard
      self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
      self.valwriter = tf.summary.FileWriter(self.tbvaldir)

    # Find previous snapshots if there is any to restore from
    sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
    sfiles = glob.glob(sfiles)
    sfiles.sort(key=os.path.getmtime)
    # Get the snapshot name in TensorFlow
    redstr = '_iter_{:d}.'.format(cfg.TRAIN.STEPSIZE+1)
    sfiles = [ss.replace('.meta', '') for ss in sfiles]
    sfiles = [ss for ss in sfiles if redstr not in ss]

    nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
    nfiles = glob.glob(nfiles)
    nfiles.sort(key=os.path.getmtime)
    nfiles = [nn for nn in nfiles if redstr not in nn]

    lsf = len(sfiles)
    assert len(nfiles) == lsf

    np_paths = nfiles
    ss_paths = sfiles

    if lsf == 0:
      # Fresh train directly from ImageNet weights
      print('Loading initial model weights from {:s}'.format(self.pretrained_model))
      variables = tf.global_variables()
      # Initialize all variables first
      sess.run(tf.variables_initializer(variables, name='init'))
      var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model)
      # Get the variables to restore, ignorizing the variables to fix
      variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic)

      restorer = tf.train.Saver(variables_to_restore)
      restorer.restore(sess, self.pretrained_model)
      print('Loaded.')
      # Need to fix the variables before loading, so that the RGB weights are changed to BGR
      # For VGG16 it also changes the convolutional weights fc6 and fc7 to
      # fully connected weights
      self.net.fix_variables(sess, self.pretrained_model)
      print('Fixed.')
      sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))
      last_snapshot_iter = 0
    else:
      # Get the most recent snapshot and restore
      ss_paths = [ss_paths[-1]]
      np_paths = [np_paths[-1]]

      print('Restorining model snapshots from {:s}'.format(sfiles[-1]))
      self.saver.restore(sess, str(sfiles[-1]))
      print('Restored.')
      # Needs to restore the other hyperparameters/states for training, (TODO xinlei) I have
      # tried my best to find the random states so that it can be recovered exactly
      # However the Tensorflow state is currently not available
      with open(str(nfiles[-1]), 'rb') as fid:
        st0 = pickle.load(fid)
        cur = pickle.load(fid)
        perm = pickle.load(fid)
        cur_val = pickle.load(fid)
        perm_val = pickle.load(fid)
        last_snapshot_iter = pickle.load(fid)

        np.random.set_state(st0)
        self.data_layer._cur = cur
        self.data_layer._perm = perm
        self.data_layer_val._cur = cur_val
        self.data_layer_val._perm = perm_val

        # Set the learning rate, only reduce once
        if last_snapshot_iter > cfg.TRAIN.STEPSIZE:
          sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA))
        else:
          sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))

    timer = Timer()
    iter = last_snapshot_iter + 1
    last_summary_time = time.time()
    while iter < max_iters + 1:
      # Learning rate
      if iter == cfg.TRAIN.STEPSIZE + 1:
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA))

      timer.tic()
      # Get training data, one batch at a time
      blobs = self.data_layer.forward()

      now = time.time()
      if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
        # Compute the graph with summary
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
          self.net.train_step_with_summary(sess, blobs, train_op)

        # self.writer.flush()
        self.writer.add_summary(summary, float(iter))
        # Also check the summary on the validation set
        blobs_val = self.data_layer_val.forward()
        summary_val = self.net.get_summary(sess, blobs_val)
        self.valwriter.add_summary(summary_val, float(iter))
        last_summary_time = now
      else:
        # Compute the graph without summary
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
          self.net.train_step(sess, blobs, train_op)
        # rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, center_loss, total_loss = \
        #   self.net.train_step_center(sess, blobs, train_op)
      timer.toc()

      # Display training information
      if iter % (cfg.TRAIN.DISPLAY) == 0:
        print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
              '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
              (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
        print('speed: {:.3f}s / iter'.format(timer.average_time))


      if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
        last_snapshot_iter = iter
        snapshot_path, np_path = self.snapshot(sess, iter)
        np_paths.append(np_path)
        ss_paths.append(snapshot_path)

        # Remove the old snapshots if there are too many
        if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
          to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
          for c in range(to_remove):
            nfile = np_paths[0]
            os.remove(str(nfile))
            np_paths.remove(nfile)

        if len(ss_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
          to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
          for c in range(to_remove):
            sfile = ss_paths[0]
            # To make the code compatible to earlier versions of Tensorflow,
            # where the naming tradition for checkpoints are different
            if os.path.exists(str(sfile)):
              os.remove(str(sfile))
            else:
              os.remove(str(sfile + '.data-00000-of-00001'))
              os.remove(str(sfile + '.index'))
            sfile_meta = sfile + '.meta'
            os.remove(str(sfile_meta))
            ss_paths.remove(sfile)

      iter += 1

    if last_snapshot_iter != iter - 1:
      self.snapshot(sess, iter - 1)

    self.writer.close()
    self.valwriter.close()
Beispiel #26
0
    def train_model(self, sess, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, cfg.NBR_CLASSES)
        self.data_layer_val = RoIDataLayer(self.valroidb, cfg.NBR_CLASSES,
                                           True)

        # Construct the computation graph corresponding to the original Faster R-CNN
        # architecture first (and potentially the post-hist module of drl-RPN)
        lr_det_op, train_op, lr_post_op, train_op_post \
          = self.construct_graph(sess)

        # Initialize the variables or restore them from the last snapshot
        rate, last_snapshot_iter, stepsizes, np_paths, ss_paths \
          = self.initialize(sess)

        # We will handle the snapshots ourselves
        self.saver = tf.train.Saver(max_to_keep=100000)

        # Initialize
        self.net.init_rl_train(sess)

        # Setup initial learning rates
        lr_rl = cfg.DRL_RPN_TRAIN.LEARNING_RATE
        lr_det = cfg.TRAIN.LEARNING_RATE
        sess.run(tf.assign(lr_det_op, lr_det))
        if cfg.DRL_RPN.USE_POST:
            lr_post = cfg.DRL_RPN_TRAIN.POST_LR
            sess.run(tf.assign(lr_post_op, lr_post))

        # Sample first beta
        if cfg.DRL_RPN_TRAIN.USE_POST:
            betas = cfg.DRL_RPN_TRAIN.POST_BETAS
        else:
            betas = cfg.DRL_RPN_TRAIN.BETAS
        beta_idx = 0
        beta = betas[beta_idx]

        # Setup drl-RPN timers
        timers = {
            'init': Timer(),
            'fulltraj': Timer(),
            'upd-obs-vol': Timer(),
            'upd-seq': Timer(),
            'upd-rl': Timer(),
            'action-rl': Timer(),
            'coll-traj': Timer(),
            'run-drl-rpn': Timer(),
            'train-drl-rpn': Timer()
        }

        # Create StatCollector (tracks various RL training statistics)
        stat_strings = [
            'reward', 'rew-done', 'traj-len', 'frac-area', 'gt >= 0.5 frac',
            'gt-IoU-frac'
        ]
        sc = StatCollector(max_iters, stat_strings)

        timer = Timer()
        iter = 0
        snapshot_add = 0
        while iter < max_iters:

            # Get training data, one batch at a time (assumes batch size 1)
            blobs = self.data_layer.forward()

            # Allows the possibility to start at arbitrary image, rather
            # than always starting from first image in dataset. Useful if
            # want to load parameters and keep going from there, rather
            # than having those and encountering visited images again.
            iter, max_iters, snapshot_add, do_continue \
              = self._check_if_continue(iter, max_iters, snapshot_add)
            if do_continue:
                continue

            if not cfg.DRL_RPN_TRAIN.USE_POST:

                # Potentially update drl-RPN learning rate
                if (iter + 1) % cfg.DRL_RPN_TRAIN.STEPSIZE == 0:
                    lr_rl *= cfg.DRL_RPN_TRAIN.GAMMA

                # Run drl-RPN in training mode
                timers['run-drl-rpn'].tic()
                stats = run_drl_rpn(sess,
                                    self.net,
                                    blobs,
                                    timers,
                                    mode='train',
                                    beta=beta,
                                    im_idx=None,
                                    extra_args=lr_rl)
                timers['run-drl-rpn'].toc()

                if (iter + 1) % cfg.DRL_RPN_TRAIN.BATCH_SIZE == 0:
                    print(
                        "\n##### DRL-RPN BATCH GRADIENT UPDATE - START ##### \n"
                    )
                    print('iter: %d / %d' % (iter + 1, max_iters))
                    print('lr-rl: %f' % lr_rl)
                    timers['train-drl-rpn'].tic()
                    self.net.train_drl_rpn(sess, lr_rl, sc, stats)
                    timers['train-drl-rpn'].toc()
                    sc.print_stats()
                    print('TIMINGS:')
                    print('runnn-drl-rpn: %.4f' %
                          timers['run-drl-rpn'].get_avg())
                    print('train-drl-rpn: %.4f' %
                          timers['train-drl-rpn'].get_avg())
                    print(
                        "\n##### DRL-RPN BATCH GRADIENT UPDATE - DONE ###### \n"
                    )

                    # Also sample new beta for next batch
                    beta_idx += 1
                    beta_idx %= len(betas)
                    beta = betas[beta_idx]
                else:
                    sc.update(0, stats)

                # At this point we assume that an RL-trajectory has been performed.
                # We next train detector with drl-RPN running in deterministic mode.
                # Potentially train detector component of network
                if cfg.DRL_RPN_TRAIN.DET_START >= 0 and \
                  iter >= cfg.DRL_RPN_TRAIN.DET_START:

                    # Run drl-RPN in deterministic mode
                    net_conv, rois_drl_rpn, gt_boxes, im_info, timers, _ \
                      = run_drl_rpn(sess, self.net, blobs, timers, mode='train_det',
                                    beta=beta, im_idx=None)

                    # Learning rate
                    if (iter + 1) % cfg.TRAIN.STEPSIZE[0] == 0:
                        lr_det *= cfg.TRAIN.GAMMA
                        sess.run(tf.assign(lr_det_op, lr_det))

                    timer.tic()
                    # Train detector part
                    loss_cls, loss_box, tot_loss \
                      = self.net.train_step_det(sess, train_op, net_conv, rois_drl_rpn,
                                                gt_boxes, im_info)
                    timer.toc()

                    # Display training information
                    self._print_det_loss(iter, max_iters, tot_loss, loss_cls,
                                         loss_box, lr_det, timer)

            # Train post-hist module AFTER we have trained rest of drl-RPN! Specifically
            # once rest of drl-RPN has been trained already, copy those weights into
            # the folder of pretrained weights and rerun training with those as initial
            # weights, which will then train only the posterior-history module
            else:

                # The very first time we need to assign the ordinary detector weights
                # as starting point
                if iter == 0:
                    self.net.assign_post_hist_weights(sess)

                # Sample beta
                beta = betas[beta_idx]
                beta_idx += 1
                beta_idx %= len(betas)

                # Run drl-RPN in deterministic mode
                net_conv, rois_drl_rpn, gt_boxes, im_info, timers, cls_hist \
                  = run_drl_rpn(sess, self.net, blobs, timers, mode='train_det',
                                beta=beta, im_idx=None)

                # Learning rate (assume only one learning rate iter for now!)
                if (iter + 1) % cfg.DRL_RPN_TRAIN.POST_SS[0] == 0:
                    lr_post *= cfg.TRAIN.GAMMA
                    sess.run(tf.assign(lr_post_op, lr_post))

                # Train post-hist detector part
                tot_loss = self.net.train_step_post(sess, train_op_post,
                                                    net_conv, rois_drl_rpn,
                                                    gt_boxes, im_info,
                                                    cls_hist)

                # Display training information
                self._print_det_loss(iter, max_iters, tot_loss, None, None,
                                     lr_post, timer, 'post-hist')

            # Snapshotting
            if (iter + 1) % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter + 1
                ss_path, np_path = self.snapshot(sess, iter + 1 + snapshot_add)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            # Increase iteration counter
            iter += 1

        # Potentially save one last time
        if last_snapshot_iter != iter:
            self.snapshot(sess, iter + snapshot_add)
class SolverWrapper(object):
  """
    A wrapper class for the training process
  """

  def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None):
    self.net = network
    self.imdb = imdb
    self.roidb = roidb
    self.valroidb = valroidb
    self.output_dir = output_dir
    self.tbdir = tbdir
    # Simply put '_val' at the end to save the summaries from the validation set
    self.tbvaldir = tbdir + '_val'
    if not os.path.exists(self.tbvaldir):
      os.makedirs(self.tbvaldir)
    self.pretrained_model = pretrained_model

  def snapshot(self, sess, iter):
    net = self.net

    if not os.path.exists(self.output_dir):
      os.makedirs(self.output_dir)

    # Store the model snapshot
    filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt'
    filename = os.path.join(self.output_dir, filename)
    self.saver.save(sess, filename)
    print('Wrote snapshot to: {:s}'.format(filename))

    # Also store some meta information, random state, etc.
    nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
    nfilename = os.path.join(self.output_dir, nfilename)
    # current state of numpy random
    st0 = np.random.get_state()
    # current position in the database
    cur = self.data_layer._cur
    # current shuffled indexes of the database
    perm = self.data_layer._perm
    # current position in the validation database
    #cur_val = self.data_layer_val._cur
    # current shuffled indexes of the validation database
    #perm_val = self.data_layer_val._perm

    # Dump the meta info
    with open(nfilename, 'wb') as fid:
      pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
      #pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
      #pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

    return filename, nfilename

  def from_snapshot(self, sess, sfile, nfile):
    print('Restoring model snapshots from {:s}'.format(sfile))
    self.saver.restore(sess, sfile)
    print('Restored.')
    # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
    # tried my best to find the random states so that it can be recovered exactly
    # However the Tensorflow state is currently not available
    with open(nfile, 'rb') as fid:
      st0 = pickle.load(fid)
      cur = pickle.load(fid)
      perm = pickle.load(fid)
      #cur_val = pickle.load(fid)
      #perm_val = pickle.load(fid)
      last_snapshot_iter = pickle.load(fid)

      np.random.set_state(st0)
      self.data_layer._cur = cur
      self.data_layer._perm = perm
      #self.data_layer_val._cur = cur_val
      #self.data_layer_val._perm = perm_val

    return last_snapshot_iter

  def get_variables_in_checkpoint_file(self, file_name):
    try:
      reader = pywrap_tensorflow.NewCheckpointReader(file_name)
      var_to_shape_map = reader.get_variable_to_shape_map()
      return var_to_shape_map 
    except Exception as e:  # pylint: disable=broad-except
      print(str(e))
      if "corrupted compressed block contents" in str(e):
        print("It's likely that your checkpoint file has been compressed "
              "with SNAPPY.")

  def construct_graph(self, sess):
    with sess.graph.as_default():
      # Set the random seed for tensorflow
      tf.set_random_seed(cfg.RNG_SEED)
      # Build the main computation graph
      layers = self.net.create_architecture('TRAIN', self.imdb.num_classes, tag='default',
                                            anchor_scales=cfg.ANCHOR_SCALES,
                                            anchor_ratios=cfg.ANCHOR_RATIOS)
      # Define the loss
      loss = layers['total_loss']
      # Set learning rate and momentum
      lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
      self.optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)

      # Compute the gradients with regard to the loss
      gvs = self.optimizer.compute_gradients(loss)
      # Double the gradient of the bias if set
      if cfg.TRAIN.DOUBLE_BIAS:
        final_gvs = []
        with tf.variable_scope('Gradient_Mult') as scope:
          for grad, var in gvs:
            scale = 1.
            if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
              scale *= 2.
            if not np.allclose(scale, 1.0):
              grad = tf.multiply(grad, scale)
            final_gvs.append((grad, var))
        train_op = self.optimizer.apply_gradients(final_gvs)
      else:
        train_op = self.optimizer.apply_gradients(gvs)

      # We will handle the snapshots ourselves
      self.saver = tf.train.Saver(max_to_keep=100000)
      # Write the train and validation information to tensorboard
      self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
      #self.valwriter = tf.summary.FileWriter(self.tbvaldir)

    return lr, train_op

  def find_previous(self):
    sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
    sfiles = glob.glob(sfiles)
    sfiles.sort(key=os.path.getmtime)
    # Get the snapshot name in TensorFlow
    redfiles = []
    for stepsize in cfg.TRAIN.STEPSIZE:
      redfiles.append(os.path.join(self.output_dir, 
                      cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}.ckpt.meta'.format(stepsize+1)))
    sfiles = [ss.replace('.meta', '') for ss in sfiles if ss not in redfiles]

    nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
    nfiles = glob.glob(nfiles)
    nfiles.sort(key=os.path.getmtime)
    redfiles = [redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles]
    nfiles = [nn for nn in nfiles if nn not in redfiles]

    lsf = len(sfiles)
    assert len(nfiles) == lsf

    return lsf, nfiles, sfiles

  def initialize(self, sess):
    # Initial file lists are empty
    np_paths = []
    ss_paths = []
    # Fresh train directly from ImageNet weights
    print('Loading initial model weights from {:s}'.format(self.pretrained_model))
    variables = tf.global_variables()
    # Initialize all variables first
    sess.run(tf.variables_initializer(variables, name='init'))
    var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model)
    # Get the variables to restore, ignoring the variables to fix
    variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic)

    restorer = tf.train.Saver(variables_to_restore)
    restorer.restore(sess, self.pretrained_model)
    print('Loaded.')
    # Need to fix the variables before loading, so that the RGB weights are changed to BGR
    # For VGG16 it also changes the convolutional weights fc6 and fc7 to
    # fully connected weights
    self.net.fix_variables(sess, self.pretrained_model)
    print('Fixed.')
    last_snapshot_iter = 0
    rate = cfg.TRAIN.LEARNING_RATE
    stepsizes = list(cfg.TRAIN.STEPSIZE)

    return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

  def restore(self, sess, sfile, nfile):
    # Get the most recent snapshot and restore
    np_paths = [nfile]
    ss_paths = [sfile]
    # Restore model from snapshots
    last_snapshot_iter = self.from_snapshot(sess, sfile, nfile)
    # Set the learning rate
    rate = cfg.TRAIN.LEARNING_RATE
    stepsizes = []
    for stepsize in cfg.TRAIN.STEPSIZE:
      if last_snapshot_iter > stepsize:
        rate *= cfg.TRAIN.GAMMA
      else:
        stepsizes.append(stepsize)

    return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

  def remove_snapshot(self, np_paths, ss_paths):
    to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
    for c in range(to_remove):
      nfile = np_paths[0]
      os.remove(str(nfile))
      np_paths.remove(nfile)

    to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
    for c in range(to_remove):
      sfile = ss_paths[0]
      # To make the code compatible to earlier versions of Tensorflow,
      # where the naming tradition for checkpoints are different
      if os.path.exists(str(sfile)):
        os.remove(str(sfile))
      else:
        os.remove(str(sfile + '.data-00000-of-00001'))
        os.remove(str(sfile + '.index'))
      sfile_meta = sfile + '.meta'
      os.remove(str(sfile_meta))
      ss_paths.remove(sfile)

  def train_model(self, sess, max_iters):
    # Build data layers for both training and validation set
    self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
    #self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)

    # Construct the computation graph
    lr, train_op = self.construct_graph(sess)

    # Find previous snapshots if there is any to restore from
    lsf, nfiles, sfiles = self.find_previous()

    # Initialize the variables or restore them from the last snapshot
    if lsf == 0:
      rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(sess)
    else:
      rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(sess, 
                                                                            str(sfiles[-1]), 
                                                                            str(nfiles[-1]))
    timer = Timer()
    iter = last_snapshot_iter + 1
    last_summary_time = time.time()
    # Make sure the lists are not empty
    stepsizes.append(max_iters)
    stepsizes.reverse()
    next_stepsize = stepsizes.pop()
    while iter < max_iters + 1:
      # Learning rate
      if iter == next_stepsize + 1:
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        rate *= cfg.TRAIN.GAMMA
        sess.run(tf.assign(lr, rate))
        next_stepsize = stepsizes.pop()

      timer.tic()
      # Get training data, one batch at a time
      blobs = self.data_layer.forward()

      now = time.time()
      if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
        # Compute the graph with summary
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
          self.net.train_step_with_summary(sess, blobs, train_op)
        self.writer.add_summary(summary, float(iter))
        # Also check the summary on the validation set
        #blobs_val = self.data_layer_val.forward()
        #summary_val = self.net.get_summary(sess, blobs_val)
        #self.valwriter.add_summary(summary_val, float(iter))
        last_summary_time = now
      else:
        # Compute the graph without summary
        rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
          self.net.train_step(sess, blobs, train_op)
      timer.toc()

      # Display training information
      if iter % (cfg.TRAIN.DISPLAY) == 0:
        print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
              '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
              (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
        print('speed: {:.3f}s / iter'.format(timer.average_time))

      # Snapshotting
      if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
        last_snapshot_iter = iter
        ss_path, np_path = self.snapshot(sess, iter)
        np_paths.append(np_path)
        ss_paths.append(ss_path)

        # Remove the old snapshots if there are too many
        if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
          self.remove_snapshot(np_paths, ss_paths)

      iter += 1

    if last_snapshot_iter != iter - 1:
      self.snapshot(sess, iter - 1)

    self.writer.close()
Beispiel #28
0
class SolverWrapper(object):
    """ A wrapper class for the training process """
    def __init__(self,
                 sess,
                 network,
                 imdb,
                 roidb,
                 valroidb,
                 output_dir,
                 pretrained_model=None):
        self.net = network
        self.imdb = imdb
        self.roidb = roidb
        self.valroidb = valroidb
        self.output_dir = output_dir
        self.pretrained_model = pretrained_model

    def snapshot(self, sess, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.ckpt'
        filename = os.path.join(self.output_dir, filename)
        self.saver.save(sess, filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indexes of the database
        perm = self.data_layer._perm
        # current position in the validation database
        cur_val = self.data_layer_val._cur
        # current shuffled indexes of the validation database
        perm_val = self.data_layer_val._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename

    def from_snapshot(self, sess, sfile, nfile):
        print('Restoring model snapshots from {:s}'.format(sfile))
        self.saver.restore(sess, sfile)
        print('Restored.')
        # Needs to restore the other hyper-parameters/states for training, I have
        # tried my best to find the random states so that it can be recovered exactly
        # However the Tensorflow state is currently not available
        with open(nfile, 'rb') as fid:
            st0 = pickle.load(fid)
            cur = pickle.load(fid)
            perm = pickle.load(fid)
            cur_val = pickle.load(fid)
            perm_val = pickle.load(fid)
            last_snapshot_iter = pickle.load(fid)

            np.random.set_state(st0)
            self.data_layer._cur = cur
            self.data_layer._perm = perm
            self.data_layer_val._cur = cur_val
            self.data_layer_val._perm = perm_val

        return last_snapshot_iter

    def get_variables_in_checkpoint_file(self, file_name):
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print(
                    "It's likely that your checkpoint file has been compressed "
                    "with SNAPPY.")

    def construct_graph(self, sess):
        # Set the random seed for tensorflow
        tf.set_random_seed(cfg.RNG_SEED)
        with sess.graph.as_default():
            # Build the main computation graph
            layers = self.net.create_architecture(
                'TRAIN',
                tag='default',
                anchor_scales=cfg.ANCHOR_SCALES,
                anchor_ratios=cfg.ANCHOR_RATIOS)
            # Define the loss
            loss = layers['total_loss']
            # Set learning rate and momentum
            lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
            self.optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)

            # Compute the gradients with regard to the loss
            gvs = self.optimizer.compute_gradients(loss)
            # Double the gradient of the bias if set
            if cfg.TRAIN.DOUBLE_BIAS:
                final_gvs = []
                with tf.variable_scope('Gradient_Mult') as scope:
                    for grad, var in gvs:
                        scale = 1.
                        if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((grad, var))
                train_op = self.optimizer.apply_gradients(final_gvs)
            else:
                train_op = self.optimizer.apply_gradients(gvs)

            # Initialize post-hist module of drl-RPN
            if cfg.DRL_RPN.USE_POST:
                loss_post = layers['total_loss_hist']
                lr_post = tf.Variable(cfg.DRL_RPN_TRAIN.POST_LR,
                                      trainable=False)
                self.optimizer_post = tf.train.MomentumOptimizer(
                    lr, cfg.TRAIN.MOMENTUM)
                gvs_post = self.optimizer_post.compute_gradients(loss_post)
                train_op_post = self.optimizer_post.apply_gradients(gvs_post)
            else:
                lr_post = None
                train_op_post = None

            # Initialize main drl-RPN network
            self.net.build_drl_rpn_network()

        return lr, train_op, lr_post, train_op_post

    def initialize(self, sess):
        # Initial file lists are empty
        np_paths = []
        ss_paths = []
        # Fresh train directly from ImageNet weights
        print('Loading initial model weights from {:s}'.format(
            self.pretrained_model))
        variables = tf.global_variables()
        # Initialize all variables first
        sess.run(tf.variables_initializer(variables, name='init'))
        var_keep_dic = self.get_variables_in_checkpoint_file(
            self.pretrained_model)
        #print(self.pretrained_model)
        #sleep(100)
        # Get the variables to restore, ignoring the variables to fix
        variables_to_restore = self.net.get_variables_to_restore(
            variables, var_keep_dic)
        restorer = tf.train.Saver(variables_to_restore)
        restorer.restore(sess, self.pretrained_model)
        print('Loaded.')
        # Need to fix the variables before loading, so that the RGB weights are
        # hanged to BGR For VGG16 it also changes the convolutional weights fc6
        # and fc7 to fully connected weights
        #
        # NOTE: IF YOU WANT TO TRAIN FROM EXISTING FASTER
        # R-CNN WEIGHTS, AND NOT FROM IMAGENET WEIGHTS, SET BELOW FLAG TO FALSE!!!!
        self.net.fix_variables(sess, self.pretrained_model, False)
        print('Fixed.')
        last_snapshot_iter = 0
        rate = cfg.TRAIN.LEARNING_RATE
        stepsizes = list(cfg.TRAIN.STEPSIZE)
        return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def restore(self, sess, sfile, nfile):
        # Get the most recent snapshot and restore
        np_paths = [nfile]
        ss_paths = [sfile]
        # Restore model from snapshots
        last_snapshot_iter = self.from_snapshot(sess, sfile, nfile)
        # Set the learning rate
        rate = cfg.TRAIN.LEARNING_RATE
        stepsizes = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            if last_snapshot_iter > stepsize:
                rate *= cfg.TRAIN.GAMMA
            else:
                stepsizes.append(stepsize)
        return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def remove_snapshot(self, np_paths, ss_paths):
        to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            nfile = np_paths[0]
            os.remove(str(nfile))
            np_paths.remove(nfile)
        to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            sfile = ss_paths[0]
            # To make the code compatible to earlier versions of Tensorflow,
            # where the naming tradition for checkpoints are different
            if os.path.exists(str(sfile)):
                os.remove(str(sfile))
            else:
                os.remove(str(sfile + '.data-00000-of-00001'))
                os.remove(str(sfile + '.index'))
            sfile_meta = sfile + '.meta'
            os.remove(str(sfile_meta))
            ss_paths.remove(sfile)

    def _print_det_loss(self,
                        iter,
                        max_iters,
                        tot_loss,
                        loss_cls,
                        loss_box,
                        lr,
                        timer,
                        in_string='detector'):
        if (iter + 1) % (cfg.TRAIN.DISPLAY) == 0:
            if loss_box is not None:
                print('iter: %d / %d, total loss: %.6f\n '
                      '>>> loss_cls (%s): %.6f\n '
                      '>>> loss_box (%s): %.6f\n >>> lr: %f' % \
                      (iter + 1, max_iters, tot_loss, in_string, loss_cls, in_string,
                       loss_box, lr))
            else:
                print('iter: %d / %d, total loss (%s): %.6f\n >>> lr: %f' % \
                      (iter + 1, max_iters, in_string, tot_loss, lr))
            print('speed: {:.3f}s / iter'.format(timer.average_time))

    def _check_if_continue(self, iter, max_iters, snapshot_add):
        img_start_idx = cfg.DRL_RPN_TRAIN.IMG_START_IDX
        if iter > img_start_idx:
            return iter, max_iters, snapshot_add, False
        if iter < img_start_idx:
            print("iter %d < img_start_idx %d -- continuing" %
                  (iter, img_start_idx))
            iter += 1
            return iter, max_iters, snapshot_add, True
        if iter == img_start_idx:
            print("Adjusting stepsize, train-det-start etcetera")
            snapshot_add = img_start_idx
            max_iters -= img_start_idx
            iter = 0
            cfg_from_list(['DRL_RPN_TRAIN.IMG_START_IDX', -1])
            cfg_from_list([
                'DRL_RPN_TRAIN.DET_START',
                cfg.DRL_RPN_TRAIN.DET_START - img_start_idx
            ])
            cfg_from_list([
                'DRL_RPN_TRAIN.STEPSIZE',
                cfg.DRL_RPN_TRAIN.STEPSIZE - img_start_idx
            ])
            cfg_from_list(
                ['TRAIN.STEPSIZE', [cfg.TRAIN.STEPSIZE[0] - img_start_idx]])
            cfg_from_list([
                'DRL_RPN_TRAIN.POST_SS',
                [cfg.DRL_RPN_TRAIN.POST_SS[0] - img_start_idx]
            ])
            print("Done adjusting stepsize, train-det-start etcetera")
            return iter, max_iters, snapshot_add, False

    def train_model(self, sess, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, cfg.NBR_CLASSES)
        self.data_layer_val = RoIDataLayer(self.valroidb, cfg.NBR_CLASSES,
                                           True)

        # Construct the computation graph corresponding to the original Faster R-CNN
        # architecture first (and potentially the post-hist module of drl-RPN)
        lr_det_op, train_op, lr_post_op, train_op_post \
          = self.construct_graph(sess)

        # Initialize the variables or restore them from the last snapshot
        rate, last_snapshot_iter, stepsizes, np_paths, ss_paths \
          = self.initialize(sess)

        # We will handle the snapshots ourselves
        self.saver = tf.train.Saver(max_to_keep=100000)

        # Initialize
        self.net.init_rl_train(sess)

        # Setup initial learning rates
        lr_rl = cfg.DRL_RPN_TRAIN.LEARNING_RATE
        lr_det = cfg.TRAIN.LEARNING_RATE
        sess.run(tf.assign(lr_det_op, lr_det))
        if cfg.DRL_RPN.USE_POST:
            lr_post = cfg.DRL_RPN_TRAIN.POST_LR
            sess.run(tf.assign(lr_post_op, lr_post))

        # Sample first beta
        if cfg.DRL_RPN_TRAIN.USE_POST:
            betas = cfg.DRL_RPN_TRAIN.POST_BETAS
        else:
            betas = cfg.DRL_RPN_TRAIN.BETAS
        beta_idx = 0
        beta = betas[beta_idx]

        # Setup drl-RPN timers
        timers = {
            'init': Timer(),
            'fulltraj': Timer(),
            'upd-obs-vol': Timer(),
            'upd-seq': Timer(),
            'upd-rl': Timer(),
            'action-rl': Timer(),
            'coll-traj': Timer(),
            'run-drl-rpn': Timer(),
            'train-drl-rpn': Timer()
        }

        # Create StatCollector (tracks various RL training statistics)
        stat_strings = [
            'reward', 'rew-done', 'traj-len', 'frac-area', 'gt >= 0.5 frac',
            'gt-IoU-frac'
        ]
        sc = StatCollector(max_iters, stat_strings)

        timer = Timer()
        iter = 0
        snapshot_add = 0
        while iter < max_iters:

            # Get training data, one batch at a time (assumes batch size 1)
            blobs = self.data_layer.forward()

            # Allows the possibility to start at arbitrary image, rather
            # than always starting from first image in dataset. Useful if
            # want to load parameters and keep going from there, rather
            # than having those and encountering visited images again.
            iter, max_iters, snapshot_add, do_continue \
              = self._check_if_continue(iter, max_iters, snapshot_add)
            if do_continue:
                continue

            if not cfg.DRL_RPN_TRAIN.USE_POST:

                # Potentially update drl-RPN learning rate
                if (iter + 1) % cfg.DRL_RPN_TRAIN.STEPSIZE == 0:
                    lr_rl *= cfg.DRL_RPN_TRAIN.GAMMA

                # Run drl-RPN in training mode
                timers['run-drl-rpn'].tic()
                stats = run_drl_rpn(sess,
                                    self.net,
                                    blobs,
                                    timers,
                                    mode='train',
                                    beta=beta,
                                    im_idx=None,
                                    extra_args=lr_rl)
                timers['run-drl-rpn'].toc()

                if (iter + 1) % cfg.DRL_RPN_TRAIN.BATCH_SIZE == 0:
                    print(
                        "\n##### DRL-RPN BATCH GRADIENT UPDATE - START ##### \n"
                    )
                    print('iter: %d / %d' % (iter + 1, max_iters))
                    print('lr-rl: %f' % lr_rl)
                    timers['train-drl-rpn'].tic()
                    self.net.train_drl_rpn(sess, lr_rl, sc, stats)
                    timers['train-drl-rpn'].toc()
                    sc.print_stats()
                    print('TIMINGS:')
                    print('runnn-drl-rpn: %.4f' %
                          timers['run-drl-rpn'].get_avg())
                    print('train-drl-rpn: %.4f' %
                          timers['train-drl-rpn'].get_avg())
                    print(
                        "\n##### DRL-RPN BATCH GRADIENT UPDATE - DONE ###### \n"
                    )

                    # Also sample new beta for next batch
                    beta_idx += 1
                    beta_idx %= len(betas)
                    beta = betas[beta_idx]
                else:
                    sc.update(0, stats)

                # At this point we assume that an RL-trajectory has been performed.
                # We next train detector with drl-RPN running in deterministic mode.
                # Potentially train detector component of network
                if cfg.DRL_RPN_TRAIN.DET_START >= 0 and \
                  iter >= cfg.DRL_RPN_TRAIN.DET_START:

                    # Run drl-RPN in deterministic mode
                    net_conv, rois_drl_rpn, gt_boxes, im_info, timers, _ \
                      = run_drl_rpn(sess, self.net, blobs, timers, mode='train_det',
                                    beta=beta, im_idx=None)

                    # Learning rate
                    if (iter + 1) % cfg.TRAIN.STEPSIZE[0] == 0:
                        lr_det *= cfg.TRAIN.GAMMA
                        sess.run(tf.assign(lr_det_op, lr_det))

                    timer.tic()
                    # Train detector part
                    loss_cls, loss_box, tot_loss \
                      = self.net.train_step_det(sess, train_op, net_conv, rois_drl_rpn,
                                                gt_boxes, im_info)
                    timer.toc()

                    # Display training information
                    self._print_det_loss(iter, max_iters, tot_loss, loss_cls,
                                         loss_box, lr_det, timer)

            # Train post-hist module AFTER we have trained rest of drl-RPN! Specifically
            # once rest of drl-RPN has been trained already, copy those weights into
            # the folder of pretrained weights and rerun training with those as initial
            # weights, which will then train only the posterior-history module
            else:

                # The very first time we need to assign the ordinary detector weights
                # as starting point
                if iter == 0:
                    self.net.assign_post_hist_weights(sess)

                # Sample beta
                beta = betas[beta_idx]
                beta_idx += 1
                beta_idx %= len(betas)

                # Run drl-RPN in deterministic mode
                net_conv, rois_drl_rpn, gt_boxes, im_info, timers, cls_hist \
                  = run_drl_rpn(sess, self.net, blobs, timers, mode='train_det',
                                beta=beta, im_idx=None)

                # Learning rate (assume only one learning rate iter for now!)
                if (iter + 1) % cfg.DRL_RPN_TRAIN.POST_SS[0] == 0:
                    lr_post *= cfg.TRAIN.GAMMA
                    sess.run(tf.assign(lr_post_op, lr_post))

                # Train post-hist detector part
                tot_loss = self.net.train_step_post(sess, train_op_post,
                                                    net_conv, rois_drl_rpn,
                                                    gt_boxes, im_info,
                                                    cls_hist)

                # Display training information
                self._print_det_loss(iter, max_iters, tot_loss, None, None,
                                     lr_post, timer, 'post-hist')

            # Snapshotting
            if (iter + 1) % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter + 1
                ss_path, np_path = self.snapshot(sess, iter + 1 + snapshot_add)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            # Increase iteration counter
            iter += 1

        # Potentially save one last time
        if last_snapshot_iter != iter:
            self.snapshot(sess, iter + snapshot_add)
Beispiel #29
0
class Solver(object):
    def __init__(self, roidb, net, freeze=0):
        # Holds current iteration number.
        self.iter = 0

        # How frequently we should print the training info.
        self.display_freq = 1

        # Holds the path prefix for snapshots.
        self.snapshot_prefix = 'snapshot'

        self.roidb = roidb
        self.net = net
        self.freeze = freeze
        self.roi_data_layer = RoIDataLayer()
        self.roi_data_layer.setup()
        self.roi_data_layer.set_roidb(self.roidb)
        self.stepfn = self.build_step_fn(self.net)
        self.predfn = self.build_pred_fn(self.net)

    # This might be a useful static method to have.
    #@staticmethod not so static anymore
    def build_step_fn(self, net):
        target_y = T.vector("target Y", dtype='int64')
        tl = lasagne.objectives.categorical_crossentropy(
            net.prediction, target_y)
        loss = tl.mean()
        accuracy = lasagne.objectives.categorical_accuracy(
            net.prediction, target_y).mean()

        weights = net.params
        grads = theano.grad(loss, weights)

        scales = np.ones(len(weights))

        if self.freeze:
            scales[:-self.freeze] = 0

        print 'GRAD SCALE >>>', scales

        for idx, param in enumerate(weights):
            grads[idx] *= scales[idx]
            grads[idx] = grads[idx].astype('float32')

        #updates_sgd = lasagne.updates.sgd(loss, net.params, learning_rate=0.0001)
        updates_sgd = lasagne.updates.sgd(grads,
                                          net.params,
                                          learning_rate=0.0001)

        stepfn = theano.function([net.inp, target_y], [loss, accuracy],
                                 updates=updates_sgd,
                                 allow_input_downcast=True)
        return stepfn

    @staticmethod
    def build_pred_fn(net):
        predfn = theano.function([net.inp],
                                 net.prediction,
                                 allow_input_downcast=True)
        return predfn

    def get_training_batch(self):
        """Uses ROIDataLayer to fetch a training batch.

    Returns:
      input_data (ndarray): input data suitable for R-CNN processing
      labels (ndarray): batch labels (of type int32)
    """
        data, rois, labels = deepcopy(self.roi_data_layer.top[:3])
        X = roi_layer(data, rois)
        y = labels.astype('int')

        return X, y

    def step(self):
        self.roi_data_layer.forward()
        data, labels = self.get_training_batch()
        """Conducts a single step of SGD."""

        loss, acc = self.stepfn(data, labels)

        self.loss = loss
        self.acc = acc
        ###################################################### Your code goes here.
        # Among other things, assign the current loss value to self.loss.

        self.iter += 1
        if self.iter % self.display_freq == 0:
            print 'Iteration {:<5} Train loss: {} Train acc: {}'.format(
                self.iter, self.loss, self.acc)

    def save(self, filename):
        self.net.save(filename)
Beispiel #30
0
    def train_model(self, sess, max_iters, imdb_name):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb,
                                           self.imdb.num_classes,
                                           random=True)

        # Determine different scales for anchors, see paper
        with sess.graph.as_default():
            # Set the random seed for tensorflow
            tf.set_random_seed(cfg.RNG_SEED)
            # Build the main computation graph
            layers = self.net.create_architecture(
                sess,
                'TRAIN',
                self.imdb.num_classes,
                tag='default',
                anchor_scales=cfg.ANCHOR_SCALES,
                anchor_ratios=cfg.ANCHOR_RATIOS)
            # Define the loss9
            loss = layers['total_loss']
            # Set learning rate and momentum
            lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
            momentum = cfg.TRAIN.MOMENTUM
            self.optimizer = tf.train.MomentumOptimizer(lr, momentum)

            gvs = self.optimizer.compute_gradients(loss)
            # Double the gradient of the bias if set
            if cfg.TRAIN.DOUBLE_BIAS:
                final_gvs = []
                with tf.variable_scope('Gradient_Mult') as scope:
                    for grad, var in gvs:
                        scale = 1.
                        if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((tf.clip_by_value(grad, -5.0,
                                                           5.0), var))
                train_op = self.optimizer.apply_gradients(final_gvs)
            else:
                #final_gvs = []
                #with tf.variable_scope('Gradient_Mult') as scope:
                #for grad, var in gvs:
                #final_gvs.append((tf.clip_by_value(grad,-50.0,50.0), var))
                train_op = self.optimizer.apply_gradients(gvs)
            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=100000)
            # Write the train and validation information to tensorboard
            #print(sess.graph)
            self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
            #print('====write done======================')
            self.valwriter = tf.summary.FileWriter(self.tbvaldir)

        # Find previous snapshots if there is any to restore from
        sfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
        sfiles = glob.glob(sfiles)
        sfiles.sort(key=os.path.getmtime)
        # Get the snapshot name in TensorFlow
        redstr = '_iter_{:d}.'.format(cfg.TRAIN.STEPSIZE[0] + 1)
        sfiles = [ss.replace('.meta', '') for ss in sfiles]
        sfiles = [ss for ss in sfiles if redstr not in ss]

        nfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
        nfiles = glob.glob(nfiles)
        nfiles.sort(key=os.path.getmtime)
        nfiles = [nn for nn in nfiles if redstr not in nn]

        lsf = len(sfiles)
        assert len(nfiles) == lsf

        np_paths = nfiles
        ss_paths = sfiles

        if lsf == 0:
            # Fresh train directly from ImageNet weights
            print('Loading initial model weights from {:s}'.format(
                self.pretrained_model))
            variables = tf.global_variables()
            for var in variables:
                print(var.name)
            # Initialize all variables first
            sess.run(tf.variables_initializer(variables, name='init'))
            print(self.pretrained_model)
            var_keep_dic = self.get_variables_in_checkpoint_file(
                self.pretrained_model)
            for v in var_keep_dic:
                if v.split('/')[0] == 'feature_fuse':
                    print('Pre Varibles restored: %s' % v)
            # Get the variables to restore, ignorizing the variables to fix
            variables_to_restore = self.net.get_variables_to_restore(
                variables, var_keep_dic)

            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(sess, self.pretrained_model)
            print('Loaded.')
            # Need to fix the variables before loading, so that the RGB weights are changed to BGR
            # For VGG16 it also changes the convolutional weights fc6 and fc7 to
            # fully connected weights
            self.net.fix_variables(sess, self.pretrained_model)
            print('Fixed.')
            sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))
            last_snapshot_iter = 0
        else:
            # Get the most recent snapshot and restore
            ss_paths = [ss_paths[-1]]
            np_paths = [np_paths[-1]]

            print('Restorining model snapshots from {:s}'.format(sfiles[-1]))
            self.saver.restore(sess, str(sfiles[-1]))
            print('Restored.')
            # Needs to restore the other hyperparameters/states for training, (TODO xinlei) I have
            # tried my best to find the random states so that it can be recovered exactly
            # However the Tensorflow state is currently not available
            with open(str(nfiles[-1]), 'rb') as fid:
                st0 = pickle.load(fid)
                cur = pickle.load(fid)
                perm = pickle.load(fid)
                cur_val = pickle.load(fid)
                perm_val = pickle.load(fid)
                last_snapshot_iter = pickle.load(fid)

                np.random.set_state(st0)
                self.data_layer._cur = cur
                self.data_layer._perm = perm
                self.data_layer_val._cur = cur_val
                self.data_layer_val._perm = perm_val

                # Set the learning rate, only reduce once
                if last_snapshot_iter > cfg.TRAIN.STEPSIZE[
                        0] and last_snapshot_iter <= cfg.TRAIN.STEPSIZE[1]:
                    sess.run(
                        tf.assign(lr,
                                  cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA))
                elif last_snapshot_iter > cfg.TRAIN.STEPSIZE[1]:
                    sess.run(
                        tf.assign(
                            lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA *
                            cfg.TRAIN.GAMMA))
                else:
                    sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))

        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()

        #reset constrain_conv

        tmp_tensor1 = self.set_constrain()
        update_weights1 = tf.assign(
            tf.get_default_graph().get_tensor_by_name(
                'noise/constrained_conv/weights:0'), tmp_tensor1)
        biase11 = tf.get_default_graph().get_tensor_by_name(
            'noise/constrained_conv/biases:0')
        biase1 = tf.multiply(biase11, 0)

        update_biases1 = tf.assign(
            tf.get_default_graph().get_tensor_by_name(
                'noise/constrained_conv/biases:0'), biase1)
        update1 = [update_weights1, update_biases1]
        sess.run(update1)

        #ini op used in while
        with tf.control_dependencies([train_op]):
            tmp_tensor = self.set_constrain()
            update_weights = tf.assign(
                tf.get_default_graph().get_tensor_by_name(
                    'noise/constrained_conv/weights:0'), tmp_tensor)

        biase1 = tf.get_default_graph().get_tensor_by_name(
            'noise/constrained_conv/biases:0')
        biase = tf.multiply(biase1, 0)
        with tf.control_dependencies([update_weights]):
            update_biases = tf.assign(
                tf.get_default_graph().get_tensor_by_name(
                    'noise/constrained_conv/biases:0'), biase)

        while iter < max_iters + 1:

            timer.tic()

            if iter == cfg.TRAIN.STEPSIZE[0] + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(sess, iter)
                sess.run(
                    tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA))
            elif iter == cfg.TRAIN.STEPSIZE[1] + 1:
                self.snapshot(sess, iter)
                sess.run(tf.assign(lr, lr.eval() * cfg.TRAIN.GAMMA))

            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            now = time.time()
            if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box,  total_loss, summary = \
                    self.net.train_step_with_summary_without_mask(sess, update_weights, update_biases, blobs, train_op)
                self.writer.add_summary(summary, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(sess, blobs_val)
                self.valwriter.add_summary(summary_val, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box,  total_loss = \
                    self.net.train_step_without_mask(sess, update_weights, update_biases, blobs, train_op)

            timer.toc()
            #  print(sess.run(tf.get_default_graph().get_tensor_by_name('noise/constrained_conv/weights:0')[2, 2, :, :]))

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
                      (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
                print('speed: {:.3f}s / iter'.format(timer.average_time))
                print('remaining time: {:.3f}h'.format(
                    ((max_iters - iter) * timer.average_time) / 3600))

            wandb.log({
                'iter':
                iter,
                'total_loss':
                total_loss,
                'rpn_loss_cls':
                rpn_loss_cls,
                'rpn_loss_box':
                rpn_loss_box,
                'loss_cls':
                loss_cls,
                'loss_box':
                loss_box,
                'speed':
                timer.average_time,
                'remaining_time':
                ((max_iters - iter) * timer.average_time) / 3600
            })

            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                snapshot_path, np_path = self.snapshot(sess, iter)
                np_paths.append(np_path)
                ss_paths.append(snapshot_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
                    for c in range(to_remove):
                        nfile = np_paths[0]
                        os.remove(str(nfile))
                        np_paths.remove(nfile)

                if len(ss_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
                    for c in range(to_remove):
                        sfile = ss_paths[0]
                        # To make the code compatible to earlier versions of Tensorflow,
                        # where the naming tradition for checkpoints are different
                        if os.path.exists(str(sfile)):
                            os.remove(str(sfile))
                        else:
                            os.remove(str(sfile + '.data-00000-of-00001'))
                            os.remove(str(sfile + '.index'))
                        sfile_meta = sfile + '.meta'
                        os.remove(str(sfile_meta))
                        ss_paths.remove(sfile)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(sess, iter - 1)

        self.writer.close()
        self.valwriter.close()
Beispiel #31
0
    def train_model(self, sess, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        if self.valroidb is not None:
            self.data_layer_val = RoIDataLayer(self.valroidb,
                                               self.imdb.num_classes,
                                               random=True)

        # Determine different scales for anchors, see paper
        if self.imdb.name.startswith('voc'):
            anchors = [8, 16, 32]
        else:
            anchors = [4, 8, 16, 32]

        with sess.graph.as_default():
            # Set the random seed for tensorflow
            tf.set_random_seed(cfg.RNG_SEED)
            # Build the main computation graph
            layers = self.net.create_architecture(
                sess,
                'TRAIN',
                self.imdb.num_classes,
                caffe_weight_path=self.pretrained_model,
                tag='default',
                anchor_scales=anchors)
            # Define the loss
            loss = layers['total_loss']
            # Set learning rate and momentum
            lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
            momentum = cfg.TRAIN.MOMENTUM
            self.optimizer = tf.train.MomentumOptimizer(lr, momentum)

            # Compute the gradients wrt the loss
            gvs = self.optimizer.compute_gradients(loss)
            # Double the gradient of the bias if set
            if cfg.TRAIN.DOUBLE_BIAS:
                final_gvs = []
                with tf.variable_scope('Gradient_Mult') as scope:
                    for grad, var in gvs:
                        scale = 1.
                        if cfg.TRAIN.DOUBLE_BIAS and '/bias:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((grad, var))
                train_op = self.optimizer.apply_gradients(final_gvs)
            else:
                train_op = self.optimizer.apply_gradients(gvs)

            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=100000)
            # Write the train and validation information to tensorboard
            self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
            self.valwriter = tf.summary.FileWriter(self.tbvaldir)

        # Find previous snapshots if there is any to restore from
        sfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
        sfiles = glob.glob(sfiles)
        sfiles.sort(key=os.path.getmtime)
        # Get the snapshot name in TensorFlow
        sfiles = [ss.replace('.meta', '') for ss in sfiles]

        nfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
        nfiles = glob.glob(nfiles)
        nfiles.sort(key=os.path.getmtime)

        lsf = len(sfiles)
        assert len(nfiles) == lsf

        np_paths = nfiles
        ss_paths = sfiles

        if lsf == 0:
            # Fresh train directly from VGG weights
            print('Loading initial model weights from {:s}'.format(
                self.pretrained_model))
            variables = tf.global_variables()

            # Only initialize the variables that were not initialized when the graph was built
            sess.run(tf.variables_initializer(variables, name='init'))
            var_keep_dic = self.get_variables_in_checkpoint_file(
                self.pretrained_model)
            variables_to_restore = []
            # print(var_keep_dic)
            for v in variables:
                if v.name.split(':')[0] in var_keep_dic:
                    variables_to_restore.append(v)
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(sess, self.pretrained_model)
            print('Loaded.')
            sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))
            last_snapshot_iter = 0
        else:
            # Get the most recent snapshot and restore
            ss_paths = [ss_paths[-1]]
            np_paths = [np_paths[-1]]

            print('Restorining model snapshots from {:s}'.format(sfiles[-1]))
            self.saver.restore(sess, str(sfiles[-1]))
            print('Restored.')
            # Needs to restore the other hyperparameters/states for training, (TODO xinlei) I have
            # tried my best to find the random states so that it can be recovered exactly
            # However the Tensorflow state is currently not available
            with open(str(nfiles[-1]), 'rb') as fid:
                st0 = pickle.load(fid)
                cur = pickle.load(fid)
                perm = pickle.load(fid)
                if self.valroidb is not None:
                    cur_val = pickle.load(fid)
                    perm_val = pickle.load(fid)
                last_snapshot_iter = pickle.load(fid)

                np.random.set_state(st0)
                self.data_layer._cur = cur
                self.data_layer._perm = perm
                if self.valroidb is not None:
                    self.data_layer_val._cur = cur_val
                    self.data_layer_val._perm = perm_val

                # Set the learning rate, only reduce once
                if last_snapshot_iter >= cfg.TRAIN.STEPSIZE:
                    sess.run(
                        tf.assign(lr,
                                  cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA))
                else:
                    sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))

        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        while iter < max_iters + 1:
            # Learning rate
            if iter == cfg.TRAIN.STEPSIZE:
                sess.run(
                    tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA))

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            now = time.time()
            if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
                  self.net.train_step_with_summary(sess, blobs, train_op)
                self.writer.add_summary(summary, float(iter))
                # Also check the summary on the validation set
                if self.valroidb is not None:
                    blobs_val = self.data_layer_val.forward()
                    summary_val = self.net.get_summary(sess, blobs_val)
                    self.valwriter.add_summary(summary_val, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
                  self.net.train_step(sess, blobs, train_op)
            timer.toc()

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
                      (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
                print('speed: {:.3f}s / iter'.format(timer.average_time))

            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                snapshot_path, np_path = self.snapshot(sess, iter)
                np_paths.append(np_path)
                ss_paths.append(snapshot_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
                    for c in range(to_remove):
                        nfile = np_paths[0]
                        os.remove(str(nfile))
                        np_paths.remove(nfile)

                if len(ss_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
                    for c in range(to_remove):
                        sfile = ss_paths[0]
                        # To make the code compatible to earlier versions of Tensorflow,
                        # where the naming tradition for checkpoints are different
                        if os.path.exists(str(sfile)):
                            os.remove(str(sfile))
                        else:
                            os.remove(str(sfile + '.data-00000-of-00001'))
                            os.remove(str(sfile + '.index'))
                        sfile_meta = sfile + '.meta'
                        os.remove(str(sfile_meta))
                        ss_paths.remove(sfile)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(sess, iter - 1)

        self.writer.close()
        self.valwriter.close()
Beispiel #32
0
class SolverWrapper(object):
    """
    A wrapper class for the training process
  """
    def __init__(self,
                 sess,
                 network,
                 imdb,
                 roidb,
                 valroidb,
                 output_dir,
                 tbdir,
                 pretrained_model=None):
        self.net = network
        self.imdb = imdb
        self.roidb = roidb
        self.valroidb = valroidb
        self.output_dir = output_dir
        self.tbdir = tbdir
        # Simply put '_val' at the end to save the summaries from the validation set
        self.tbvaldir = tbdir + '_val'
        if not os.path.exists(self.tbvaldir):
            os.makedirs(self.tbvaldir)
        self.pretrained_model = pretrained_model

    def snapshot(self, sess, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.ckpt'
        filename = os.path.join(self.output_dir, filename)
        self.saver.save(sess, filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indeces of the database
        perm = self.data_layer._perm
        # current position in the validation database
        cur_val = self.data_layer_val._cur
        # current shuffled indeces of the validation database
        perm_val = self.data_layer_val._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename

    def get_variables_in_checkpoint_file(self, file_name):
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print(
                    "It's likely that your checkpoint file has been compressed "
                    "with SNAPPY.")

    def set_value(self, matrix, x, y, val):
        # 得到张量的宽和高,即第一维和第二维的Size
        batch = 3
        w = 5
        h = 5
        # 构造一个只有目标位置有值的稀疏矩阵,其值为目标值于原始值的差
        val_diff = val - matrix[x][y]
        diff_matrix = tf.sparse_tensor_to_dense(
            tf.SparseTensor(indices=[x, y],
                            values=[val_diff],
                            shape=[batch, w, h]))
        # 用 Variable.assign_add 将两个矩阵相加
        matrix.assign_add(diff_matrix)

    def set_constrain(self):
        with tf.device("/gpu:0"):
            tmp_np = tf.get_default_graph().get_tensor_by_name(
                'noise/constrained_conv/weights:0').eval()
            tmp_np[2, 2, :, :] = 0
            for i in range(3):
                tmp_np[:, :, 0,
                       i] = tmp_np[:, :, 0, i] / tmp_np[:, :, 0, i].sum()
                tmp_np[:, :, 1,
                       i] = tmp_np[:, :, 1, i] / tmp_np[:, :, 1, i].sum()
                tmp_np[:, :, 2,
                       i] = tmp_np[:, :, 2, i] / tmp_np[:, :, 2, i].sum(
                       )  # Element-wise division by the sum
            tmp_np[2, 2, :, :] = -1

            tmp_tensor = tf.convert_to_tensor(tmp_np, dtype=tf.float32)
            return tmp_tensor

    def train_model(self, sess, max_iters, imdb_name):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb,
                                           self.imdb.num_classes,
                                           random=True)

        # Determine different scales for anchors, see paper
        with sess.graph.as_default():
            # Set the random seed for tensorflow
            tf.set_random_seed(cfg.RNG_SEED)
            # Build the main computation graph
            layers = self.net.create_architecture(
                sess,
                'TRAIN',
                self.imdb.num_classes,
                tag='default',
                anchor_scales=cfg.ANCHOR_SCALES,
                anchor_ratios=cfg.ANCHOR_RATIOS)
            # Define the loss9
            loss = layers['total_loss']
            # Set learning rate and momentum
            lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
            momentum = cfg.TRAIN.MOMENTUM
            self.optimizer = tf.train.MomentumOptimizer(lr, momentum)

            gvs = self.optimizer.compute_gradients(loss)
            # Double the gradient of the bias if set
            if cfg.TRAIN.DOUBLE_BIAS:
                final_gvs = []
                with tf.variable_scope('Gradient_Mult') as scope:
                    for grad, var in gvs:
                        scale = 1.
                        if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((tf.clip_by_value(grad, -5.0,
                                                           5.0), var))
                train_op = self.optimizer.apply_gradients(final_gvs)
            else:
                #final_gvs = []
                #with tf.variable_scope('Gradient_Mult') as scope:
                #for grad, var in gvs:
                #final_gvs.append((tf.clip_by_value(grad,-50.0,50.0), var))
                train_op = self.optimizer.apply_gradients(gvs)
            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=100000)
            # Write the train and validation information to tensorboard
            #print(sess.graph)
            self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
            #print('====write done======================')
            self.valwriter = tf.summary.FileWriter(self.tbvaldir)

        # Find previous snapshots if there is any to restore from
        sfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
        sfiles = glob.glob(sfiles)
        sfiles.sort(key=os.path.getmtime)
        # Get the snapshot name in TensorFlow
        redstr = '_iter_{:d}.'.format(cfg.TRAIN.STEPSIZE[0] + 1)
        sfiles = [ss.replace('.meta', '') for ss in sfiles]
        sfiles = [ss for ss in sfiles if redstr not in ss]

        nfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
        nfiles = glob.glob(nfiles)
        nfiles.sort(key=os.path.getmtime)
        nfiles = [nn for nn in nfiles if redstr not in nn]

        lsf = len(sfiles)
        assert len(nfiles) == lsf

        np_paths = nfiles
        ss_paths = sfiles

        if lsf == 0:
            # Fresh train directly from ImageNet weights
            print('Loading initial model weights from {:s}'.format(
                self.pretrained_model))
            variables = tf.global_variables()
            for var in variables:
                print(var.name)
            # Initialize all variables first
            sess.run(tf.variables_initializer(variables, name='init'))
            print(self.pretrained_model)
            var_keep_dic = self.get_variables_in_checkpoint_file(
                self.pretrained_model)
            for v in var_keep_dic:
                if v.split('/')[0] == 'feature_fuse':
                    print('Pre Varibles restored: %s' % v)
            # Get the variables to restore, ignorizing the variables to fix
            variables_to_restore = self.net.get_variables_to_restore(
                variables, var_keep_dic)

            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(sess, self.pretrained_model)
            print('Loaded.')
            # Need to fix the variables before loading, so that the RGB weights are changed to BGR
            # For VGG16 it also changes the convolutional weights fc6 and fc7 to
            # fully connected weights
            self.net.fix_variables(sess, self.pretrained_model)
            print('Fixed.')
            sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))
            last_snapshot_iter = 0
        else:
            # Get the most recent snapshot and restore
            ss_paths = [ss_paths[-1]]
            np_paths = [np_paths[-1]]

            print('Restorining model snapshots from {:s}'.format(sfiles[-1]))
            self.saver.restore(sess, str(sfiles[-1]))
            print('Restored.')
            # Needs to restore the other hyperparameters/states for training, (TODO xinlei) I have
            # tried my best to find the random states so that it can be recovered exactly
            # However the Tensorflow state is currently not available
            with open(str(nfiles[-1]), 'rb') as fid:
                st0 = pickle.load(fid)
                cur = pickle.load(fid)
                perm = pickle.load(fid)
                cur_val = pickle.load(fid)
                perm_val = pickle.load(fid)
                last_snapshot_iter = pickle.load(fid)

                np.random.set_state(st0)
                self.data_layer._cur = cur
                self.data_layer._perm = perm
                self.data_layer_val._cur = cur_val
                self.data_layer_val._perm = perm_val

                # Set the learning rate, only reduce once
                if last_snapshot_iter > cfg.TRAIN.STEPSIZE[
                        0] and last_snapshot_iter <= cfg.TRAIN.STEPSIZE[1]:
                    sess.run(
                        tf.assign(lr,
                                  cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA))
                elif last_snapshot_iter > cfg.TRAIN.STEPSIZE[1]:
                    sess.run(
                        tf.assign(
                            lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA *
                            cfg.TRAIN.GAMMA))
                else:
                    sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))

        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()

        #reset constrain_conv

        tmp_tensor1 = self.set_constrain()
        update_weights1 = tf.assign(
            tf.get_default_graph().get_tensor_by_name(
                'noise/constrained_conv/weights:0'), tmp_tensor1)
        biase11 = tf.get_default_graph().get_tensor_by_name(
            'noise/constrained_conv/biases:0')
        biase1 = tf.multiply(biase11, 0)

        update_biases1 = tf.assign(
            tf.get_default_graph().get_tensor_by_name(
                'noise/constrained_conv/biases:0'), biase1)
        update1 = [update_weights1, update_biases1]
        sess.run(update1)

        #ini op used in while
        with tf.control_dependencies([train_op]):
            tmp_tensor = self.set_constrain()
            update_weights = tf.assign(
                tf.get_default_graph().get_tensor_by_name(
                    'noise/constrained_conv/weights:0'), tmp_tensor)

        biase1 = tf.get_default_graph().get_tensor_by_name(
            'noise/constrained_conv/biases:0')
        biase = tf.multiply(biase1, 0)
        with tf.control_dependencies([update_weights]):
            update_biases = tf.assign(
                tf.get_default_graph().get_tensor_by_name(
                    'noise/constrained_conv/biases:0'), biase)

        while iter < max_iters + 1:

            timer.tic()

            if iter == cfg.TRAIN.STEPSIZE[0] + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(sess, iter)
                sess.run(
                    tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA))
            elif iter == cfg.TRAIN.STEPSIZE[1] + 1:
                self.snapshot(sess, iter)
                sess.run(tf.assign(lr, lr.eval() * cfg.TRAIN.GAMMA))

            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            now = time.time()
            if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box,  total_loss, summary = \
                    self.net.train_step_with_summary_without_mask(sess, update_weights, update_biases, blobs, train_op)
                self.writer.add_summary(summary, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(sess, blobs_val)
                self.valwriter.add_summary(summary_val, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box,  total_loss = \
                    self.net.train_step_without_mask(sess, update_weights, update_biases, blobs, train_op)

            timer.toc()
            #  print(sess.run(tf.get_default_graph().get_tensor_by_name('noise/constrained_conv/weights:0')[2, 2, :, :]))

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
                      (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
                print('speed: {:.3f}s / iter'.format(timer.average_time))
                print('remaining time: {:.3f}h'.format(
                    ((max_iters - iter) * timer.average_time) / 3600))

            wandb.log({
                'iter':
                iter,
                'total_loss':
                total_loss,
                'rpn_loss_cls':
                rpn_loss_cls,
                'rpn_loss_box':
                rpn_loss_box,
                'loss_cls':
                loss_cls,
                'loss_box':
                loss_box,
                'speed':
                timer.average_time,
                'remaining_time':
                ((max_iters - iter) * timer.average_time) / 3600
            })

            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                snapshot_path, np_path = self.snapshot(sess, iter)
                np_paths.append(np_path)
                ss_paths.append(snapshot_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
                    for c in range(to_remove):
                        nfile = np_paths[0]
                        os.remove(str(nfile))
                        np_paths.remove(nfile)

                if len(ss_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
                    for c in range(to_remove):
                        sfile = ss_paths[0]
                        # To make the code compatible to earlier versions of Tensorflow,
                        # where the naming tradition for checkpoints are different
                        if os.path.exists(str(sfile)):
                            os.remove(str(sfile))
                        else:
                            os.remove(str(sfile + '.data-00000-of-00001'))
                            os.remove(str(sfile + '.index'))
                        sfile_meta = sfile + '.meta'
                        os.remove(str(sfile_meta))
                        ss_paths.remove(sfile)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(sess, iter - 1)

        self.writer.close()
        self.valwriter.close()
if rand_seed is not None:
    np.random.seed(rand_seed)

# load config file and get hyperparameters
cfg_from_file(cfg_file)
lr = cfg.TRAIN.LEARNING_RATE
momentum = cfg.TRAIN.MOMENTUM
weight_decay = cfg.TRAIN.WEIGHT_DECAY
disp_interval = cfg.TRAIN.DISPLAY
log_interval = cfg.TRAIN.LOG_IMAGE_ITERS

# load imdb and create data later
imdb = get_imdb(imdb_name)
rdl_roidb.prepare_roidb(imdb)
roidb = imdb.roidb
data_layer = RoIDataLayer(roidb, imdb.num_classes)

#pdb.set_trace()

# Create network and initialize
net = WSDDN(classes=imdb.classes, debug=_DEBUG)
network.weights_normal_init(net, dev=0.001)
if os.path.exists('pretrained_alexnet.pkl'):
    pret_net = pkl.load(open('pretrained_alexnet.pkl', 'r'))
else:
    pret_net = model_zoo.load_url(
        'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth')
    pkl.dump(pret_net, open('pretrained_alexnet.pkl', 'wb'),
             pkl.HIGHEST_PROTOCOL)
own_state = net.state_dict()
for name, param in pret_net.items():
  def train_model(self, sess, max_iters):
    # Build data layers for both training and validation set
    self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
    self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)

    # Determine different scales for anchors, see paper
    with sess.graph.as_default():
      # Set the random seed for tensorflow
      tf.set_random_seed(cfg.RNG_SEED)
      # Build the main computation graph
      layers = self.net.create_architecture(sess, 'TRAIN', self.imdb.num_classes, tag='default',
                                            anchor_scales=cfg.ANCHOR_SCALES,
                                            anchor_ratios=cfg.ANCHOR_RATIOS)
      # Define the loss
      loss = layers['total_loss']
      rpn_loss = layers['rpn_loss']
      class_loss = layers['class_loss']
      # Set learning rate and momentum
      lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
      momentum = cfg.TRAIN.MOMENTUM
      self.optimizer = tf.train.MomentumOptimizer(lr, momentum)

      # Compute the gradients wrt the loss
      gvs = self.optimizer.compute_gradients(loss)
      gvs_rpn = self.optimizer.compute_gradients(rpn_loss)
      gvs_class = self.optimizer.compute_gradients(class_loss)
      # Double the gradient of the bias if set
      if cfg.TRAIN.DOUBLE_BIAS:
        final_gvs = []
        with tf.variable_scope('Gradient_Mult') as scope:
          for grad, var in gvs:
            scale = 1.
            if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
              scale *= 2.
            if not np.allclose(scale, 1.0):
              grad = tf.multiply(grad, scale)
            final_gvs.append((grad, var))
        train_op = self.optimizer.apply_gradients(final_gvs)
        train_op_rpn = self.optimizer.apply_gradients(gvs_rpn)
        train_op_class = self.optimizer.apply_gradients(gvs_class)
      else:
        train_op = self.optimizer.apply_gradients(gvs)
        train_op_rpn = self.optimizer.apply_gradients(gvs_rpn)
        train_op_class = self.optimizer.apply_gradients(gvs_class)

      # We will handle the snapshots ourselves
      self.saver = tf.train.Saver(max_to_keep=100000)
      # Write the train and validation information to tensorboard
      self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
      self.valwriter = tf.summary.FileWriter(self.tbvaldir)

    # Find previous snapshots if there is any to restore from
    sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
    sfiles = glob.glob(sfiles)
    sfiles.sort(key=os.path.getmtime)
    # Get the snapshot name in TensorFlow
    redstr = '_iter_{:d}.'.format(cfg.TRAIN.STEPSIZE+1)
    sfiles = [ss.replace('.meta', '') for ss in sfiles]
    sfiles = [ss for ss in sfiles if redstr not in ss]

    nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
    nfiles = glob.glob(nfiles)
    nfiles.sort(key=os.path.getmtime)
    nfiles = [nn for nn in nfiles if redstr not in nn]

    lsf = len(sfiles)
    assert len(nfiles) == lsf

    np_paths = nfiles
    ss_paths = sfiles

    if lsf == 0:
      # Fresh train directly from ImageNet weights
      print('Loading initial model weights from {:s}'.format(self.pretrained_model))
      variables = tf.global_variables()
      # Initialize all variables first
      sess.run(tf.variables_initializer(variables, name='init'))
      var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model)
      # Get the variables to restore, ignorizing the variables to fix
      variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic)

      restorer = tf.train.Saver(variables_to_restore)
      restorer.restore(sess, self.pretrained_model)
      print('Loaded.')
      # Need to fix the variables before loading, so that the RGB weights are changed to BGR
      # For VGG16 it also changes the convolutional weights fc6 and fc7 to
      # fully connected weights
      self.net.fix_variables(sess, self.pretrained_model)
      print('Fixed.')
      sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))
      last_snapshot_iter = 0
    else:
      # Get the most recent snapshot and restore
      ss_paths = [ss_paths[-1]]
      np_paths = [np_paths[-1]]

      print('Restorining model snapshots from {:s}'.format(sfiles[-1]))
      self.saver.restore(sess, str(sfiles[-1]))
      print('Restored.')
      # Needs to restore the other hyperparameters/states for training, (TODO xinlei) I have
      # tried my best to find the random states so that it can be recovered exactly
      # However the Tensorflow state is currently not available
      with open(str(nfiles[-1]), 'rb') as fid:
        st0 = pickle.load(fid)
        cur = pickle.load(fid)
        perm = pickle.load(fid)
        cur_val = pickle.load(fid)
        perm_val = pickle.load(fid)
        last_snapshot_iter = pickle.load(fid)

        np.random.set_state(st0)
        self.data_layer._cur = cur
        self.data_layer._perm = perm
        self.data_layer_val._cur = cur_val
        self.data_layer_val._perm = perm_val

        # Set the learning rate, only reduce once
        if last_snapshot_iter > cfg.TRAIN.STEPSIZE:
          sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA))
        else:
          sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE))

    timer = Timer()
    iter = last_snapshot_iter + 1
    last_summary_time = time.time()
    while iter < max_iters + 1:
      # Learning rate
      # if iter == cfg.TRAIN.STEPSIZE + 1:
      if iter == 60000+1:        # rpn
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        sess.run(tf.assign(lr, 0.0001))
      elif iter == 80000+1:      # rfcn  80000-160000-200000
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        sess.run(tf.assign(lr, 0.001))
      elif iter == 160000 + 1:
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        sess.run(tf.assign(lr, 0.0001))
      elif iter == 200000 + 1:    # rpn 200000-260000-280000
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        sess.run(tf.assign(lr, 0.001))
      elif iter == 260000 + 1:
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        sess.run(tf.assign(lr, 0.0001))
      elif iter == 280000 + 1:   # rfcn 280000-360000-400000
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        sess.run(tf.assign(lr, 0.001))
      elif iter == 360000 + 1:
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        sess.run(tf.assign(lr, 0.0001))
      elif iter == 400000 + 1:
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        sess.run(tf.assign(lr, 0.001))
      elif iter == 460000 + 1:
        # Add snapshot here before reducing the learning rate
        self.snapshot(sess, iter)
        sess.run(tf.assign(lr, 0.0001))

      timer.tic()
      # Get training data, one batch at a time
      blobs = self.data_layer.forward()

      now = time.time()
      if now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
        # Compute the graph with summary
        if (iter < 80000) or (200000 <= iter < 280000) or (400000 <= iter < max_iters+1):
        # if (iter < 10) or (100 <= iter < 200):
        #   print(str(iter), '@@@'*10, 'train rpn layers')
          rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
            self.net.train_step_with_summary(sess, blobs, train_op_rpn)
        # elif (80000 <= iter < 120000) or (200000 <= iter < max_iters+1):
        else:
          # print(str(iter), '=====' * 10, 'train rfcn layers')
          rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
            self.net.train_step_with_summary(sess, blobs, train_op_class)
        self.writer.add_summary(summary, float(iter))
        # Also check the summary on the validation set
        blobs_val = self.data_layer_val.forward()
        summary_val = self.net.get_summary(sess, blobs_val)
        self.valwriter.add_summary(summary_val, float(iter))
        last_summary_time = now
      else:
        # Compute the graph without summary
        if (iter < 80000) or (200000 <= iter < 280000) or (400000 <= iter < max_iters + 1):
        # if (iter < 10) or (100 <= iter < 200):
        #   print(str(iter), '@@@' * 10, 'train rpn layers')
          rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
            self.net.train_step(sess, blobs, train_op_rpn)
        else:
          # print(str(iter), '=====' * 10, 'train rfcn layers')
          rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
            self.net.train_step(sess, blobs, train_op_class)

      timer.toc()

      # Display training information
      if iter % (cfg.TRAIN.DISPLAY) == 0:
        print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
              '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
              (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
        print('speed: {:.3f}s / iter'.format(timer.average_time))

      if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
        last_snapshot_iter = iter
        snapshot_path, np_path = self.snapshot(sess, iter)
        np_paths.append(np_path)
        ss_paths.append(snapshot_path)

        # Remove the old snapshots if there are too many
        if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
          to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
          for c in range(to_remove):
            nfile = np_paths[0]
            os.remove(str(nfile))
            np_paths.remove(nfile)

        if len(ss_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
          to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
          for c in range(to_remove):
            sfile = ss_paths[0]
            # To make the code compatible to earlier versions of Tensorflow,
            # where the naming tradition for checkpoints are different
            if os.path.exists(str(sfile)):
              os.remove(str(sfile))
            else:
              os.remove(str(sfile + '.data-00000-of-00001'))
              os.remove(str(sfile + '.index'))
            sfile_meta = sfile + '.meta'
            os.remove(str(sfile_meta))
            ss_paths.remove(sfile)
      iter += 1
    if last_snapshot_iter != iter - 1:
      self.snapshot(sess, iter - 1)

    self.writer.close()
    self.valwriter.close()
Beispiel #35
0
class SolverWrapper(object):
    """
    A wrapper class for the training process
  """
    def __init__(self,
                 network,
                 imdb,
                 roidb,
                 valroidb,
                 output_dir,
                 tbdir,
                 pretrained_model=None):
        self.net = network
        self.imdb = imdb
        self.roidb = roidb
        self.valroidb = valroidb
        self.output_dir = output_dir
        self.tbdir = tbdir
        # Simply put '_val' at the end to save the summaries from the validation set
        self.tbvaldir = tbdir + '_val'
        if not os.path.exists(self.tbvaldir):
            os.makedirs(self.tbvaldir)
        self.pretrained_model = pretrained_model

    def snapshot(self, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.pth'
        filename = os.path.join(self.output_dir, filename)
        torch.save(self.net.state_dict(), filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indexes of the database
        perm = self.data_layer._perm
        # current position in the validation database
        cur_val = self.data_layer_val._cur
        # current shuffled indexes of the validation database
        perm_val = self.data_layer_val._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename

    def from_snapshot(self, sfile, nfile):
        print('Restoring model snapshots from {:s}'.format(sfile))
        self.net.load_state_dict(torch.load(str(sfile)))
        print('Restored.')
        # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
        # tried my best to find the random states so that it can be recovered exactly
        # However the Tensorflow state is currently not available
        with open(nfile, 'rb') as fid:
            st0 = pickle.load(fid)
            cur = pickle.load(fid)
            perm = pickle.load(fid)
            cur_val = pickle.load(fid)
            perm_val = pickle.load(fid)
            last_snapshot_iter = pickle.load(fid)

            np.random.set_state(st0)
            self.data_layer._cur = cur
            self.data_layer._perm = perm
            self.data_layer_val._cur = cur_val
            self.data_layer_val._perm = perm_val

        return last_snapshot_iter

    def construct_graph(self):
        # Set the random seed
        torch.manual_seed(cfg.RNG_SEED)
        # Build the main computation graph
        self.net.create_architecture(self.imdb.num_classes,
                                     tag='default',
                                     anchor_scales=cfg.ANCHOR_SCALES,
                                     anchor_ratios=cfg.ANCHOR_RATIOS)
        # Define the loss
        # loss = layers['total_loss']
        # Set learning rate and momentum
        lr = cfg.TRAIN.LEARNING_RATE
        params = []
        for key, value in dict(self.net.named_parameters()).items():
            if value.requires_grad:
                if 'refine' in key and 'bias' in key:
                    params += [{
                        'params': [value],
                        'lr':
                        10 * lr * (cfg.TRAIN.DOUBLE_BIAS + 1),
                        'weight_decay':
                        cfg.TRAIN.BIAS_DECAY and cfg.TRAIN.WEIGHT_DECAY or 0
                    }]
                elif 'refine' in key and 'bias' not in key:
                    params += [{
                        'params': [value],
                        'lr':
                        10 * lr,
                        'weight_decay':
                        getattr(value, 'weight_decay', cfg.TRAIN.WEIGHT_DECAY)
                    }]
                elif 'refine' not in key and 'bias' in key:
                    params += [{
                        'params': [value],
                        'lr':
                        lr * (cfg.TRAIN.DOUBLE_BIAS + 1),
                        'weight_decay':
                        cfg.TRAIN.BIAS_DECAY and cfg.TRAIN.WEIGHT_DECAY or 0
                    }]
                else:
                    params += [{
                        'params': [value],
                        'lr':
                        lr,
                        'weight_decay':
                        getattr(value, 'weight_decay', cfg.TRAIN.WEIGHT_DECAY)
                    }]
        self.optimizer = torch.optim.SGD(params, momentum=cfg.TRAIN.MOMENTUM)
        # Write the train and validation information to tensorboard
        self.writer = tb.writer.FileWriter(self.tbdir)
        self.valwriter = tb.writer.FileWriter(self.tbvaldir)

        return lr, self.optimizer

    def find_previous(self):
        sfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pth')
        sfiles = glob.glob(sfiles)
        sfiles.sort(key=os.path.getmtime)
        # Get the snapshot name in pytorch
        redfiles = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            redfiles.append(
                os.path.join(
                    self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX +
                    '_iter_{:d}.pth'.format(stepsize + 1)))
        sfiles = [ss for ss in sfiles if ss not in redfiles]

        nfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
        nfiles = glob.glob(nfiles)
        nfiles.sort(key=os.path.getmtime)
        redfiles = [redfile.replace('.pth', '.pkl') for redfile in redfiles]
        nfiles = [nn for nn in nfiles if nn not in redfiles]

        lsf = len(sfiles)
        assert len(nfiles) == lsf

        return lsf, nfiles, sfiles

    def initialize(self):
        # Initial file lists are empty
        np_paths = []
        ss_paths = []
        # Fresh train directly from ImageNet weights
        print('Loading initial model weights from {:s}'.format(
            self.pretrained_model))
        self.net.load_pretrained_cnn(torch.load(self.pretrained_model))
        print('Loaded.')
        # Need to fix the variables before loading, so that the RGB weights are changed to BGR
        # For VGG16 it also changes the convolutional weights fc6 and fc7 to
        # fully connected weights
        last_snapshot_iter = 0
        lr = cfg.TRAIN.LEARNING_RATE
        stepsizes = list(cfg.TRAIN.STEPSIZE)

        return lr, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def restore(self, sfile, nfile):
        # Get the most recent snapshot and restore
        np_paths = [nfile]
        ss_paths = [sfile]
        # Restore model from snapshots
        last_snapshot_iter = self.from_snapshot(sfile, nfile)
        # Set the learning rate
        lr_scale = 1
        stepsizes = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            if last_snapshot_iter > stepsize:
                lr_scale *= cfg.TRAIN.GAMMA
            else:
                stepsizes.append(stepsize)
        scale_lr(self.optimizer, lr_scale)
        lr = cfg.TRAIN.LEARNING_RATE * lr_scale
        return lr, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def remove_snapshot(self, np_paths, ss_paths):
        to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            nfile = np_paths[0]
            os.remove(str(nfile))
            np_paths.remove(nfile)

        to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            sfile = ss_paths[0]
            # To make the code compatible to earlier versions of Tensorflow,
            # where the naming tradition for checkpoints are different
            os.remove(str(sfile))
            ss_paths.remove(sfile)

    def train_model(self, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb,
                                           self.imdb.num_classes,
                                           random=True)

        # Construct the computation graph
        lr, train_op = self.construct_graph()

        # Find previous snapshots if there is any to restore from
        lsf, nfiles, sfiles = self.find_previous()

        # Initialize the variables or restore them from the last snapshot
        if lsf == 0:
            lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(
            )
        else:
            lr, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(
                str(sfiles[-1]), str(nfiles[-1]))
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        # Make sure the lists are not empty
        stepsizes.append(max_iters)
        stepsizes.reverse()
        next_stepsize = stepsizes.pop()

        self.net.train()
        self.net.to(self.net._device)

        while iter < max_iters + 1:
            # Learning rate
            if iter == next_stepsize + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(iter)
                lr *= cfg.TRAIN.GAMMA
                scale_lr(self.optimizer, cfg.TRAIN.GAMMA)
                next_stepsize = stepsizes.pop()
            #if ((iter -1) % cfg.TRAIN.MIL_RECURRENT_STEP) == 0:
            #  num_epoch = int((iter - 1) / cfg.TRAIN.MIL_RECURRENT_STEP) + 1
            #  cfg.TRAIN.MIL_RECURRECT_WEIGHT = ((num_epoch - 1)/20.0)/1.5
            #if iter == cfg.TRAIN.MIL_RECURRENT_STEP + 1:
            #  cfg.TRAIN.MIL_RECURRECT_WEIGHT = cfg.TRAIN.MIL_RECURRECT_WEIGHT * 10

            utils.timer.timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            now = time.time()
            if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                cls_det_loss, refine_loss_1, refine_loss_2, total_loss, summary = \
                  self.net.train_step_with_summary(blobs, self.optimizer)
                for _sum in summary:
                    self.writer.add_summary(_sum, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(blobs_val)
                for _sum in summary_val:
                    self.valwriter.add_summary(_sum, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                cls_det_loss, refine_loss_1, refine_loss_2, total_loss = self.net.train_step(
                    blobs, self.optimizer)
            utils.timer.timer.toc()

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                print('iter: %d / %d, total loss: %.6f\n >>> cls_det_loss: %.6f\n '
                      '>>> refine_loss_1: %.6f\n >>> refine_loss_2: %.6f\n >>> lr: %f' % \
                      (iter, max_iters, total_loss, cls_det_loss, refine_loss_1, refine_loss_2, lr))
                print('speed: {:.3f}s / iter'.format(
                    utils.timer.timer.average_time()))

                # for k in utils.timer.timer._average_time.keys():
                #   print(k, utils.timer.timer.average_time(k))

            # Snapshotting
            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                ss_path, np_path = self.snapshot(iter)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(iter - 1)

        self.writer.close()
        self.valwriter.close()
Beispiel #36
0
 def get_data_layer(self):
     """return a data layer."""
     layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
     return layer
Beispiel #37
0
class SolverWrapper(object):
    """
    A wrapper class for the training process
  """
    def __init__(self,
                 sess,
                 network,
                 imdb,
                 roidb,
                 valroidb,
                 output_dir,
                 tbdir,
                 pretrained_model=None):
        self.net = network
        self.imdb = imdb
        self.roidb = roidb
        self.valroidb = valroidb
        self.output_dir = output_dir
        self.tbdir = tbdir
        # Simply put '_val' at the end to save the summaries from the validation set
        self.tbvaldir = tbdir + '_val'
        if not os.path.exists(self.tbvaldir):
            os.makedirs(self.tbvaldir)
        self.pretrained_model = pretrained_model

    def snapshot(self, sess, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.ckpt'
        filename = os.path.join(self.output_dir, filename)
        self.saver.save(sess, filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(
            iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indexes of the database
        perm = self.data_layer._perm
        # current position in the validation database
        cur_val = self.data_layer_val._cur
        # current shuffled indexes of the validation database
        perm_val = self.data_layer_val._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename

    def from_snapshot(self, sess, sfile, nfile):
        print('Restoring model snapshots from {:s}'.format(sfile))
        self.saver.restore(sess, sfile)
        print('Restored.======================')
        # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
        # tried my best to find the random states so that it can be recovered exactly
        # However the Tensorflow state is currently not available
        with open(nfile, 'rb') as fid:
            st0 = pickle.load(fid)
            cur = pickle.load(fid)
            perm = pickle.load(fid)
            cur_val = pickle.load(fid)
            perm_val = pickle.load(fid)
            last_snapshot_iter = pickle.load(fid)

            np.random.set_state(st0)
            self.data_layer._cur = cur
            self.data_layer._perm = perm
            self.data_layer_val._cur = cur_val
            self.data_layer_val._perm = perm_val

        return last_snapshot_iter

    def get_variables_in_checkpoint_file(self, file_name):
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print(
                    "It's likely that your checkpoint file has been compressed "
                    "with SNAPPY.")

    def construct_graph(self, sess):
        with sess.graph.as_default():
            # Set the random seed for tensorflow
            tf.set_random_seed(cfg.RNG_SEED)
            # Build the main computation graph
            layers = self.net.create_architecture(
                'TRAIN',
                self.imdb.num_classes,
                tag='default',
                anchor_scales=cfg.ANCHOR_SCALES,
                anchor_ratios=cfg.ANCHOR_RATIOS)
            # Define the loss
            loss = layers['total_loss']
            # Set learning rate and momentum
            lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)
            self.optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)

            # Compute the gradients with regard to the loss
            gvs = self.optimizer.compute_gradients(loss)
            # Double the gradient of the bias if set
            if cfg.TRAIN.DOUBLE_BIAS:
                final_gvs = []
                with tf.variable_scope('Gradient_Mult') as scope:
                    for grad, var in gvs:
                        scale = 1.
                        if cfg.TRAIN.DOUBLE_BIAS and '/biases:' in var.name:
                            scale *= 2.
                        if not np.allclose(scale, 1.0):
                            grad = tf.multiply(grad, scale)
                        final_gvs.append((grad, var))
                train_op = self.optimizer.apply_gradients(final_gvs)
            else:
                train_op = self.optimizer.apply_gradients(gvs)

            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=100000)
            # Write the train and validation information to tensorboard
            self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
            self.valwriter = tf.summary.FileWriter(self.tbvaldir)

        return lr, train_op

    def find_previous(self):
        sfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
        sfiles = glob.glob(sfiles)
        sfiles.sort(key=os.path.getmtime)
        # Get the snapshot name in TensorFlow
        redfiles = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            redfiles.append(
                os.path.join(
                    self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX +
                    '_iter_{:d}.ckpt.meta'.format(stepsize + 1)))
        sfiles = [
            ss.replace('.meta', '') for ss in sfiles if ss not in redfiles
        ]

        nfiles = os.path.join(self.output_dir,
                              cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
        nfiles = glob.glob(nfiles)
        nfiles.sort(key=os.path.getmtime)
        redfiles = [
            redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles
        ]
        nfiles = [nn for nn in nfiles if nn not in redfiles]

        lsf = len(sfiles)
        assert len(nfiles) == lsf

        return lsf, nfiles, sfiles

    def initialize(self, sess):
        # Initial file lists are empty
        np_paths = []
        ss_paths = []
        # Fresh train directly from ImageNet weights
        print('Loading initial model weights from {:s}'.format(
            self.pretrained_model))
        variables = tf.global_variables()
        # Initialize all variables first
        sess.run(tf.variables_initializer(variables, name='init'))
        var_keep_dic = self.get_variables_in_checkpoint_file(
            self.pretrained_model)
        # Get the variables to restore, ignoring the variables to fix
        variables_to_restore = self.net.get_variables_to_restore(
            variables, var_keep_dic)

        restorer = tf.train.Saver(variables_to_restore)
        restorer.restore(sess, self.pretrained_model)
        print('Loaded.')
        # Need to fix the variables before loading, so that the RGB weights are changed to BGR
        # For VGG16 it also changes the convolutional weights fc6 and fc7 to
        # fully connected weights
        self.net.fix_variables(sess, self.pretrained_model)
        print('Fixed.')
        last_snapshot_iter = 0
        rate = cfg.TRAIN.LEARNING_RATE
        stepsizes = list(cfg.TRAIN.STEPSIZE)

        return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def restore(self, sess, sfile, nfile):
        # Get the most recent snapshot and restore
        np_paths = [nfile]
        ss_paths = [sfile]
        # Restore model from snapshots
        last_snapshot_iter = self.from_snapshot(sess, sfile, nfile)
        # Set the learning rate
        rate = cfg.TRAIN.LEARNING_RATE
        stepsizes = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            if last_snapshot_iter > stepsize:
                rate *= cfg.TRAIN.GAMMA
            else:
                stepsizes.append(stepsize)

        return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def remove_snapshot(self, np_paths, ss_paths):
        to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            nfile = np_paths[0]
            os.remove(str(nfile))
            np_paths.remove(nfile)

        to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            sfile = ss_paths[0]
            # To make the code compatible to earlier versions of Tensorflow,
            # where the naming tradition for checkpoints are different
            if os.path.exists(str(sfile)):
                os.remove(str(sfile))
            else:
                os.remove(str(sfile + '.data-00000-of-00001'))
                os.remove(str(sfile + '.index'))
            sfile_meta = sfile + '.meta'
            os.remove(str(sfile_meta))
            ss_paths.remove(sfile)

    def train_model(self, sess, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb,
                                           self.imdb.num_classes,
                                           random=True)

        # Construct the computation graph
        lr, train_op = self.construct_graph(sess)

        # Find previous snapshots if there is any to restore from
        lsf, nfiles, sfiles = self.find_previous()

        # Initialize the variables or restore them from the last snapshot
        if lsf == 0:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(
                sess)
        else:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(
                sess, str(sfiles[-1]), str(nfiles[-1]))
        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        # Make sure the lists are not empty
        stepsizes.append(max_iters)
        stepsizes.reverse()
        next_stepsize = stepsizes.pop()
        while iter < max_iters + 1:
            # Learning rate
            if iter == next_stepsize + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(sess, iter)
                rate *= cfg.TRAIN.GAMMA
                sess.run(tf.assign(lr, rate))
                next_stepsize = stepsizes.pop()

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            now = time.time()
            if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = \
                  self.net.train_step_with_summary(sess, blobs, train_op)
                self.writer.add_summary(summary, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(sess, blobs_val)
                self.valwriter.add_summary(summary_val, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = \
                  self.net.train_step(sess, blobs, train_op)
            timer.toc()

            # Display training information
            if iter % (cfg.TRAIN.DISPLAY) == 0:
                print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
                      (iter, max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr.eval()))
                print('speed: {:.3f}s / iter'.format(timer.average_time))

            # Snapshotting
            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                ss_path, np_path = self.snapshot(sess, iter)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(sess, iter - 1)

        self.writer.close()
        self.valwriter.close()
Beispiel #38
0
class SolverWrapper(object):
    """
      A wrapper class for the training process
    """

    def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None):
        self.net = network
        self.imdb = imdb
        self.roidb = roidb
        self.valroidb = valroidb
        self.output_dir = output_dir
        self.tbdir = tbdir
        # Simply put '_val' at the end to save the summaries from the validation set
        self.tbvaldir = tbdir + '_val'
        common.check_dir(self.tbvaldir)
        self.pretrained_model = pretrained_model

    def snapshot(self, sess, iter):
        net = self.net

        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # Store the model snapshot
        filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt'
        filename = os.path.join(self.output_dir, filename)
        self.saver.save(sess, filename)
        print('Wrote snapshot to: {:s}'.format(filename))

        # Also store some meta information, random state, etc.
        nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
        nfilename = os.path.join(self.output_dir, nfilename)
        # current state of numpy random
        st0 = np.random.get_state()
        # current position in the database
        cur = self.data_layer._cur
        # current shuffled indexes of the database
        perm = self.data_layer._perm
        # current position in the validation database
        cur_val = self.data_layer_val._cur
        # current shuffled indexes of the validation database
        perm_val = self.data_layer_val._perm

        # Dump the meta info
        with open(nfilename, 'wb') as fid:
            pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
            pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

        return filename, nfilename

    def from_snapshot(self, sess, sfile, nfile):
        print('Restoring model snapshots from {:s}'.format(sfile))
        self.saver.restore(sess, sfile)
        print('Restored.')
        # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
        # tried my best to find the random states so that it can be recovered exactly
        # However the Tensorflow state is currently not available
        with open(nfile, 'rb') as fid:
            st0 = pickle.load(fid)
            cur = pickle.load(fid)
            perm = pickle.load(fid)
            cur_val = pickle.load(fid)
            perm_val = pickle.load(fid)
            last_snapshot_iter = pickle.load(fid)

            np.random.set_state(st0)
            self.data_layer._cur = cur
            self.data_layer._perm = perm
            self.data_layer_val._cur = cur_val
            self.data_layer_val._perm = perm_val

        return last_snapshot_iter

    def get_variables_in_checkpoint_file(self, file_name):
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print("It's likely that your checkpoint file has been compressed "
                      "with SNAPPY.")

    def construct_graph(self, sess):
        with sess.graph.as_default():
            # Set the random seed for tensorflow
            tf.set_random_seed(cfg.RNG_SEED)
            # Build the main computation graph
            layers = self.net.create_architecture('TRAIN', self.imdb.num_classes, tag='default',
                                                  anchor_width=cfg.CTPN.ANCHOR_WIDTH,
                                                  anchor_h_ratio_step=cfg.CTPN.H_RADIO_STEP,
                                                  num_anchors=cfg.CTPN.NUM_ANCHORS)
            # Define the loss
            total_loss = layers['total_loss']
            # Set learning rate and momentum
            lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False)


            if cfg.TRAIN.OPTIMIZER == 'Adam':
                self.optimizer = tf.train.AdamOptimizer(lr)
            elif cfg.TRAIN.OPTIMIZER == 'Momentum':
                self.optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)
            elif cfg.TRAIN.OPTIMIZER == 'RMS':
                self.optimizer = tf.train.RMSPropOptimizer(lr)
            else:
                raise NotImplementedError

            global_step = tf.Variable(0, trainable=False)
            with_clip = False
            if with_clip:
                tvars = tf.trainable_variables()
                grads, norm = tf.clip_by_global_norm(tf.gradients(total_loss, tvars), 10.0)
                train_op = self.optimizer.apply_gradients(list(zip(grads, tvars)), global_step=global_step)
            else:
                # required by tf.layers.batch_normalization()
                # add update ops(for moving_mean and moving_variance) as a dependency to the train_op
                # https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                with tf.control_dependencies(update_ops):
                    train_op = self.optimizer.minimize(total_loss, global_step=global_step)

            # We will handle the snapshots ourselves
            self.saver = tf.train.Saver(max_to_keep=100000)
            # Write the train and validation information to tensorboard
            self.writer = tf.summary.FileWriter(self.tbdir, sess.graph)
            self.valwriter = tf.summary.FileWriter(self.tbvaldir)

        return lr, train_op

    def find_previous(self):
        sfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.ckpt.meta')
        sfiles = glob.glob(sfiles)
        sfiles.sort(key=os.path.getmtime)
        # Get the snapshot name in TensorFlow
        redfiles = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            redfiles.append(os.path.join(self.output_dir,
                                         cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}.ckpt.meta'.format(stepsize + 1)))
        sfiles = [ss.replace('.meta', '') for ss in sfiles if ss not in redfiles]

        nfiles = os.path.join(self.output_dir, cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_*.pkl')
        nfiles = glob.glob(nfiles)
        nfiles.sort(key=os.path.getmtime)
        redfiles = [redfile.replace('.ckpt.meta', '.pkl') for redfile in redfiles]
        nfiles = [nn for nn in nfiles if nn not in redfiles]

        lsf = len(sfiles)
        assert len(nfiles) == lsf

        return lsf, nfiles, sfiles

    def restore_ckpt_from_dir(self, sess, net, checkpoint_dir):
        print("Restoring checkpoint from: " + checkpoint_dir)

        ckpt = tf.train.latest_checkpoint(checkpoint_dir)
        if ckpt is None:
            print("Checkpoint not found")
            exit(-1)

        meta_file = ckpt + '.meta'
        try:
            print('Restore variables from {}'.format(ckpt))
            print('Restore meta_filr from {}'.format(meta_file))
            saver = tf.train.Saver(net.variables_to_restore)
            saver.restore(sess, ckpt)
        except Exception:
            raise Exception("Can not restore from {}".format(checkpoint_dir))

    def initialize(self, sess):
        # Initial file lists are empty
        np_paths = []
        ss_paths = []

        variables = tf.global_variables()
        # Initialize all variables first
        sess.run(tf.variables_initializer(variables, name='init'))

        if self.pretrained_model is not None:
            if self.pretrained_model.endswith('.ckpt'):
                # Fresh train directly from ImageNet weights
                print('Loading initial model weights from {:s}'.format(self.pretrained_model))

                var_keep_dic = self.get_variables_in_checkpoint_file(self.pretrained_model)
                # Get the variables to restore, ignoring the variables to fix
                variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic)

                restorer = tf.train.Saver(variables_to_restore)
                restorer.restore(sess, self.pretrained_model)
                print('Loaded.')
            else:
                # Restore from checkpoint and meta file
                self.restore_ckpt_from_dir(sess, self.net, self.pretrained_model)
                print('Loaded.')

        last_snapshot_iter = 0
        rate = cfg.TRAIN.LEARNING_RATE
        stepsizes = list(cfg.TRAIN.STEPSIZE)

        return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def restore(self, sess, sfile, nfile):
        # Get the most recent snapshot and restore
        np_paths = [nfile]
        ss_paths = [sfile]
        # Restore model from snapshots
        last_snapshot_iter = self.from_snapshot(sess, sfile, nfile)
        # Set the learning rate
        rate = cfg.TRAIN.LEARNING_RATE
        stepsizes = []
        for stepsize in cfg.TRAIN.STEPSIZE:
            if last_snapshot_iter > stepsize:
                rate *= cfg.TRAIN.GAMMA
            else:
                stepsizes.append(stepsize)

        return rate, last_snapshot_iter, stepsizes, np_paths, ss_paths

    def remove_snapshot(self, np_paths, ss_paths):
        to_remove = len(np_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            nfile = np_paths[0]
            os.remove(str(nfile))
            np_paths.remove(nfile)

        to_remove = len(ss_paths) - cfg.TRAIN.SNAPSHOT_KEPT
        for c in range(to_remove):
            sfile = ss_paths[0]
            # To make the code compatible to earlier versions of Tensorflow,
            # where the naming tradition for checkpoints are different
            if os.path.exists(str(sfile)):
                os.remove(str(sfile))
            else:
                os.remove(str(sfile + '.data-00000-of-00001'))
                os.remove(str(sfile + '.index'))
            sfile_meta = sfile + '.meta'
            os.remove(str(sfile_meta))
            ss_paths.remove(sfile)

    def train_model(self, sess, max_iters):
        # Build data layers for both training and validation set
        self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
        self.data_layer_val = RoIDataLayer(self.valroidb, self.imdb.num_classes, random=True)

        # Construct the computation graph
        lr, train_op = self.construct_graph(sess)

        # Find previous snapshots if there is any to restore from
        lsf, nfiles, sfiles = self.find_previous()

        # Initialize the variables or restore them from the last snapshot
        if lsf == 0:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.initialize(sess)
        else:
            rate, last_snapshot_iter, stepsizes, np_paths, ss_paths = self.restore(sess,
                                                                                   str(sfiles[-1]),
                                                                                   str(nfiles[-1]))
        timer = Timer()
        iter = last_snapshot_iter + 1
        last_summary_time = time.time()
        # Make sure the lists are not empty
        stepsizes.append(max_iters)
        stepsizes.reverse()
        next_stepsize = stepsizes.pop()
        while iter < max_iters + 1:
            # Learning rate
            if iter == next_stepsize + 1:
                # Add snapshot here before reducing the learning rate
                self.snapshot(sess, iter)
                rate *= cfg.TRAIN.GAMMA
                sess.run(tf.assign(lr, rate))
                next_stepsize = stepsizes.pop()

            timer.tic()
            # Get training data, one batch at a time
            blobs = self.data_layer.forward()

            now = time.time()
            if iter == 1 or now - last_summary_time > cfg.TRAIN.SUMMARY_INTERVAL:
                # Compute the graph with summary
                rpn_loss_cls, rpn_loss_box, rpn_loss, total_loss, summary = \
                    self.net.train_step_with_summary(sess, blobs, train_op)
                self.writer.add_summary(summary, float(iter))
                # Also check the summary on the validation set
                blobs_val = self.data_layer_val.forward()
                summary_val = self.net.get_summary(sess, blobs_val)
                self.valwriter.add_summary(summary_val, float(iter))
                last_summary_time = now
            else:
                # Compute the graph without summary
                rpn_loss_cls, rpn_loss_box, rpn_loss, total_loss, _ = \
                    self.net.train_step(sess, blobs, train_op)
            timer.toc()

            print('%d/%d time: %.3f total_loss: %.3f rpn_loss: %.3f rpn_loss_cls: %.3f '
                  'rpn_loss_box: %.3f lr: %f' % (
                      iter, max_iters, timer.diff, total_loss, rpn_loss, rpn_loss_cls, rpn_loss_box, lr.eval()))

            # Snapshotting
            if iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = iter
                ss_path, np_path = self.snapshot(sess, iter)
                np_paths.append(np_path)
                ss_paths.append(ss_path)

                # Remove the old snapshots if there are too many
                if len(np_paths) > cfg.TRAIN.SNAPSHOT_KEPT:
                    self.remove_snapshot(np_paths, ss_paths)

            iter += 1

        if last_snapshot_iter != iter - 1:
            self.snapshot(sess, iter - 1)

        self.writer.close()
        self.valwriter.close()