def train(config, params): if params['msg'] == True: model = CLS_MSG_Model(params['batch_size'], params['num_points'], params['num_classes'], params['bn']) else: model = CLS_SSG_Model(params['batch_size'], params['num_points'], params['num_classes'], params['bn']) model.build(input_shape=(params['batch_size'], params['num_points'], 3)) print(model.summary()) print('[info] model training...') optimizer = tf.keras.optimizers.Adam(lr=params['lr']) loss_object = tf.keras.losses.SparseCategoricalCrossentropy() acc_object = tf.keras.metrics.SparseCategoricalAccuracy() train_ds = TFDataset(os.path.join(config['dataset_dir'], 'train.tfrecord'), params['batch_size']) val_ds = TFDataset(os.path.join(config['dataset_dir'], 'val.tfrecord'), params['batch_size']) train_summary_writer = tf.summary.create_file_writer( os.path.join(config['log_dir'], config['log_code'])) with train_summary_writer.as_default(): while True: train_pts, train_labels = train_ds.get_batch() loss, acc = train_step(optimizer, model, loss_object, acc_object, train_pts, train_labels) if optimizer.iterations % config['log_freq'] == 0: tf.summary.scalar('train loss', loss, step=optimizer.iterations) tf.summary.scalar('train accuracy', acc, step=optimizer.iterations) if optimizer.iterations % config['test_freq'] == 0: test_pts, test_labels = val_ds.get_batch() test_loss, test_acc = test_step(optimizer, model, loss_object, acc_object, test_pts, test_labels) tf.summary.scalar('test loss', test_loss, step=optimizer.iterations) tf.summary.scalar('test accuracy', test_acc, step=optimizer.iterations)
def train(config): """ Build network Parameters ---------- config: dict Model training configuration """ # load dataset: training_data = load_dataset(config['training_data'], config['batch_size']) validation_data = load_dataset(config['validation_data'], config['batch_size']) # init model: if config['msg'] == True: model = CLS_MSG_Model(config['batch_size'], config['num_classes'], config['batch_normalization']) else: model = CLS_SSG_Model(config['batch_size'], config['num_classes'], config['batch_normalization']) # enable early stopping: callbacks = [ keras.callbacks.EarlyStopping('val_sparse_categorical_accuracy', min_delta=0.01, patience=10), keras.callbacks.TensorBoard('./logs/{}'.format(config['log_dir']), update_freq=50), keras.callbacks.ModelCheckpoint('./logs/{}/model/weights.ckpt'.format( config['log_dir']), 'val_sparse_categorical_accuracy', save_best_only=True) ] model.build(input_shape=(config['batch_size'], KITTIPCDClassificationDataset.N, KITTIPCDClassificationDataset.d + KITTIPCDClassificationDataset.C)) print(model.summary()) model.compile(optimizer=keras.optimizers.Adam(config['lr']), loss=keras.losses.SparseCategoricalCrossentropy(), metrics=[keras.metrics.SparseCategoricalAccuracy()]) model.fit(training_data, validation_data=validation_data, validation_steps=20, validation_freq=1, callbacks=callbacks, epochs=100, verbose=1)
def train(): if config['msg'] == True: model = CLS_MSG_Model(config['batch_size'], config['num_classes'], config['bn']) else: model = CLS_SSG_Model(config['batch_size'], config['num_classes'], config['bn']) train_ds = load_dataset(config['train_ds'], config['batch_size']) val_ds = load_dataset(config['val_ds'], config['batch_size']) callbacks = [ keras.callbacks.EarlyStopping('val_sparse_categorical_accuracy', min_delta=0.01, patience=10), keras.callbacks.TensorBoard('./logs/{}'.format(config['log_dir']), update_freq=50), keras.callbacks.ModelCheckpoint('./logs/{}/model/weights.ckpt'.format( config['log_dir']), 'val_sparse_categorical_accuracy', save_best_only=True) ] model.build(input_shape=(config['batch_size'], 8192, 3)) print(model.summary()) model.compile(optimizer=keras.optimizers.Adam(config['lr']), loss=keras.losses.SparseCategoricalCrossentropy(), metrics=[keras.metrics.SparseCategoricalAccuracy()]) model.fit(train_ds, validation_data=val_ds, validation_steps=20, validation_freq=1, callbacks=callbacks, epochs=100, verbose=1)