def train(config): """ Build network Parameters ---------- config: dict Model training configuration """ # load dataset: training_data = load_dataset(config['training_data'], config['batch_size']) validation_data = load_dataset(config['validation_data'], config['batch_size']) # init model: if config['msg'] == True: model = CLS_MSG_Model(config['batch_size'], config['num_classes'], config['batch_normalization']) else: model = CLS_SSG_Model(config['batch_size'], config['num_classes'], config['batch_normalization']) # enable early stopping: callbacks = [ keras.callbacks.EarlyStopping('val_sparse_categorical_accuracy', min_delta=0.01, patience=10), keras.callbacks.TensorBoard('./logs/{}'.format(config['log_dir']), update_freq=50), keras.callbacks.ModelCheckpoint('./logs/{}/model/weights.ckpt'.format( config['log_dir']), 'val_sparse_categorical_accuracy', save_best_only=True) ] model.build(input_shape=(config['batch_size'], KITTIPCDClassificationDataset.N, KITTIPCDClassificationDataset.d + KITTIPCDClassificationDataset.C)) print(model.summary()) model.compile(optimizer=keras.optimizers.Adam(config['lr']), loss=keras.losses.SparseCategoricalCrossentropy(), metrics=[keras.metrics.SparseCategoricalAccuracy()]) model.fit(training_data, validation_data=validation_data, validation_steps=20, validation_freq=1, callbacks=callbacks, epochs=100, verbose=1)
def train(): if config['msg'] == True: model = CLS_MSG_Model(config['batch_size'], config['num_classes'], config['bn']) else: model = CLS_SSG_Model(config['batch_size'], config['num_classes'], config['bn']) train_ds = load_dataset(config['train_ds'], config['batch_size']) val_ds = load_dataset(config['val_ds'], config['batch_size']) callbacks = [ keras.callbacks.EarlyStopping('val_sparse_categorical_accuracy', min_delta=0.01, patience=10), keras.callbacks.TensorBoard('./logs/{}'.format(config['log_dir']), update_freq=50), keras.callbacks.ModelCheckpoint('./logs/{}/model/weights.ckpt'.format( config['log_dir']), 'val_sparse_categorical_accuracy', save_best_only=True) ] model.build(input_shape=(config['batch_size'], 8192, 3)) print(model.summary()) model.compile(optimizer=keras.optimizers.Adam(config['lr']), loss=keras.losses.SparseCategoricalCrossentropy(), metrics=[keras.metrics.SparseCategoricalAccuracy()]) model.fit(train_ds, validation_data=val_ds, validation_steps=20, validation_freq=1, callbacks=callbacks, epochs=100, verbose=1)