コード例 #1
0
def SW_add_scalars2(self, main_tag, tag_scalar_dict, global_step=None):
    """Adds many scalar data to summary.
    Args:
        tag (string): Data identifier
        main_tag (string): The parent name for the tags
        tag_scalar_dict (dict): Key-value pair storing the tag and corresponding values
        global_step (int): Global step value to record
    Examples::
        writer.add_scalars('run_14h',{'xsinx':i*np.sin(i/r),
                                      'xcosx':i*np.cos(i/r),
                                      'arctanx': numsteps*np.arctan(i/r)}, i)
        # This function adds three values to the same scalar plot with the tag
        # 'run_14h' in TensorBoard's scalar section.
    """
    timestamp = time.time()
    fw_logdir = self.file_writer.get_logdir()
    for tag, scalar_value in tag_scalar_dict.items():
        fw_tag = fw_logdir + "/" + tag
        #fw_tag_full = fw_logdir + "/" + main_tag + "/" + tag
        if fw_tag in self.all_writers.keys():
            fw = self.all_writers[fw_tag]
        else:
            fw = tensorboardX.FileWriter(logdir=fw_tag)
            self.all_writers[fw_tag] = fw
        fw.add_summary(tensorboardX.summary.scalar(main_tag, scalar_value), global_step)
コード例 #2
0
def add_tsb_image_last(entry, data, path=".", subfolder="", **kwargs):
    """Handler that append the last image of the data to a tensorboard event file.

    :param string entry: Name of the log entry
    :param Dict data: Data should be a numpy array / torch tensor of shape [3, W, H]
    :param string subfolder: Subfolder in which put the data
    :param string path: Root path. Set by DataLogger if used as handler.
    """
    tsb_dir = os.path.join(path, subfolder)
    os.makedirs(tsb_dir, exist_ok=True)
    if tsb_writer.file_writer.get_logdir() != tsb_dir:
        tsb_writer.file_writer = tensorboardX.FileWriter(tsb_dir)
    last_time = max(data.keys())
    value = data[last_time]
    tsb_writer.add_image(entry, value, last_time)
コード例 #3
0
def add_tsb_scalars_last(entry,
                         data,
                         labels=None,
                         path=".",
                         subfolder="",
                         **kwargs):
    """Handler that appends the last item of the data dictionary to a tensorboard event file.

    :param string entry: Name of the log entry.
    :param Dict data: Data should be numpy arrays of size [n] with constant n.
    :param List[string] labels: Labels to use for the lines
    :param string subfolder: Subfolder in which put the data
    :param string path: Root path. Set by DataLogger if used as handler.
    """
    tsb_dir = os.path.join(path, subfolder)
    os.makedirs(tsb_dir, exist_ok=True)
    if tsb_writer.file_writer.get_logdir() != tsb_dir:
        tsb_writer.file_writer = tensorboardX.FileWriter(tsb_dir)
    last_time = max(data.keys())
    value = data[last_time]
    if labels is None:
        labels = [str(a) for a in range(len(value))]
    scalars_dict = {labels[i]: value[i] for i in range(len(value))}
    tsb_writer.add_scalars(entry, scalars_dict, last_time)
コード例 #4
0
 def __init__(self, model_dir):
     self.summary_writer = tb.FileWriter(model_dir)
