Exemplo n.º 1
0
def main(_):
    if not(os.path.exists(output_directory)):
        os.mkdir(output_directory)
    input_shape = None
    additional_output_tensor_names = None

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    with tf.gfile.GFile(pipeline_config_path, 'r') as f:
        text_format.Merge(f.read(), pipeline_config)
    text_format.Merge(config_override, pipeline_config)
    if input_shape:
        input_shape = [int(dim) if dim != '-1' else None for dim in input_shape.split(',')]
    else:
        input_shape = None
    if use_side_inputs:
        side_input_shapes, side_input_names, side_input_types = (exporter.parse_side_inputs(side_input_shapes, side_input_names, side_input_types))
    else:
        side_input_shapes = None
        side_input_names = None
        side_input_types = None
    if additional_output_tensor_names:
        additional_output_tensor_names = list(additional_output_tensor_names.split(','))
    else:
        additional_output_tensor_names = None
    exporter.export_inference_graph(input_type, pipeline_config, trained_checkpoint_prefix,
                                    output_directory, input_shape=input_shape,
                                    write_inference_graph=write_inference_graph,
                                    additional_output_tensor_names=additional_output_tensor_names,
                                    use_side_inputs=use_side_inputs,
                                    side_input_shapes=side_input_shapes,
                                    side_input_names=side_input_names,
                                    side_input_types=side_input_types)
Exemplo n.º 2
0
def main(_):
    with open('system_dict.json') as json_file:
        args = json.load(json_file)

    if (os.path.isdir(args["output_directory"])):
        os.system("rm -r " + args["output_directory"])

    os.mkdir(args["output_directory"])

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    with tf.gfile.GFile(args["pipeline_config_path"], 'r') as f:
        text_format.Merge(f.read(), pipeline_config)
    text_format.Merge(args["config_override"], pipeline_config)

    if args["input_shape"]:
        input_shape = [
            int(dim) if dim != '-1' else None
            for dim in args["input_shape"].split(',')
        ]
    else:
        input_shape = None

    if args["input_shape_flops"]:
        input_shape_flops = [
            int(dim) if dim != '-1' else None
            for dim in args["input_shape_flops"].split(',')
        ]
    else:
        input_shape_flops = None

    if args["use_side_inputs"]:
        side_input_shapes, side_input_names, side_input_types = (
            exporter.parse_side_inputs(args["side_input_shapes"],
                                       args["side_input_names"],
                                       args["side_input_types"]))
    else:
        side_input_shapes = None
        side_input_names = None
        side_input_types = None

    if args["additional_output_tensor_names"]:
        additional_output_tensor_names = list(
            args["additional_output_tensor_names"].split(','))
    else:
        additional_output_tensor_names = None

    exporter.export_inference_graph(
        args["input_type"],
        pipeline_config,
        args["trained_checkpoint_prefix"],
        args["output_directory"],
        input_shape=input_shape,
        write_inference_graph=args["write_inference_graph"],
        additional_output_tensor_names=additional_output_tensor_names,
        use_side_inputs=args["use_side_inputs"],
        side_input_shapes=side_input_shapes,
        side_input_names=side_input_names,
        side_input_types=side_input_types)
    '''
Exemplo n.º 3
0
def main(_):
    if FLAGS.gpu_device:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(FLAGS.gpu_device)

    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
        text_format.Merge(f.read(), pipeline_config)
    text_format.Merge(FLAGS.config_override, pipeline_config)
    if FLAGS.input_shape:
        input_shape = [
            int(dim) if dim != '-1' else None
            for dim in FLAGS.input_shape.split(',')
        ]
    else:
        input_shape = None
    if FLAGS.use_side_inputs:
        side_input_shapes, side_input_names, side_input_types = (
            exporter.parse_side_inputs(FLAGS.side_input_shapes,
                                       FLAGS.side_input_names,
                                       FLAGS.side_input_types))
    else:
        side_input_shapes = None
        side_input_names = None
        side_input_types = None
    if FLAGS.additional_output_tensor_names:
        additional_output_tensor_names = list(
            FLAGS.additional_output_tensor_names.split(','))
    else:
        additional_output_tensor_names = None

    checkpoint_path = FLAGS.trained_checkpoint_prefix
    if not checkpoint_path:
        checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)

    exporter.export_inference_graph(
        FLAGS.input_type,
        pipeline_config,
        checkpoint_path,
        FLAGS.output_directory,
        input_shape=input_shape,
        write_inference_graph=FLAGS.write_inference_graph,
        additional_output_tensor_names=additional_output_tensor_names,
        use_side_inputs=FLAGS.use_side_inputs,
        side_input_shapes=side_input_shapes,
        side_input_names=side_input_names,
        side_input_types=side_input_types)
Exemplo n.º 4
0
def main(_):
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    with tf.gfile.GFile(FLAGS.pipeline_config_path, "r") as f:
        text_format.Merge(f.read(), pipeline_config)
    text_format.Merge(FLAGS.config_override, pipeline_config)
    if FLAGS.input_shape:
        input_shape = [
            int(dim) if dim != "-1" else None
            for dim in FLAGS.input_shape.split(",")
        ]
    else:
        input_shape = None
    if FLAGS.use_side_inputs:
        (
            side_input_shapes,
            side_input_names,
            side_input_types,
        ) = exporter.parse_side_inputs(FLAGS.side_input_shapes,
                                       FLAGS.side_input_names,
                                       FLAGS.side_input_types)
    else:
        side_input_shapes = None
        side_input_names = None
        side_input_types = None
    if FLAGS.additional_output_tensor_names:
        additional_output_tensor_names = list(
            FLAGS.additional_output_tensor_names.split(","))
    else:
        additional_output_tensor_names = None
    exporter.export_inference_graph(
        FLAGS.input_type,
        pipeline_config,
        FLAGS.trained_checkpoint_prefix,
        FLAGS.output_directory,
        input_shape=input_shape,
        write_inference_graph=FLAGS.write_inference_graph,
        additional_output_tensor_names=additional_output_tensor_names,
        use_side_inputs=FLAGS.use_side_inputs,
        side_input_shapes=side_input_shapes,
        side_input_names=side_input_names,
        side_input_types=side_input_types,
    )