예제 #1
0
def predict(config):
    """ 
	Load trained network and make predictions

	Parameters
	----------
	config: dict 
		Model training configuration

	"""
    # load dataset:
    data = load_dataset(config['test_data'], config['batch_size'])

    # init model:
    if config['msg'] == True:
        model = CLS_MSG_Model(config['batch_size'], config['num_classes'],
                              config['batch_normalization'])
    else:
        model = CLS_SSG_Model(config['batch_size'], config['num_classes'],
                              config['batch_normalization'])

    # load params:
    model.load_weights(config['checkpoint_path'])

    y_truths = []
    y_preds = []
    for X, y in data:
        y_truths.append(y.numpy().flatten())
        y_preds.append(np.argmax(model(X), axis=1))

    y_truth = np.hstack(y_truths)
    y_pred = np.hstack(y_preds)

    # get decoder
    decoder = KITTIPCDClassificationDataset(
        input_dir=
        '/workspace/data/kitti_3d_object_classification_normal_resampled'
    ).get_decoder()

    # get confusion matrix:
    conf_mat = confusion_matrix(y_truth, y_pred)
    # change to percentage:
    conf_mat = np.dot(np.diag(1.0 / conf_mat.sum(axis=1)), conf_mat)
    conf_mat = 100.0 * conf_mat
    labels = [decoder[i] for i in range(config['num_classes'])]
    plt.figure(figsize=(10, 10))
    sn.heatmap(conf_mat, annot=True, xticklabels=labels, yticklabels=labels)
    plt.title('KITTI 3D Object Classification -- Confusion Matrix')
    plt.show()

    print(classification_report(y_truth, y_pred, target_names=labels))
def predict(config):
	""" 
	Load trained network and make predictions

	Parameters
	----------
	config: dict 
		Model training configuration

	"""
	# load dataset:
	data = load_dataset(config['test_data'], config['batch_size'])

	# init model:
	if config['msg'] == True:
		model = CLS_MSG_Model(config['batch_size'], config['num_classes'], config['batch_normalization'])
	else:
		model = CLS_SSG_Model(config['batch_size'], config['num_classes'], config['batch_normalization'])

	# load params:
	model.load_weights(config['checkpoint_path'])

	y_truths = []
	y_preds = []
	for X, y in data:
		y_truths.append(y.numpy().flatten())
		y_preds.append(np.argmax(model(X), axis=1))
	
	y_truth = np.hstack(y_truths)
	y_pred = np.hstack(y_preds)

	# get decoder
	decoder = ModelNet40Dataset(input_dir='/workspace/data/modelnet40_normal_resampled').get_decoder()

	plt.figure(figsize = (10,10))
	sn.heatmap(confusion_matrix(y_truth, y_pred), annot=True)
	plt.show()

	print(
		classification_report(
			y_truth, y_pred, 
			target_names=[
				decoder[i] for i in range(40)
			]
		)
	)
예제 #3
0
def load_model(config):
    """
    Load pre-trained object classification network
    
    """
    # init model:
    if config['msg'] == True:
        model = CLS_MSG_Model(config['batch_size'], config['num_classes'],
                              config['batch_normalization'])
    else:
        model = CLS_SSG_Model(config['batch_size'], config['num_classes'],
                              config['batch_normalization'])

    # load params:
    model.load_weights(config['checkpoint_path'])

    return model