def main():
    # Parse command line args
    if len(sys.argv) < 2:
        print 'USAGE: validate-treshold.py project_file model_name best_epoch_number [models_dir]'
        sys.exit(1)
    project_file = sys.argv[1]
    model_name = sys.argv[2]
    best_epoch_number = int(sys.argv[3])
    if len(sys.argv) >= 5:
        models_dir = sys.argv[4]
    else:
        models_dir = '.'

    # Load the project file and the model
    project = Project(project_file)
    model_filename = os.path.join(models_dir, model_name + '.model')
    model = get_model_from_name(project, model_name)
    # Load the weight corresponding to the right epoch
    load_weights(model, model_name, model_filename, best_epoch_number)

    # Generate validation data
    validation_dataset = project.get_validation_dataset()
    validation_genconfig = project.get_validation_genconfig() #using the same validation dataset for threshold fine tuning
    input_data = validation_dataset.generate_input_data(validation_genconfig)
    output_data = validation_dataset.generate_ground_truth_data(validation_genconfig)


    validate_threshold(model,input_data,output_data)
def main():
    # Parse command line args
    if len(sys.argv) < 2:
        print 'USAGE: binarize_image.py project_file model_name best_epoch_number test_image_file [models_dir]'
        sys.exit(1)
    project_file = sys.argv[1]
    model_name = sys.argv[2]
    best_epoch_number = int(sys.argv[3])
    test_image_file = sys.argv[4]
    if len(sys.argv) >= 6:
        models_dir = sys.argv[5]
    else:
        models_dir = '.'

    # Load the project file and the model
    project = Project(project_file)
    model_filename = os.path.join(models_dir, model_name + '.model')
    model = get_model_from_name(project, model_name)
    # Load the weight corresponding to the right epoch
    load_weights(model, model_name, model_filename, best_epoch_number)
    # Load the image
    dataset = DataSet('.')
    dataset.imagespaths[test_image_file] = test_image_file
    data = TrainingData(dataset, dataset)
    patch_size = project.get_training_dataset().config['patch_size']
    data.set_config({'patch_size': patch_size})
    genconfig = data.exhaustive_gen_config([patch_size, patch_size])
    patches = data.generate_input_data(genconfig)
    # Binarize the image
    binarized_image = binarize(patches, patch_size, genconfig[test_image_file], model, dataset.open_image(test_image_file).size)

    # Draw the image using PIL for easier image saving method
    im = Image.fromarray(np.uint8(binarized_image*255))
    im.show()
def main():
    # Parse command line args
    if len(sys.argv) < 2:
        print 'USAGE: train.py project_file.json [output_dir]'
        sys.exit(1)
    project_file = sys.argv[1]
    if len(sys.argv) >= 3:
        output_dir = sys.argv[2]
    else:
        output_dir = '.'
    # Load project
    project = Project(project_file)
    # Generate training data
    training_dataset = project.get_training_dataset()
    training_genconfig = project.get_training_genconfig()
    input_data = training_dataset.generate_input_data(training_genconfig)
    output_data = training_dataset.generate_ground_truth_data(training_genconfig)
    print str(input_data.shape[0]) + " training samples were generated."
    print "Each sample contains " + str(input_data.shape[1]) + " pixels."
    # Create output directory if necessary
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    # Train each model
    for i in range(project.get_models_count()):
        model = project.get_model(i)
        model_name = project.get_model_name(i)
        print 'Training model ' + model_name + '...'
        model_path = os.path.join(output_dir, model_name + '.model')
        model.train(input_data, output_data, open(model_path, 'wb+'))
def main():
    # Parse command line args
    if len(sys.argv) < 2:
        print 'USAGE: validate.py project_file.json [models_dir]'
        sys.exit(1)
    project_file = sys.argv[1]
    if len(sys.argv) >= 3:
        models_dir = sys.argv[2]
    else:
        models_dir = '.'
    # Check that the models directory exists
    if not os.path.exists(models_dir):
        raise IOError('The models directory does not exist.')
    # Load project
    project = Project(project_file)
    # Generate validation data
    validation_dataset = project.get_validation_dataset()
    validation_genconfig = project.get_validation_genconfig()
    input_data = validation_dataset.generate_input_data(validation_genconfig)
    output_data = validation_dataset.generate_ground_truth_data(validation_genconfig)
    print str(input_data.shape[0]) + ' validation samples were generated.'
    print 'Each sample contains ' + str(input_data.shape[1]) + ' pixels.'
    # Validate each model
    for i in range(project.get_models_count()):
    #TODO : implement a validation method that doesn't need training to be finished, so that validation error could be visualized while the app is still running
        model = project.get_model(i)
        model_name = project.get_model_name(i)
        print 'Validating model ' + model_name + '...'
        model_filename = os.path.join(models_dir, model_name + '.model')
        validate(model, model_filename, model_name, input_data, output_data)
    # Show all plots
    plt.show()