コード例 #1
0
def test(model_name, threshold=0.5, save=True, verbose=True, refine=False):
    classifications = np.array([0, 0, 0, 0])
    results_folder = os.path.join(model_name, 'results')
    if not os.path.exists(results_folder): os.mkdir(results_folder)
    _, test_set = get_dataset_split()
    if refine: test_set = sort_imgs(test_set)
    prediction = None
    model = keras.models.load_model(os.path.join(model_name,
                                                 model_name + '.h5'),
                                    custom_objects=get_custom_objects())
    for i in range(len(test_set)):
        if verbose: display_progress(i / len(test_set))
        img_path, gt_path = test_set[i].replace('\n', '').split(',')

        img = read_image(img_path, pad=(4, 4))
        img = normalise_img(img)
        ground_truth = read_gt(gt_path)
        ground_truth = np.squeeze(ground_truth)

        if refine:
            prediction = ground_truth if prediction is None else prediction
            pmap = create_map(prediction > 0.5, 1)
            prob = get_prob_map(pmap)

        prediction = model.predict(img)

        prediction = np.squeeze(prediction)
        prediction = prediction[4:-4, ...]

        prediction = (prediction > threshold).astype(np.uint8)
        if refine: prediction = prediction * prob

        classifications += getPixels(prediction, ground_truth, 0.5)

        if save:
            save_image(prediction,
                       os.path.join(results_folder, ntpath.basename(img_path)))
            save_image(
                ground_truth,
                os.path.join(
                    results_folder,
                    ntpath.basename(img_path).replace('.png', '_gt.png')))


#            if refine:
#                prob = prob.astype(np.uint8)
#                save_image(os.path.join(results_folder,
#                                ntpath.basename(img_path).replace('.png', '_prob.png')), prob)

    print(model_name, threshold)
    printMetrics(getMetrics(classifications))
        with open(args.classes_from, 'rb') as f:
            embed_labels = pickle.load(f)['ind2label']
    else:
        embed_labels = None
    data_generator = get_data_generator(args.dataset,
                                        args.data_root,
                                        classes=embed_labels)

    # Load class hierarchy
    id_type = str if args.str_ids else int
    hierarchy = ClassHierarchy.from_file(
        args.hierarchy, is_a_relations=args.is_a,
        id_type=id_type) if args.hierarchy else None

    # Learn SVM classifier on training data and evaluate on test data
    custom_objects = utils.get_custom_objects(args.architecture)
    custom_objects['labelembed_loss'] = labelembed_loss
    perf = OrderedDict()
    for i, model in enumerate(args.model):
        model_name = args.label[i] if (args.label is not None) and (i < len(
            args.label)) else os.path.splitext(os.path.basename(model))[0]
        if (args.layer is not None) and (i < len(args.layer)):
            try:
                layer = int(args.layer[i])
            except ValueError:
                layer = args.layer[i]
        else:
            layer = None
        normalize = args.norm[i] if (args.norm is not None) and (
            i < len(args.norm)) else False
        prob_features = args.prob_features[i] if (
コード例 #3
0
    else:
        with open(args.embedding, 'rb') as pf:
            embedding = pickle.load(pf)
            embed_labels = embedding['ind2label']
            embedding = embedding['embedding']

    # Load dataset
    data_generator = get_data_generator(args.dataset, args.data_root, classes = embed_labels)
    if embedding is None:
        embedding = np.eye(data_generator.num_classes)

    # Construct and train model
    if (args.gpus <= 1) or args.gpu_merge:
        if args.snapshot and os.path.exists(args.snapshot):
            print('Resuming from snapshot {}'.format(args.snapshot))
            model = keras.models.load_model(args.snapshot, custom_objects = utils.get_custom_objects(args.architecture), compile = False)
        else:
            embed_model = utils.build_network(embedding.shape[1], args.architecture)
            model = embed_model
            if args.loss == 'inv_corr':
                model = keras.models.Model(model.inputs, keras.layers.Lambda(utils.l2norm, name = 'l2norm')(model.output))
            elif args.loss == 'softmax_corr':
                model = keras.models.Model(model.inputs, keras.layers.Activation('softmax', name = 'softmax')(model.output))
            if args.cls_weight > 0:
                model = cls_model(model, data_generator.num_classes, args.cls_base)
        par_model = model if args.gpus <= 1 else keras.utils.multi_gpu_model(model, gpus = args.gpus, cpu_merge = False)
    else:
        with K.tf.device('/cpu:0'):
            if args.snapshot and os.path.exists(args.snapshot):
                print('Resuming from snapshot {}'.format(args.snapshot))
                model = keras.models.load_model(args.snapshot, custom_objects = utils.get_custom_objects(args.architecture), compile = False)
コード例 #4
0
def main(args):
    # If output_model path is relative and in cwd, make it absolute from root
    output_model = FLAGS.output_model
    if str(Path(output_model).parent) == '.':
        output_model = str((Path.cwd() / output_model))

    output_fld = Path(output_model).parent
    output_model_name = Path(output_model).name
    output_model_stem = Path(output_model).stem
    output_model_pbtxt_name = output_model_stem + '.pbtxt'

    # Create output directory if it does not exist
    Path(output_model).parent.mkdir(parents=True, exist_ok=True)

    if FLAGS.channels_first:
        K.set_image_data_format('channels_first')
    else:
        K.set_image_data_format('channels_last')

    custom_object_dict = get_custom_objects()

    model = load_input_model(FLAGS.input_model, FLAGS.input_model_json,
                             FLAGS.input_model_yaml, custom_objects=custom_object_dict)

    # TODO(amirabdi): Support networks with multiple inputs
    orig_output_node_names = [node.op.name for node in model.outputs]
    if FLAGS.output_nodes_prefix:
        num_output = len(orig_output_node_names)
        pred = [None] * num_output
        converted_output_node_names = [None] * num_output

        # Create dummy tf nodes to rename output
        for i in range(num_output):
            converted_output_node_names[i] = '{}{}'.format(
                FLAGS.output_nodes_prefix, i)
            pred[i] = tf.identity(model.outputs[i],
                                  name=converted_output_node_names[i])
    else:
        converted_output_node_names = orig_output_node_names
    logging.info('Converted output node names are: %s',
                 str(converted_output_node_names))

    sess = K.get_session()
    if FLAGS.output_meta_ckpt:
        saver = tf.train.Saver()
        saver.save(sess, str(output_fld / output_model_stem))

    if FLAGS.save_graph_def:
        tf.train.write_graph(sess.graph.as_graph_def(), str(output_fld),
                             output_model_pbtxt_name, as_text=True)
        logging.info('Saved the graph definition in ascii format at %s',
                     str(Path(output_fld) / output_model_pbtxt_name))

    if FLAGS.quantize:
        from tensorflow.tools.graph_transforms import TransformGraph
        transforms = ["quantize_weights", "quantize_nodes"]
        transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [],
                                               converted_output_node_names,
                                               transforms)
        constant_graph = graph_util.convert_variables_to_constants(
            sess,
            transformed_graph_def,
            converted_output_node_names)
    else:
        constant_graph = graph_util.convert_variables_to_constants(
            sess,
            sess.graph.as_graph_def(),
            converted_output_node_names)

    graph_io.write_graph(constant_graph, str(output_fld), output_model_name,
                         as_text=False)
    logging.info('Saved the freezed graph at %s',
                 str(Path(output_fld) / output_model_name))