from processors import SimpleDataGenerator
from readers import KittiDataReader

tf.get_logger().setLevel("ERROR")

DATA_ROOT = "../training"  # TODO make main arg
MODEL_ROOT = "./logs"

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

if __name__ == "__main__":

    params = Parameters()

    gpus = tf.config.experimental.list_physical_devices('GPU')
    loss = PointPillarNetworkLoss(params)
    optimizer = tf.keras.optimizers.Adam(lr=params.learning_rate,
                                         decay=params.decay_rate)
    if len(gpus) > 1:
        strategy = tf.distribute.MirroredStrategy(
            cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
        with strategy.scope():
            pillar_net = build_point_pillar_graph(params)
            pillar_net.load_weights(os.path.join(MODEL_ROOT, "model.h5"))
            pillar_net.compile(optimizer, loss=loss.losses())
    else:
        pillar_net = build_point_pillar_graph(params)
        pillar_net.load_weights(os.path.join(MODEL_ROOT, "model.h5"))
        pillar_net.compile(optimizer, loss=loss.losses())

    data_reader = KittiDataReader()
示例#2
0
from loss import PointPillarNetworkLoss
from network import build_point_pillar_graph
from processors import SimpleDataGenerator
from readers import KittiDataReader

DATA_ROOT = "/Users/chirag/Documents/Projects/682/PointPillars/data"  # TODO make main arg

TEST_ROOT = "/Users/chirag/Documents/Projects/682/PointPillars/data/testing"  # TODO make main arg

if __name__ == "__main__":

    params = Parameters()

    pillar_net = build_point_pillar_graph(params)

    loss = PointPillarNetworkLoss(params)

    optimizer = tf.keras.optimizers.Adam(lr=params.learning_rate, decay=params.decay_rate)

    pillar_net.compile(optimizer, loss=loss.losses(),metrics=['accuracy'])

    log_dir = "./logs"
    callbacks = [
        tf.keras.callbacks.TensorBoard(log_dir=log_dir),
        tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(log_dir, "model.h5"), save_best_only=True),
        tf.keras.callbacks.LearningRateScheduler(
            lambda epoch, lr: lr * 0.8 if ((epoch % 15 == 0) and (epoch != 0)) else lr, verbose=True),
        tf.keras.callbacks.EarlyStopping(patience=20),
    ]

    data_reader = KittiDataReader()