コード例 #5
0
def main(argv):
  (opts, args) = parser.parse_args(argv)
  if 'estimate' in opts.mode:
    mode_idx = int(opts.mode[-1])

  global colorPlatte, bones, Evaluation

  if 'nyu' in opts.config:
    colorPlatte = utils.util.nyuColorIdx
    bones = utils.util.nyuBones
    Evaluation = NYUHandposeEvaluation
  elif 'icvl' in opts.config:
    colorPlatte = utils.util.icvlColorIdx
    bones = utils.util.icvlBones
    Evaluation = ICVLHandposeEvaluation



  # Load experiment setting
  assert isinstance(opts, object)
  config = NetConfig(opts.config)

  batch_size = config.hyperparameters['batch_size'] if 'estimate' in opts.mode else 1
  test_batch_size = batch_size * 32
  max_iterations = config.hyperparameters['max_iterations']
  frac = opts.frac

  dataset_a = get_dataset(config.datasets['train_a'])
  dataset_b = get_dataset(config.datasets['train_b'])
  dataset_test = get_dataset(config.datasets['test_b'])


  train_loader_a = get_data_loader(dataset_a, batch_size, shuffle=True)
  train_loader_b = get_data_loader(dataset_b, batch_size, shuffle=True)
  test_loader_real = get_data_loader(dataset_test, test_batch_size, shuffle=False)

  cmd = "trainer=%s(config.hyperparameters)" % config.hyperparameters['trainer']
  local_dict = locals()
  exec(cmd,globals(),local_dict)
  trainer = local_dict['trainer']

  di_a = dataset_a.di
  di_b = dataset_b.di

  # Check if resume training
  iterations = 0
  if opts.resume == 1:
    iterations = trainer.resume(config.snapshot_prefix, idx=-1, load_opt=True)
    for i in range(iterations//1000):
	trainer.dis_sch.step()
	trainer.gen_sch.step()
  trainer.cuda(opts.gpu)


  print('using %.2f percent of the labeled real data' % frac)
  try:
    if 'estimate' in opts.mode and (mode_idx == 3 or mode_idx == 4):
      trainer.load_vae(config.snapshot_prefix, 2+frac)
    else:
      trainer.load_vae(config.snapshot_prefix, frac)
  except:
    print('Failed to load the parameters of vae')

  if 'estimate' in opts.mode:
    if opts.idx != 0:
      trainer.resume(config.snapshot_prefix, idx=opts.idx, est=mode_idx==5)
    if frac > 0. and frac < 1.:
      dataset_b.set_nmax(frac)
    #trainer.dis.freeze_layers()

  ###############################################################################################
  # Setup logger and repare image outputs
  train_writer = tensorboardX.FileWriter("%s/%s" % (opts.log,os.path.splitext(os.path.basename(opts.config))[0]))
  image_directory, snapshot_directory = prepare_snapshot_and_image_folder(config.snapshot_prefix, iterations, config.image_save_iterations)

  best_err, best_acc = 100., 0.
  start_time = time.time()
  for ep in range(0, MAX_EPOCHS):
    for it, ((images_a, labels_a, com_a, M_a, cube_a, _), (images_b,labels_b, com_b, M_b, cube_b, _)) in \
						enumerate(izip(train_loader_a,train_loader_b)):
      if images_a.size(0) != batch_size or images_b.size(0) != batch_size:
        continue
      images_a = Variable(images_a.cuda(opts.gpu))
      images_b = Variable(images_b.cuda(opts.gpu))
      labels_a = Variable(labels_a.cuda(opts.gpu))
      labels_b = Variable(labels_b.cuda(opts.gpu))
      com_a = Variable(com_a.cuda(opts.gpu))
      com_b = Variable(com_b.cuda(opts.gpu))

      trainer.dis.train()
      if opts.mode == 'pretrain':
	if (iterations+1) % 1000 == 0:
	  trainer.dis_sch.step()
	  trainer.gen_sch.step()
	  print('lr %.8f' % trainer.dis_sch.get_lr()[0])

        trainer.dis_update(images_a, labels_a, images_b, labels_b, com_a, com_b, config.hyperparameters)
        image_outputs = trainer.gen_update(images_a, labels_a, images_b, labels_b, config.hyperparameters)
        assembled_images = trainer.assemble_outputs(images_a, images_b, image_outputs)
      else:
	if (iterations+1) % 100 == 0:
	  trainer.dis_sch.step()
        image_outputs = trainer.post_update(images_a, labels_a, images_b, labels_b,com_a,com_b, mode_idx, config.hyperparameters)
        assembled_images = trainer.assemble_outputs(images_a, images_b, image_outputs)

      # Dump training stats in log file
      if (iterations+1) % config.display == 0:
	elapsed_time = time.time() -  start_time
        write_loss(iterations, max_iterations, trainer, train_writer, elapsed_time)
	start_time = time.time()

      if (iterations + 1) % config.image_display_iterations == 0:
          img_filename = '%s/gen.jpg' % (image_directory)
          torchvision.utils.save_image(assembled_images.data / 2 + 0.5, img_filename, nrow=1)

      if (iterations+1) % config.image_save_iterations == 0:

        if opts.mode == 'pretrain':# and (iterations+1) % (2*config.image_save_iterations) != 0:
          img_filename = '%s/gen_%08d.jpg' % (image_directory, iterations + 1)
          torchvision.utils.save_image(assembled_images.data / 2 + 0.5, img_filename, nrow=1)
          write_html(snapshot_directory + "/index.html", iterations + 1, \
						config.image_save_iterations, image_directory)
	else:
          trainer.dis.eval()
          score, maxerr = 0, 0
          num_samples = 0
	  maxJntError = []
	  meanJntError = 0
	  img2sav = None
    	  gt3D = []
    	  joints = []
	  joints_imgcord = []
          codec = cv2.VideoWriter_fourcc(*'XVID')
          vid = cv2.VideoWriter(os.path.join(image_directory,'gen.avi'), codec, 25, (128*2,128))
          for tit, (test_images_b, test_labels_b, com_b, trans_b, cube_b, fn) in enumerate(test_loader_real):
            test_images_b = Variable(test_images_b.cuda(opts.gpu))
            test_labels_b = Variable(test_labels_b.cuda(opts.gpu))
	    if mode_idx == 0:
               pred_pose, pred_post, _ = trainer.dis.regress_a(test_images_b)
	    else:
               pred_pose, pred_post, _ = trainer.dis.regress_b(test_images_b)

	    if True:
	      pred_pose = trainer.vae.decode(pred_post)

	    n = test_labels_b.size(0)

	    gt_pose = test_labels_b.data.cpu().numpy().reshape((n,-1, 3))
	    pr_pose = pred_pose.data.cpu().numpy().reshape((n,-1, 3))

	    if  tit < 20:
	      for i in range(0, n, 4):
                real_img = visPair(di_b, test_images_b[i].data.cpu().numpy(), gt_pose[i].reshape((-1)), \
				trans_b[i].numpy(), com_b[i].numpy(), cube_b[i].numpy(), 50.0)
                est_img = visPair(di_b, test_images_b[i].data.cpu().numpy(), pr_pose[i].reshape((-1)), \
				trans_b[i].numpy(), com_b[i].numpy(), cube_b[i].numpy(), 50.0)

	        vid.write(np.hstack((real_img,est_img)).astype('uint8'))
	      
	    both_img = np.vstack((real_img,est_img))

	    if True and tit < 8:
	      if img2sav is None:
	        img2sav = both_img
	      else:
	        img2sav = np.hstack((img2sav,both_img))


	    if 'nyu' in opts.config:
		restrictedJointsEval = np.array([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 27, 30, 31, 32])
		gt_pose = gt_pose[:,restrictedJointsEval]
		pr_pose = pr_pose[:,restrictedJointsEval]

	    for i in range(n):
              gt3D.append(gt_pose[i]*(cube_b.numpy()[0]/2.)+ com_b[i].numpy())
              joints.append(pr_pose[i]*(cube_b.numpy()[0]/2.)+ com_b[i].numpy())
              joints_imgcord.append(di_b.joints3DToImg(pr_pose[i]*(cube_b.numpy()[0]/2.)+ com_b[i].numpy()))


            score += meanJntError
            num_samples += test_images_b.size(0)

	  cv2.imwrite(image_directory + '/_test.jpg', img2sav.astype('uint8'))
	  vid.release()

    	  hpe = Evaluation(np.array(gt3D), np.array(joints))
	  mean_err = hpe.getMeanError()
	  over_40 = 100. * hpe.getNumFramesWithinMaxDist(40) / len(gt3D)
	  best_err = np.minimum(best_err, mean_err)
	  best_acc = np.maximum(best_acc, over_40)
    	  print("------------ Mean err: {:.4f} ({:.4f}) mm, Max over 40mm: {:.2f} ({:.2f}) %".format(mean_err, best_err, over_40, best_acc))


      # Save network weights
      if (iterations+1) % config.snapshot_save_iterations == 0:
	if opts.mode == 'pretrain':
          trainer.save(config.snapshot_prefix, iterations)
	elif 'estimate' in opts.mode:
          trainer.save(config.snapshot_prefix+'_est', iterations)

      iterations += 1
      if iterations >= max_iterations:
        return
コード例 #6
0
    optimizer = torch.optim.Adam()
else:
    optimizer = ''
    raise ValueError()

model.cuda()

# # DATA LOADER
get_loader = get_data_loader(cfg.datasetname)
train_data = get_loader()
class_names = train_data.dataset.classes
print('dataset len: {}'.format(len(train_data.dataset)))

tb_dir = os.path.join(cfg.train_dir, cfg.backbone + '_' + cfg.datasetname,
                      time.strftime("%h%d_%H"))
writer = tbx.FileWriter(tb_dir)
summary_out = []

global_step = 0
timer = Timer()

for ep in range(start_epoch, cfg.max_epoch):
    if ep in cfg.lr_decay_epoches and cfg.solver == 'SGD':
        lr *= cfg.lr_decay
        adjust_learning_rate(optimizer, lr)
        print('adjusting learning rate {:.6f}'.format(lr))

    for step, batch in enumerate(train_data):
        timer.tic()

        input, anchors_np, im_scale_list, image_ids, gt_boxes_list, rpn_targets, _, _ = batch
コード例 #7
0
 def __init__(self, model_dir, sub_dir=""):
     self.summary_writer = tb.FileWriter(os.path.join(model_dir, sub_dir))
コード例 #8
0
ファイル: test.py プロジェクト: onejiin/Detectron-PYTORCH
def main():
    # config model and lr
    num_anchors = len(cfg.anchor_ratios) * len(cfg.anchor_scales[0]) * len(cfg.anchor_shift) \
        if isinstance(cfg.anchor_scales[0], list) else \
        len(cfg.anchor_ratios) * len(cfg.anchor_scales)

    resnet = resnet50 if cfg.backbone == 'resnet50' else resnet101
    detection_model = MaskRCNN if cfg.model_type.lower(
    ) == 'maskrcnn' else RetinaNet

    model = detection_model(resnet(pretrained=True, maxpool5=cfg.maxpool5),
                            num_classes=cfg.num_classes,
                            num_anchors=num_anchors,
                            strides=cfg.strides,
                            in_channels=cfg.in_channels,
                            f_keys=cfg.f_keys,
                            num_channels=256,
                            is_training=False,
                            activation=cfg.class_activation)

    lr = cfg.lr
    start_epoch = 0
    if cfg.restore is not None:
        meta = load_net(cfg.restore, model)
        print(meta)
        if meta[0] >= 0:
            start_epoch = meta[0] + 1
            lr = meta[1]
        print('Restored from %s, starting from %d epoch, lr:%.6f' %
              (cfg.restore, start_epoch, lr))
    else:
        raise ValueError('restore is not set')

    model.cuda()
    model.eval()

    class_names = test_data.dataset.classes
    print('dataset len: {}'.format(len(test_data.dataset)))

    tb_dir = os.path.join(cfg.train_dir, cfg.backbone + '_' + cfg.datasetname,
                          'test', time.strftime("%h%d_%H"))
    writer = tbx.FileWriter(tb_dir)

    # main loop
    timer_all = Timer()
    timer_post = Timer()
    all_results1 = []
    all_results2 = []
    all_results_gt = []
    for step, batch in enumerate(test_data):

        timer_all.tic()

        # NOTE: Targets is in NHWC order!!
        # input, anchors_np, im_scale_list, image_ids, gt_boxes_list = batch
        # input = everything2cuda(input)
        input_t, anchors_np, im_scale_list, image_ids, gt_boxes_list = batch
        input = everything2cuda(input_t, volatile=True)

        outs = model(input, gt_boxes_list=None, anchors_np=anchors_np)

        if cfg.model_type == 'maskrcnn':
            rpn_logit, rpn_box, rpn_prob, rpn_labels, rpn_bbtargets, rpn_bbwghts, anchors, \
            rois, roi_img_ids, rcnn_logit, rcnn_box, rcnn_prob, rcnn_labels, rcnn_bbtargets, rcnn_bbwghts = outs
            outputs = [
                rois, roi_img_ids, rpn_logit, rpn_box, rpn_prob, rcnn_logit,
                rcnn_box, rcnn_prob, anchors
            ]
            targets = []
        elif cfg.model_type == 'retinanet':
            rpn_logit, rpn_box, rpn_prob, _, _, _ = outs
            outputs = [rpn_logit, rpn_box, rpn_prob]
        else:
            raise ValueError('Unknown model type: %s' % cfg.model_type)

        timer_post.tic()

        dets_dict = model.get_final_results(
            outputs,
            everything2cuda(anchors_np),
            score_threshold=0.01,
            max_dets=cfg.max_det_num * cfg.batch_size,
            overlap_threshold=cfg.overlap_threshold)
        if 'stage1' in dets_dict:
            Dets = dets_dict['stage1']
        else:
            raise ValueError('No stage1 results:', dets_dict.keys())
        Dets2 = dets_dict['stage2'] if 'stage2' in dets_dict else Dets

        t3 = timer_post.toc()
        t = timer_all.toc()

        formal_res1 = dataset.to_detection_format(copy.deepcopy(Dets),
                                                  image_ids, im_scale_list)
        formal_res2 = dataset.to_detection_format(copy.deepcopy(Dets2),
                                                  image_ids, im_scale_list)
        all_results1 += formal_res1
        all_results2 += formal_res2

        Dets_gt = []
        for gb in gt_boxes_list:
            cpy_mask = gb[:, 4] >= 1
            gb = gb[cpy_mask]
            n = cpy_mask.astype(np.int32).sum()
            res_gt = np.zeros((n, 6))
            res_gt[:, :4] = gb[:, :4]
            res_gt[:, 4] = 1.
            res_gt[:, 5] = gb[:, 4]
            Dets_gt.append(res_gt)
        formal_res_gt = dataset.to_detection_format(Dets_gt, image_ids,
                                                    im_scale_list)
        all_results_gt += formal_res_gt

        if step % cfg.log_image == 0:
            input_np = everything2numpy(input)
            summary_out = []
            Is = single_shot.draw_detection(input_np,
                                            Dets,
                                            class_names=class_names)
            Is = Is.astype(np.uint8)
            summary_out += log_images(Is, image_ids, step, prefix='Detection/')

            Is = single_shot.draw_detection(input_np,
                                            Dets2,
                                            class_names=class_names)
            Is = Is.astype(np.uint8)
            summary_out += log_images(Is,
                                      image_ids,
                                      step,
                                      prefix='Detection2/')

            Imgs = single_shot.draw_gtboxes(input_np,
                                            gt_boxes_list,
                                            class_names=class_names)
            Imgs = Imgs.astype(np.uint8)
            summary_out += log_images(Imgs,
                                      image_ids,
                                      float(step),
                                      prefix='GT')

            for s in summary_out:
                writer.add_summary(s, float(step))

        if step % cfg.display == 0:
            print(time.strftime("%H:%M:%S ") +
                  'Epoch %d iter %d: speed %.3fs (%.3fs)' % (0, step, t, t3) +
                  ' ImageIds: ' + ', '.join(str(s) for s in image_ids),
                  end='\r')

    res_dict = {
        'stage1': all_results1,
        'stage2': all_results2,
        'gt': all_results_gt
    }
    return res_dict
コード例 #9
0
ファイル: pose_train.py プロジェクト: xyhak47/LSPS
def main(argv):
    (opts, args) = parser.parse_args(argv)

    global colorPlatte, bones, Evaluation

    if 'nyu' in opts.config:
        colorPlatte = utils.util.nyuColorIdx
        bones = utils.util.nyuBones
        Evaluation = NYUHandposeEvaluation
    elif 'icvl' in opts.config:
        colorPlatte = utils.util.icvlColorIdx
        bones = utils.util.icvlBones
        Evaluation = ICVLHandposeEvaluation

    # Load experiment setting
    assert isinstance(opts, object)
    config = NetConfig(opts.config)

    batch_size = config.hyperparameters['batch_size_pose']
    max_iterations = 200000  #config.hyperparameters['max_iterations']
    frac = opts.frac

    dataset_a = get_dataset(config.datasets['train_a'])
    dataset_b = get_dataset(config.datasets['train_b'])
    dataset_test = get_dataset(config.datasets['test_b'])

    train_loader_a = get_data_loader(dataset_a, batch_size, shuffle=True)
    train_loader_b = get_data_loader(dataset_b, batch_size, shuffle=True)
    test_loader_real = get_data_loader(dataset_test, 1, shuffle=True)

    cmd = "trainer=%s(config.hyperparameters)" % config.hyperparameters[
        'trainer']
    local_dict = locals()
    exec(cmd, globals(), local_dict)
    trainer = local_dict['trainer']

    iterations = 0
    trainer.cuda(opts.gpu)

    dataset_a.pose_only = True
    dataset_b.pose_only = True

    if frac > 0. and frac < 1.:
        dataset_b.set_nmax(frac)

    di_a = dataset_a.di
    di_b = dataset_b.di

    dataset_a.sample_poses()
    dataset_b.sample_poses()

    ###################################################################
    # Setup logger and repare image outputs
    train_writer = tensorboardX.FileWriter(
        "%s/%s" %
        (opts.log, os.path.splitext(os.path.basename(opts.config))[0]))
    image_directory, snapshot_directory = prepare_snapshot_and_image_folder(
        config.snapshot_prefix, iterations, config.image_save_iterations)

    print('using %.2f percent of the labeled real data' % frac)
    start_time = time.time()
    for ep in range(0, MAX_EPOCHS):
        for it, ((labels_a),
                 (labels_b)) in enumerate(izip(train_loader_a,
                                               train_loader_b)):
            if labels_a.size(0) != batch_size or labels_b.size(
                    0) != batch_size:
                continue
            labels_a = Variable(labels_a.cuda(opts.gpu))
            labels_b = Variable(labels_b.cuda(opts.gpu))
            labels = labels_a

            if frac > 0.:
                labels = torch.cat((labels_a, labels_b), 0)

            if (iterations + 1) % 1000 == 0:
                trainer.vae_sch.step()

            recon_pose = trainer.vae_update(labels, config.hyperparameters)

            # Dump training stats in log file
            if (iterations + 1) % config.display == 0:
                elapsed_time = time.time() - start_time
                write_loss(iterations, max_iterations, trainer, train_writer,
                           elapsed_time)
                start_time = time.time()

            if (iterations + 1) % (10 * config.image_save_iterations) == 0:
                if True:
                    score, maxerr = 0, 0
                    num_samples = 0
                    maxJntError = []
                    img2sav = None
                    gt3D = []
                    joints = []
                    for tit, (test_images_b, test_labels_b, com_b, trans_b,
                              cube_b, _) in enumerate(test_loader_real):
                        test_images_b = Variable(test_images_b.cuda(opts.gpu))
                        test_labels_b = Variable(test_labels_b.cuda(opts.gpu))

                        pred_pose = trainer.vae.decode(
                            trainer.vae.encode(test_labels_b)[1])

                        gt3D.append(test_labels_b.data.cpu().numpy().reshape((-1, 3))*(cube_b.numpy()[0]/2.) +\
                     com_b.numpy())

                        joints.append(pred_pose.data.cpu().numpy().reshape((-1, 3))*(cube_b.numpy()[0]/2.) +\
                     com_b.numpy())

                        if True and tit < 8:
                            real_img = visPair(di_b, test_images_b.data.cpu().numpy(), test_labels_b.data.cpu().numpy(), \
                  trans_b.numpy(), com_b.numpy(), cube_b.numpy(), 50.0)
                            est_img = visPair(di_b, test_images_b.data.cpu().numpy(), pred_pose.data.cpu().numpy(), \
                  trans_b.numpy(), com_b.numpy(), cube_b.numpy(), 50.0)

                            if img2sav is None:
                                img2sav = np.vstack((real_img, est_img))
                            else:
                                img2sav = np.hstack(
                                    (img2sav, np.vstack((real_img, est_img))))

                        num_samples += test_images_b.size(0)

                    cv2.imwrite(image_directory + '/_test.jpg',
                                img2sav.astype('uint8'))
                    #maxerr = Evaluation.plotError(maxJntError, image_directory + '/maxJntError.txt')

                    hpe = Evaluation(np.array(gt3D), np.array(joints))
                    print("Mean error: {}mm, max error: {}mm".format(
                        hpe.getMeanError(), hpe.getMaxError()))

            # Save network weights
            if (iterations + 1) % (4 * config.snapshot_save_iterations) == 0:
                trainer.save_vae(config.snapshot_prefix, iterations, 2 + frac)

            iterations += 1
            if iterations >= max_iterations:
                return
コード例 #10
0
def main():
    # config model and lr
    num_anchors = len(cfg.anchor_ratios) * len(cfg.anchor_scales[0]) \
        if isinstance(cfg.anchor_scales[0], list) else \
        len(cfg.anchor_ratios) * len(cfg.anchor_scales)

    resnet = resnet50 if cfg.backbone == 'resnet50' else resnet101
    detection_model = MaskRCNN if cfg.model_type.lower(
    ) == 'maskrcnn' else RetinaNet

    model = detection_model(resnet(pretrained=True),
                            num_classes=cfg.num_classes,
                            num_anchors=num_anchors,
                            strides=cfg.strides,
                            in_channels=cfg.in_channels,
                            f_keys=cfg.f_keys,
                            num_channels=256,
                            is_training=False,
                            activation=cfg.class_activation)

    lr = cfg.lr
    start_epoch = 0
    if cfg.restore is not None:
        meta = load_net(cfg.restore, model)
        print(meta)
        if meta[0] >= 0:
            start_epoch = meta[0] + 1
            lr = meta[1]
        print('Restored from %s, starting from %d epoch, lr:%.6f' %
              (cfg.restore, start_epoch, lr))
    else:
        raise ValueError('restore is not set')

    model.cuda()
    model.eval()

    ANCHORS = np.vstack(
        [anc.reshape([-1, 4]) for anc in test_data.dataset.ANCHORS])
    model.anchors = everything2cuda(ANCHORS.astype(np.float32))

    class_names = test_data.dataset.classes
    print('dataset len: {}'.format(len(test_data.dataset)))

    tb_dir = os.path.join(cfg.train_dir, cfg.backbone + '_' + cfg.datasetname,
                          'test', time.strftime("%h%d_%H"))
    writer = tbx.FileWriter(tb_dir)
    summary_out = []

    # main loop
    timer_all = Timer()
    timer_post = Timer()
    all_results1 = []
    all_results2 = []
    all_results_gt = []
    for step, batch in enumerate(test_data):

        timer_all.tic()

        # NOTE: Targets is in NHWC order!!
        input, image_ids, gt_boxes_list, image_ori = batch
        input = everything2cuda(input)

        outs = model(input)

        timer_post.tic()

        dets_dict = model.get_final_results(
            score_threshold=0.05,
            max_dets=cfg.max_det_num * cfg.batch_size,
            overlap_threshold=cfg.overlap_threshold)
        if 'stage1' in dets_dict:
            Dets = dets_dict['stage1']
        else:
            raise ValueError('No stage1 results:', dets_dict.keys())
        Dets2 = dets_dict['stage2'] if 'stage2' in dets_dict else Dets

        t3 = timer_post.toc()
        t = timer_all.toc()

        formal_res1 = dataset.to_detection_format(
            copy.deepcopy(Dets),
            image_ids,
            ori_sizes=[im.shape for im in image_ori])
        formal_res2 = dataset.to_detection_format(
            copy.deepcopy(Dets2),
            image_ids,
            ori_sizes=[im.shape for im in image_ori])
        all_results1 += formal_res1
        all_results2 += formal_res2

        if step % cfg.log_image == 0:
            input_np = everything2numpy(input)
            summary_out = []
            Is = single_shot.draw_detection(input_np,
                                            Dets,
                                            class_names=class_names)
            Is = Is.astype(np.uint8)
            summary_out += log_images(Is, image_ids, step, prefix='Detection/')

            Is = single_shot.draw_detection(input_np,
                                            Dets2,
                                            class_names=class_names)
            Is = Is.astype(np.uint8)
            summary_out += log_images(Is,
                                      image_ids,
                                      step,
                                      prefix='Detection2/')

            Imgs = single_shot.draw_gtboxes(input_np,
                                            gt_boxes_list,
                                            class_names=class_names)
            Imgs = Imgs.astype(np.uint8)
            summary_out += log_images(Imgs,
                                      image_ids,
                                      float(step),
                                      prefix='GT')

            for s in summary_out:
                writer.add_summary(s, float(step))

        if step % cfg.display == 0:
            print(time.strftime("%H:%M:%S ") +
                  'Epoch %d iter %d: speed %.3fs (%.3fs)' % (0, step, t, t3) +
                  ' ImageIds: ' + ', '.join(str(s) for s in image_ids),
                  end='\r')

    res_dict = {
        'stage1': all_results1,
        'stage2': all_results2,
        'gt': all_results_gt
    }
    return res_dict