Ejemplo n.º 1
0
def main():
    # capture the config path from the run arguments
    # then process the json configration file
    # try:


    data_loader = DataLoader(data_dir, config)
    data_loader.load_directory('.tif')
    data_loader.create_np_arrays()
    data_loader.create_data_label_pairs()

    preptt = PrepTrainTest(config, data_loader)

    for data_label_pair in data_loader.data_label_pairs:
        x_data = data_label_pair[0][data_loader]
        y_true = data_label_pair[1][data_loader.data_label_pairs[i][1][:, :, 0]]

        preptt.add_data(x_data, y_true)

    # Create the experiments dirs
    create_dirs([config.summary_dir, config.checkpoint_dir, config.input_dir])

    # Create tensorflow session
    sess = tf.Session()

    # Create instance of the model you want
    model = PopModel(config)

    # Load model if exist
    model.load(sess)

    # Create Tensorboard logger
    logger = Logger(sess, config)
    logger.log_config()

    # Create your data generator
    data = DataGenerator(config, preptraintest = preptt)

    data.create_traintest_data()

    # Create trainer and path all previous components to it
    trainer = PopTrainer(sess, model, data, config, logger)

    # Train model
    trainer.train()
Ejemplo n.º 2
0
def main():
    # capture the config path from the run arguments
    # then process the json configration file
    # try:
    args = get_args()
    if args.config != 'None':
        config = process_config(args.config)
    else:
        config = process_config(os.path.join(config_dir, 'example.json'))

    data_loader = DataLoader(data_dir)
    data_loader.load_directory('.tif')
    data_loader.create_np_arrays()

    preptt = PrepTrainTest(data_loader.arrays[0], data_loader.arrays[1], config.batch_size, config.chunk_height, config.chunk_width)
    prepd = PrepData(data_loader.arrays[0], data_loader.arrays[1], config.batch_size, config.chunk_height, config.chunk_width)

    # create the experiments dirs
    create_dirs([config.summary_dir, config.checkpoint_dir])
    # create tensorflow session
    sess = tf.Session()
    # create instance of the model you want
    model = PopModel(config)
    #load model if exist
    model.load(sess)
    # create your data generator
    data = DataGenerator(config, preptt, prepd)

    data.create_traintest_data()
    data.create_data()

    # Create Tensorboard logger
    logger = Logger(sess, config)

    # Create trainer and path all previous components to it
    tester = PopTrainer(sess, model, data, config, logger)

    # Test model
    tester.test()