예제 #1
0
파일: icdar.py 프로젝트: zergmk2/tf_ctpn
    dataset = args.dataset

    ckpt_dir = os.path.join('output', netname, dataset, args.tag)
    ckpt = tf.train.get_checkpoint_state(ckpt_dir)

    # set config
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True

    # init session
    sess = tf.Session(config=tfconfig)
    # load network
    if netname == 'vgg16':
        net = vgg16()
    elif netname == 'res101':
        net = Resnetv1(num_layers=101)
    elif netname == 'mobile':
        net = MobileNetV2()
    else:
        raise NotImplementedError

    cfg.USE_GPU_NMS = True
    net.create_architecture("TEST",
                            num_classes=len(CLASSES),
                            tag=args.tag,
                            anchor_width=cfg.CTPN.ANCHOR_WIDTH,
                            anchor_h_ratio_step=cfg.CTPN.H_RADIO_STEP,
                            num_anchors=cfg.CTPN.NUM_ANCHORS)
    saver = tf.train.Saver()
    saver.restore(sess, ckpt.model_checkpoint_path)
예제 #2
0
  print('Output will be saved to `{:s}`'.format(output_dir))

  # tensorboard directory where the summaries are saved during training
  tb_dir = get_output_tb_dir(imdb, args.tag)
  print('TensorFlow summaries will be saved to `{:s}`'.format(tb_dir))

  # also add the validation set, but with no flipping images
  orgflip = cfg.TRAIN.USE_FLIPPED
  cfg.TRAIN.USE_FLIPPED = False
  _, valroidb = combined_roidb(args.imdbval_name)
  print('{:d} validation roidb entries'.format(len(valroidb)))
  cfg.TRAIN.USE_FLIPPED = orgflip

  # load network
  if args.net == 'vgg16':
    net = Vgg16()
  elif args.net == 'res50':
    net = Resnetv1(num_layers=50)
  elif args.net == 'res101':
    net = Resnetv1(num_layers=101)
  elif args.net == 'res152':
    net = Resnetv1(num_layers=152)
  elif args.net == 'mobile':
    net = mobilenetv1()
  else:
    raise NotImplementedError
    
  train_net(net, imdb, roidb, valroidb, output_dir, tb_dir,
            pretrained_model=args.weight,
            max_iters=args.max_iters)