def train(config): """ Build network Parameters ---------- config: dict Model training configuration """ # load dataset: training_data = load_dataset(config['training_data'], config['batch_size']) validation_data = load_dataset(config['validation_data'], config['batch_size']) # init model: if config['msg'] == True: model = CLS_MSG_Model(config['batch_size'], config['num_classes'], config['batch_normalization']) else: model = CLS_SSG_Model(config['batch_size'], config['num_classes'], config['batch_normalization']) # enable early stopping: callbacks = [ keras.callbacks.EarlyStopping('val_sparse_categorical_accuracy', min_delta=0.01, patience=10), keras.callbacks.TensorBoard('./logs/{}'.format(config['log_dir']), update_freq=50), keras.callbacks.ModelCheckpoint('./logs/{}/model/weights.ckpt'.format( config['log_dir']), 'val_sparse_categorical_accuracy', save_best_only=True) ] model.build(input_shape=(config['batch_size'], KITTIPCDClassificationDataset.N, KITTIPCDClassificationDataset.d + KITTIPCDClassificationDataset.C)) print(model.summary()) model.compile(optimizer=keras.optimizers.Adam(config['lr']), loss=keras.losses.SparseCategoricalCrossentropy(), metrics=[keras.metrics.SparseCategoricalAccuracy()]) model.fit(training_data, validation_data=validation_data, validation_steps=20, validation_freq=1, callbacks=callbacks, epochs=100, verbose=1)
def predict(config): """ Load trained network and make predictions Parameters ---------- config: dict Model training configuration """ # load dataset: data = load_dataset(config['test_data'], config['batch_size']) # init model: if config['msg'] == True: model = CLS_MSG_Model(config['batch_size'], config['num_classes'], config['batch_normalization']) else: model = CLS_SSG_Model(config['batch_size'], config['num_classes'], config['batch_normalization']) # load params: model.load_weights(config['checkpoint_path']) y_truths = [] y_preds = [] for X, y in data: y_truths.append(y.numpy().flatten()) y_preds.append(np.argmax(model(X), axis=1)) y_truth = np.hstack(y_truths) y_pred = np.hstack(y_preds) # get decoder decoder = KITTIPCDClassificationDataset( input_dir= '/workspace/data/kitti_3d_object_classification_normal_resampled' ).get_decoder() # get confusion matrix: conf_mat = confusion_matrix(y_truth, y_pred) # change to percentage: conf_mat = np.dot(np.diag(1.0 / conf_mat.sum(axis=1)), conf_mat) conf_mat = 100.0 * conf_mat labels = [decoder[i] for i in range(config['num_classes'])] plt.figure(figsize=(10, 10)) sn.heatmap(conf_mat, annot=True, xticklabels=labels, yticklabels=labels) plt.title('KITTI 3D Object Classification -- Confusion Matrix') plt.show() print(classification_report(y_truth, y_pred, target_names=labels))
def predict(config): """ Load trained network and make predictions Parameters ---------- config: dict Model training configuration """ # load dataset: data = load_dataset(config['test_data'], config['batch_size']) # init model: if config['msg'] == True: model = CLS_MSG_Model(config['batch_size'], config['num_classes'], config['batch_normalization']) else: model = CLS_SSG_Model(config['batch_size'], config['num_classes'], config['batch_normalization']) # load params: model.load_weights(config['checkpoint_path']) y_truths = [] y_preds = [] for X, y in data: y_truths.append(y.numpy().flatten()) y_preds.append(np.argmax(model(X), axis=1)) y_truth = np.hstack(y_truths) y_pred = np.hstack(y_preds) # get decoder decoder = ModelNet40Dataset(input_dir='/workspace/data/modelnet40_normal_resampled').get_decoder() plt.figure(figsize = (10,10)) sn.heatmap(confusion_matrix(y_truth, y_pred), annot=True) plt.show() print( classification_report( y_truth, y_pred, target_names=[ decoder[i] for i in range(40) ] ) )
def load_model(config): """ Load pre-trained object classification network """ # init model: if config['msg'] == True: model = CLS_MSG_Model(config['batch_size'], config['num_classes'], config['batch_normalization']) else: model = CLS_SSG_Model(config['batch_size'], config['num_classes'], config['batch_normalization']) # load params: model.load_weights(config['checkpoint_path']) return model
def train(): if config['msg'] == True: model = CLS_MSG_Model(config['batch_size'], config['num_classes'], config['bn']) else: model = CLS_SSG_Model(config['batch_size'], config['num_classes'], config['bn']) train_ds = load_dataset(config['train_ds'], config['batch_size']) val_ds = load_dataset(config['val_ds'], config['batch_size']) callbacks = [ keras.callbacks.EarlyStopping('val_sparse_categorical_accuracy', min_delta=0.01, patience=10), keras.callbacks.TensorBoard('./logs/{}'.format(config['log_dir']), update_freq=50), keras.callbacks.ModelCheckpoint('./logs/{}/model/weights.ckpt'.format( config['log_dir']), 'val_sparse_categorical_accuracy', save_best_only=True) ] model.build(input_shape=(config['batch_size'], 8192, 3)) print(model.summary()) model.compile(optimizer=keras.optimizers.Adam(config['lr']), loss=keras.losses.SparseCategoricalCrossentropy(), metrics=[keras.metrics.SparseCategoricalAccuracy()]) model.fit(train_ds, validation_data=val_ds, validation_steps=20, validation_freq=1, callbacks=callbacks, epochs=100, verbose=1)
def train(config, params): if params['msg'] == True: model = CLS_MSG_Model(params['batch_size'], params['num_points'], params['num_classes'], params['bn']) else: model = CLS_SSG_Model(params['batch_size'], params['num_points'], params['num_classes'], params['bn']) model.build(input_shape=(params['batch_size'], params['num_points'], 3)) print(model.summary()) print('[info] model training...') optimizer = tf.keras.optimizers.Adam(lr=params['lr']) loss_object = tf.keras.losses.SparseCategoricalCrossentropy() acc_object = tf.keras.metrics.SparseCategoricalAccuracy() train_ds = TFDataset(os.path.join(config['dataset_dir'], 'train.tfrecord'), params['batch_size']) val_ds = TFDataset(os.path.join(config['dataset_dir'], 'val.tfrecord'), params['batch_size']) train_summary_writer = tf.summary.create_file_writer( os.path.join(config['log_dir'], config['log_code'])) with train_summary_writer.as_default(): while True: train_pts, train_labels = train_ds.get_batch() loss, acc = train_step(optimizer, model, loss_object, acc_object, train_pts, train_labels) if optimizer.iterations % config['log_freq'] == 0: tf.summary.scalar('train loss', loss, step=optimizer.iterations) tf.summary.scalar('train accuracy', acc, step=optimizer.iterations) if optimizer.iterations % config['test_freq'] == 0: test_pts, test_labels = val_ds.get_batch() test_loss, test_acc = test_step(optimizer, model, loss_object, acc_object, test_pts, test_labels) tf.summary.scalar('test loss', test_loss, step=optimizer.iterations) tf.summary.scalar('test accuracy', test_acc, step=optimizer.iterations)