roidb = get_training_roidb(imdb)
    valroidb = roidb
    # output directory where the models are saved
    output_dir = get_output_dir(imdb, args.tag)
    print('Output will be saved to `{:s}`'.format(output_dir))

    # tensorboard directory where the summaries are saved during training
    tb_dir = get_output_tb_dir(imdb, args.tag)
    print('TensorFlow summaries will be saved to `{:s}`'.format(tb_dir))

    # also add the validation set, but with no flipping images
    orgflip = cfg.TRAIN.USE_FLIPPED
    cfg.TRAIN.USE_FLIPPED = False
    # _, valroidb = combined_roidb(args.imdbval_name)
    print('{:d} validation roidb entries'.format(len(valroidb)))
    cfg.TRAIN.USE_FLIPPED = orgflip
    if args.net == 'vgg16':
        net = vgg16(batch_size=cfg.TRAIN.IMS_PER_BATCH)
    elif args.net == 'res101':
        net = Resnet101(batch_size=cfg.TRAIN.IMS_PER_BATCH)
    else:
        raise NotImplementedError
    train_net(net,
              imdb,
              roidb,
              valroidb,
              output_dir,
              tb_dir,
              pretrained_model=args.weight,
              max_iters=args.max_iters)
Beispiel #2
0
    tag = tag if tag else 'default'
    filename = tag + '/' + filename

    imdb = get_imdb(args.imdb_name)
    imdb.competition_mode(args.comp_mode)

    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True

    # init session
    sess = tf.Session(config=tfconfig)
    # load network
    if args.net == 'vgg16':
        net = vgg16(batch_size=1)
    elif args.net == 'res101':
        net = Resnet101(batch_size=1)
    else:
        raise NotImplementedError
    # load model

    net.create_architecture(sess,
                            "TEST",
                            imdb.num_classes,
                            tag='default',
                            anchor_scales=cfg.ANCHOR_SCALES)

    if args.model:
        print(('Loading model check point from {:s}').format(args.model))
        saver = tf.train.Saver()
        saver.restore(sess, args.model)
        print('Loaded.')