def main():
    # Parse command line args
    if len(sys.argv) < 2:
        print 'USAGE: binarize_image.py project_file model_name best_epoch_number test_image_file [models_dir]'
        sys.exit(1)
    project_file = sys.argv[1]
    model_name = sys.argv[2]
    best_epoch_number = int(sys.argv[3])
    test_image_file = sys.argv[4]
    if len(sys.argv) >= 6:
        models_dir = sys.argv[5]
    else:
        models_dir = '.'

    # Load the project file and the model
    project = Project(project_file)
    model_filename = os.path.join(models_dir, model_name + '.model')
    model = get_model_from_name(project, model_name)
    # Load the weight corresponding to the right epoch
    load_weights(model, model_name, model_filename, best_epoch_number)
    # Load the image
    dataset = DataSet('.')
    dataset.imagespaths[test_image_file] = test_image_file
    data = TrainingData(dataset, dataset)
    patch_size = project.get_training_dataset().config['patch_size']
    data.set_config({'patch_size': patch_size})
    genconfig = data.exhaustive_gen_config([patch_size, patch_size])
    patches = data.generate_input_data(genconfig)
    # Binarize the image
    binarized_image = binarize(patches, patch_size, genconfig[test_image_file], model, dataset.open_image(test_image_file).size)

    # Draw the image using PIL for easier image saving method
    im = Image.fromarray(np.uint8(binarized_image*255))
    im.show()
class Project(object):
    def __init__(self, config_filepath):
        self.config = json.load(open(config_filepath, 'r'))
        self.training = None
        self.training_genconfig = None
        self.validation = None
        self.validation_genconfig = None
        self.models = [None] * len(self.config['models'])
        self.load_gen_config = {
            'random': lambda dataset, conf: dataset.random_gen_config(conf['patches_per_image']),
            'exhaustive': lambda dataset, conf: dataset.exhaustive_gen_config(conf['patches_padding']),
            'load': lambda dataset, conf: dataset.load_gen_config(['conf.generation_file'])
        }

    def get_training_dataset(self):
        if self.training == None:
            print 'Preloading training dataset...'
            self.training = TrainingData(DataSet(self.config['training']['samples']), DataSet(self.config['training']['ground_truth']))
            self.training.set_config({'patch_size': self.config['patch_size']})
        return self.training

    def get_training_genconfig(self):
        if self.training_genconfig == None:
            print 'Generating training dataset configuration...'
            self.training_genconfig = self.load_gen_config.get(self.config['training']['generation_type'])(self.get_training_dataset(), self.config['training'])
        return self.training_genconfig

    def get_validation_dataset(self):
        if self.validation == None:
            print 'Generating validation dataset configuration...'
            self.validation = TrainingData(DataSet(self.config['validation']['samples']), DataSet(self.config['validation']['ground_truth']))
            self.validation.set_config({'patch_size': self.config['patch_size']})
        return self.validation

    def get_validation_genconfig(self):
        if self.validation_genconfig == None:
            print 'Generating validation dataset configuration...'
            self.validation_genconfig = self.load_gen_config.get(self.config['validation']['generation_type'])(self.get_validation_dataset(), self.config['validation'])
        return self.validation_genconfig

    def get_models_count(self):
        return len(self.models)

    def get_model(self, i):
        if self.models[i] == None:
            model_config = self.config['models'][i]
            print 'Configuring model ' + model_config['name'] + '...'
            self.models[i] = NeuralNetwork()
            patch_size = self.config['patch_size']
            self.models[i].initialise(self.config['patch_size'] ** 2, range(0, model_config['network']['learning_params']['epochs'], model_config['logging_period']), model_config['network'])
        return self.models[i]

    def get_model_name(self, i):
        return self.config['models'][i]['name']
 def get_validation_dataset(self):
     if self.validation == None:
         print 'Generating validation dataset configuration...'
         self.validation = TrainingData(DataSet(self.config['validation']['samples']), DataSet(self.config['validation']['ground_truth']))
         self.validation.set_config({'patch_size': self.config['patch_size']})
     return self.validation
 def get_training_dataset(self):
     if self.training == None:
         print 'Preloading training dataset...'
         self.training = TrainingData(DataSet(self.config['training']['samples']), DataSet(self.config['training']['ground_truth']))
         self.training.set_config({'patch_size': self.config['patch_size']})
     return self.training