예제 #1
0
def train(config):
    """ 
	Build network

	Parameters
	----------
	config: dict 
		Model training configuration

	"""

    # load dataset:
    training_data = load_dataset(config['training_data'], config['batch_size'])
    validation_data = load_dataset(config['validation_data'],
                                   config['batch_size'])

    # init model:
    if config['msg'] == True:
        model = CLS_MSG_Model(config['batch_size'], config['num_classes'],
                              config['batch_normalization'])
    else:
        model = CLS_SSG_Model(config['batch_size'], config['num_classes'],
                              config['batch_normalization'])

    # enable early stopping:
    callbacks = [
        keras.callbacks.EarlyStopping('val_sparse_categorical_accuracy',
                                      min_delta=0.01,
                                      patience=10),
        keras.callbacks.TensorBoard('./logs/{}'.format(config['log_dir']),
                                    update_freq=50),
        keras.callbacks.ModelCheckpoint('./logs/{}/model/weights.ckpt'.format(
            config['log_dir']),
                                        'val_sparse_categorical_accuracy',
                                        save_best_only=True)
    ]

    model.build(input_shape=(config['batch_size'],
                             KITTIPCDClassificationDataset.N,
                             KITTIPCDClassificationDataset.d +
                             KITTIPCDClassificationDataset.C))
    print(model.summary())

    model.compile(optimizer=keras.optimizers.Adam(config['lr']),
                  loss=keras.losses.SparseCategoricalCrossentropy(),
                  metrics=[keras.metrics.SparseCategoricalAccuracy()])

    model.fit(training_data,
              validation_data=validation_data,
              validation_steps=20,
              validation_freq=1,
              callbacks=callbacks,
              epochs=100,
              verbose=1)
예제 #2
0
def predict(config):
    """ 
	Load trained network and make predictions

	Parameters
	----------
	config: dict 
		Model training configuration

	"""
    # load dataset:
    data = load_dataset(config['test_data'], config['batch_size'])

    # init model:
    if config['msg'] == True:
        model = CLS_MSG_Model(config['batch_size'], config['num_classes'],
                              config['batch_normalization'])
    else:
        model = CLS_SSG_Model(config['batch_size'], config['num_classes'],
                              config['batch_normalization'])

    # load params:
    model.load_weights(config['checkpoint_path'])

    y_truths = []
    y_preds = []
    for X, y in data:
        y_truths.append(y.numpy().flatten())
        y_preds.append(np.argmax(model(X), axis=1))

    y_truth = np.hstack(y_truths)
    y_pred = np.hstack(y_preds)

    # get decoder
    decoder = KITTIPCDClassificationDataset(
        input_dir=
        '/workspace/data/kitti_3d_object_classification_normal_resampled'
    ).get_decoder()

    # get confusion matrix:
    conf_mat = confusion_matrix(y_truth, y_pred)
    # change to percentage:
    conf_mat = np.dot(np.diag(1.0 / conf_mat.sum(axis=1)), conf_mat)
    conf_mat = 100.0 * conf_mat
    labels = [decoder[i] for i in range(config['num_classes'])]
    plt.figure(figsize=(10, 10))
    sn.heatmap(conf_mat, annot=True, xticklabels=labels, yticklabels=labels)
    plt.title('KITTI 3D Object Classification -- Confusion Matrix')
    plt.show()

    print(classification_report(y_truth, y_pred, target_names=labels))
def predict(config):
	""" 
	Load trained network and make predictions

	Parameters
	----------
	config: dict 
		Model training configuration

	"""
	# load dataset:
	data = load_dataset(config['test_data'], config['batch_size'])

	# init model:
	if config['msg'] == True:
		model = CLS_MSG_Model(config['batch_size'], config['num_classes'], config['batch_normalization'])
	else:
		model = CLS_SSG_Model(config['batch_size'], config['num_classes'], config['batch_normalization'])

	# load params:
	model.load_weights(config['checkpoint_path'])

	y_truths = []
	y_preds = []
	for X, y in data:
		y_truths.append(y.numpy().flatten())
		y_preds.append(np.argmax(model(X), axis=1))
	
	y_truth = np.hstack(y_truths)
	y_pred = np.hstack(y_preds)

	# get decoder
	decoder = ModelNet40Dataset(input_dir='/workspace/data/modelnet40_normal_resampled').get_decoder()

	plt.figure(figsize = (10,10))
	sn.heatmap(confusion_matrix(y_truth, y_pred), annot=True)
	plt.show()

	print(
		classification_report(
			y_truth, y_pred, 
			target_names=[
				decoder[i] for i in range(40)
			]
		)
	)
