예제 #1
0
def main():
    args = parse_args()

    cfg = parse_config_file(args.config_file)

    # Replace cfg parameters with the command line values
    if args.max_number_of_steps != None:
        cfg.NUM_TRAIN_ITERATIONS = args.max_number_of_steps

    if args.learning_rate_decay_type != None:
        cfg.LEARNING_RATE_DECAY_TYPE = args.learning_rate_decay_type

    if args.learning_rate != None:
        cfg.INITIAL_LEARNING_RATE = args.learning_rate

    if args.batch_size != None:
        cfg.BATCH_SIZE = args.batch_size

    if args.model_name != None:
        cfg.MODEL_NAME = args.model_name

    train(
        tfrecords=args.tfrecords,
        logdir=args.logdir,
        cfg=cfg,
        pretrained_model_path=args.pretrained_model,
        trainable_scopes = args.trainable_scopes,
        checkpoint_exclude_scopes = args.checkpoint_exclude_scopes,
        restore_variables_with_moving_averages=args.restore_variables_with_moving_averages,
        restore_moving_averages=args.restore_moving_averages,
        read_images=args.read_images
    )
예제 #2
0
def main():
  args = parse_args()
  cfg = parse_config_file(args.config_file)
  visualize_train_inputs(
    tfrecords=args.tfrecords,
    cfg=cfg,
    show_text_labels=args.show_text_labels,
    read_images=args.read_images
  )
예제 #3
0
def main():
    args = parse_args()

    cfg = parse_config_file(args.config_file)

    if args.batch_size != None:
        cfg.BATCH_SIZE = args.batch_size

    if args.model_name != None:
        cfg.MODEL_NAME = args.model_name

    classify(tfrecords=args.tfrecords,
             checkpoint_path=args.checkpoint_path,
             save_path=args.save_path,
             max_iterations=args.batches,
             save_logits=args.save_logits,
             cfg=cfg)
예제 #4
0
def main():
    args = parse_args()

    cfg = parse_config_file(args.config_file)

    if args.batch_size != None:
        cfg.BATCH_SIZE = args.batch_size

    if args.model_name != None:
        cfg.MODEL_NAME = args.model_name

    extract_and_save(tfrecords=args.tfrecords,
                     checkpoint_path=args.checkpoint_path,
                     save_path=args.save_path,
                     num_iterations=args.batches,
                     feature_keys=args.features,
                     cfg=cfg)
예제 #5
0
def main():

    args = parse_args()

    cfg = parse_config_file(args.config_file)

    if args.batch_size != None:
        cfg.BATCH_SIZE = args.batch_size

    if args.model_name != None:
        cfg.MODEL_NAME = args.model_name

    test(tfrecords=args.tfrecords,
         checkpoint_path=args.checkpoint_path,
         save_dir=args.savedir,
         max_iterations=args.batches,
         eval_interval_secs=args.eval_interval_secs,
         cfg=cfg)
예제 #6
0
    parser.add_argument(
        '--serving',
        dest='serving',
        help=
        'Export for TensorFlow Serving usage. Otherwise, a constant graph will be generated.',
        action='store_true',
        default=False)

    parser.add_argument(
        '--do_preprocess',
        dest='do_preprocess',
        help='Add the image decoding and preprocessing nodes to the graph.',
        action='store_true',
        default=False)

    args = parser.parse_args()

    return args


if __name__ == '__main__':

    args = parse_args()
    cfg = parse_config_file(args.config_file)

    export(args.checkpoint_path,
           args.export_dir,
           args.export_version,
           args.serving,
           args.do_preprocess,
           cfg=cfg)