예제 #4
0
def load_model(config):
    """
    Load pre-trained object classification network
    
    """
    # init model:
    if config['msg'] == True:
        model = CLS_MSG_Model(config['batch_size'], config['num_classes'],
                              config['batch_normalization'])
    else:
        model = CLS_SSG_Model(config['batch_size'], config['num_classes'],
                              config['batch_normalization'])

    # load params:
    model.load_weights(config['checkpoint_path'])

    return model
예제 #5
0
def train():

    if config['msg'] == True:
        model = CLS_MSG_Model(config['batch_size'], config['num_classes'],
                              config['bn'])
    else:
        model = CLS_SSG_Model(config['batch_size'], config['num_classes'],
                              config['bn'])

    train_ds = load_dataset(config['train_ds'], config['batch_size'])
    val_ds = load_dataset(config['val_ds'], config['batch_size'])

    callbacks = [
        keras.callbacks.EarlyStopping('val_sparse_categorical_accuracy',
                                      min_delta=0.01,
                                      patience=10),
        keras.callbacks.TensorBoard('./logs/{}'.format(config['log_dir']),
                                    update_freq=50),
        keras.callbacks.ModelCheckpoint('./logs/{}/model/weights.ckpt'.format(
            config['log_dir']),
                                        'val_sparse_categorical_accuracy',
                                        save_best_only=True)
    ]

    model.build(input_shape=(config['batch_size'], 8192, 3))
    print(model.summary())

    model.compile(optimizer=keras.optimizers.Adam(config['lr']),
                  loss=keras.losses.SparseCategoricalCrossentropy(),
                  metrics=[keras.metrics.SparseCategoricalAccuracy()])

    model.fit(train_ds,
              validation_data=val_ds,
              validation_steps=20,
              validation_freq=1,
              callbacks=callbacks,
              epochs=100,
              verbose=1)
예제 #6
0
def train(config, params):

    if params['msg'] == True:
        model = CLS_MSG_Model(params['batch_size'], params['num_points'],
                              params['num_classes'], params['bn'])
    else:
        model = CLS_SSG_Model(params['batch_size'], params['num_points'],
                              params['num_classes'], params['bn'])

    model.build(input_shape=(params['batch_size'], params['num_points'], 3))
    print(model.summary())
    print('[info] model training...')

    optimizer = tf.keras.optimizers.Adam(lr=params['lr'])
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
    acc_object = tf.keras.metrics.SparseCategoricalAccuracy()

    train_ds = TFDataset(os.path.join(config['dataset_dir'], 'train.tfrecord'),
                         params['batch_size'])
    val_ds = TFDataset(os.path.join(config['dataset_dir'], 'val.tfrecord'),
                       params['batch_size'])

    train_summary_writer = tf.summary.create_file_writer(
        os.path.join(config['log_dir'], config['log_code']))

    with train_summary_writer.as_default():

        while True:

            train_pts, train_labels = train_ds.get_batch()

            loss, acc = train_step(optimizer, model, loss_object, acc_object,
                                   train_pts, train_labels)

            if optimizer.iterations % config['log_freq'] == 0:
                tf.summary.scalar('train loss',
                                  loss,
                                  step=optimizer.iterations)
                tf.summary.scalar('train accuracy',
                                  acc,
                                  step=optimizer.iterations)

            if optimizer.iterations % config['test_freq'] == 0:

                test_pts, test_labels = val_ds.get_batch()

                test_loss, test_acc = test_step(optimizer, model, loss_object,
                                                acc_object, test_pts,
                                                test_labels)

                tf.summary.scalar('test loss',
                                  test_loss,
                                  step=optimizer.iterations)
                tf.summary.scalar('test accuracy',
                                  test_acc,
                                  step=optimizer.iterations)