示例#1
0
def train(config, options):
    """ Train the model """



    model_list = [config.model_type]

    for model in model_list:

        config.model_type = model

        # Put proper model name in the model and state paths (if not a genetic
        # model), then create the model

        if model in models.MODELS:
            for entry in ['model', 'state']:
                path = config.paths[entry]
                folder_name, file_name = os.path.split(path)
                file_name_parts = file_name.split('-')
                file_name_parts[0] = model
                file_name = '-'.join(file_name_parts)
                path = os.path.join(folder_name, file_name)
                config.paths[entry] = path

            cur_model = models.MODELS[model](config)
        else:
            cur_model = models.BaseSRCNNModel(model, config)
            cur_model.model, _ = genomics.build_model(model,
                                                      shape=config.image_shape,
                                                      learning_rate=config.learning_rate,
                                                      metrics=[cur_model.evaluation_function])
            errors = oops(False, cur_model.model is None, 'Compilation failed')
            terminate(errors, False)

        # Create and fit model (best model state will be automatically saved)

        print(cur_model)

        cur_config = cur_model.get_config()
        print('Model configuration:')
        for key in cur_config:
            print('{:>18s} : {}'.format(key, cur_config[key]))

        # If learning rate explicitly specified in options, reset it.

        if 'learning_rate' in options:
            print('Learning Rate reset to {}'.format(options['learning_rate']))
            cur_model.set_lr(options['learning_rate'])

        # PU: Cannot adjust ending epoch number until we load the model state,
        # which does not happen until we fit(). So we have to pass both
        # the max epoch and the run # of epochs.

        cur_model.fit(max_epochs=config.epochs, run_epochs=config.run_epochs)

    print('')
    print('Training completed...')
    exit(0)
示例#2
0
def test_terminate(capsys):
    """ Test of misc.terminate """

    # Do not terminate, not verbose

    out, err = capsys.readouterr()
    misc.terminate(False, False)
    out, err = capsys.readouterr()

    assert out == ''
    assert err == ''

    # Do not terminate, verbose

    out, err = capsys.readouterr()
    misc.terminate(False, True)
    out, err = capsys.readouterr()

    assert out == ''
    assert err == ''

    # terminate() is test aware, and will not exit if we are testing

    # Terminate, not verbose

    out, err = capsys.readouterr()
    misc.terminate(True, False)
    out, err = capsys.readouterr()

    assert out == '\n'
    assert err == ''

    # Terminate, verbose

    out, err = capsys.readouterr()
    misc.terminate(True, True)
    out, err = capsys.readouterr()

    assert out == _T1 + '\n\n'
    assert err == ''
示例#3
0
def setup(options):
    """ Create model configuration """

    # set up our initial state

    errors = False

    # If user specified a path to a model, then we have to fix up the model_type.
    # If the path is just a model name, delete the path, this will ensure that
    # the correct paths get set up by ModelIO()

    print(options)
    if 'paths' in options and 'model' in options['paths']:
        options['model_type'] = os.path.basename(options['paths']['model'])
        if options['paths']['model'] == options['model_type']:
            del options['paths']['model']

    print(options)
    config = ModelIO(options)
    dpath = config.paths['data']

    # Fix up model and state paths if genetic model and not explicitly specified in options

    if config.model_type not in models.MODELS:
        errors = oops(errors,
                      genomics.build_model(config.model_type)[0] is None,
                      '{} is an invalid model',
                      config.model_type)
        terminate(errors, False)

        if 'model_path' not in options:
            mpath = os.path.join(dpath, 'models', config.model_type + '.h5')
            if os.path.exists(mpath):
                config.paths['model'] = mpath
            else:
                mpath = os.path.join(dpath, 'models', 'genes', config.model_type + '.h5')
                if os.path.exists(mpath):
                    config.paths['model'] = mpath

        if 'state_path' not in options:
            spath = os.path.join(dpath, 'models', config.model_type + '.json')
            if os.path.exists(spath):
                config.paths['state'] = spath
            else:
                spath = os.path.join(dpath, 'models', 'genes', config.model_type + '.json')
                if os.path.exists(spath):
                    config.paths['state'] = spath


    # Validation and error checking

    image_paths = ['training', 'validation']
    sub_folders = ['Alpha', 'Beta']
    image_info = [[[], []], [[], []]]

    for fcnt, fpath in enumerate(image_paths):
        for scnt, spath in enumerate(sub_folders):
            image_info[fcnt][scnt] = frameops.image_files(os.path.join(config.paths[fpath], spath), True)

    for fcnt in [0, 1]:
        for scnt in [0, 1]:
            errors = oops(errors,
                          image_info[fcnt][scnt] is None,
                          '{} images folder does not exist',
                          image_paths[fcnt] + '/' + sub_folders[scnt])

    terminate(errors, False)

    for fcnt in [0, 1]:
        for scnt in [0, 1]:
            errors = oops(errors,
                          not image_info[fcnt][scnt],
                          '{} images folder does not contain any images',
                          image_paths[fcnt] + '/' + sub_folders[scnt])
            errors = oops(errors,
                          len(image_info[fcnt][scnt]) > 1,
                          '{} images folder contains more than one type of image',
                          image_paths[fcnt] + '/' + sub_folders[scnt])

    terminate(errors, False)

    for fcnt in [0, 1]:
        errors = oops(errors,
                      len(image_info[fcnt][0][0]) != len(image_info[fcnt][1][0]),
                      '{} images folders have different numbers of images',
                      image_paths[fcnt])

    terminate(errors, False)

    for fcnt in [0, 1]:
        for fpath1, fpath2 in zip(image_info[fcnt][0][0], image_info[fcnt][1][0]):
            fpath1, fpath2 = os.path.basename(fpath1), os.path.basename(fpath2)
            errors = oops(errors,
                          fpath1 != fpath2,
                          '{} images folders do not have identical image filenames ({} vs {})',
                          (image_paths[fcnt], fpath1, fpath2))
            terminate(errors, False)

    # Check sizes, even tiling here.

    #test_files = [[image_info[f][g][0][0] for g in [0, 1]] for f in [0, 1]]

    test_images = [[frameops.imread(image_info[f][g][0][0])
                    for g in [0, 1]] for f in [0, 1]]

    # Check that the Beta tiles are the same size.

    size1, size2 = np.shape(test_images[0][1]), np.shape(test_images[1][1])
    errors = oops(errors,
                  size1 != size2,
                  'Beta training and evaluation images do not have identical size ({} vs {})',
                  (size1, size2))

    # Warn if we do have some differences between Alpha and Beta sizes

    for fcnt in [0, 1]:
        size1, size2 = np.shape(test_images[fcnt][0]), np.shape(test_images[fcnt][1])
        if size1 != size2:
            print('Warning: {} Alpha and Beta images are not the same size.'.format(image_paths[fcnt].title()))

    terminate(errors, False)

    # Only check the size of the Beta output for proper configuration, since Alpha tiles will
    # be scaled as needed.

    errors = oops(errors,
                  len(size2) != 3 or size2[2] != 3,
                  'Images have improper shape ({})',
                  str(size2))

    terminate(errors, False)

    image_width, image_height = size2[1], size2[0]
    trimmed_width = image_width - (config.trim_left + config.trim_right)
    trimmed_height = image_height - (config.trim_top + config.trim_bottom)

    errors = oops(errors,
                  trimmed_width <= 0,
                  'Trimmed images have invalid width ({} - ({} + {}) <= 0)',
                  (size1[0], config.trim_left, config.trim_right))
    errors = oops(errors,
                  trimmed_height <= 0,
                  'Trimmed images have invalid height ({} - ({} + {}) <= 0)',
                  (size1[1], config.trim_top, config.trim_bottom))

    terminate(errors, False)

    errors = oops(errors,
                  (trimmed_width % config.base_tile_width) != 0,
                  'Trimmed images do not evenly tile horizontally ({} % {} != 0)',
                  (trimmed_width, config.tile_width))
    errors = oops(errors,
                  (trimmed_height % config.base_tile_height) != 0,
                  'Trimmed images do not evenly tile vertically ({} % {} != 0)',
                  (trimmed_height, config.tile_height))

    terminate(errors, False)

    # Attempt to automatically figure out the border color black level, by finding the minimum pixel value in one of our
    # sample images. This will definitely work if we are processing 1440x1080 4:3 embedded in 1920x1080 16:19 images.
    # Write back any change into config.

    if config.black_level < 0:
        config.black_level = np.min(test_images[0][0])
        config.config['black_level'] = config.black_level

    # Since we've gone to the trouble of reading in all the path data, let's make it available to our models for reuse

    for fcnt, fpath in enumerate(image_paths):
        for scnt, spath in enumerate(sub_folders):
            config.paths[fpath + '.' + spath] = image_info[fcnt][scnt]

    # Only at this point can we set default model and state filenames because that depends on image type.
    # If we are training a genetic model, then these will already have been set.

    if 'model' not in config.paths:
        name = '{}{}-{}-{}-{}-{}.h5'.format(
            config.model_type,
            '-R' if config.residual else '',
            config.tile_width,
            config.tile_height,
            config.border,
            config.img_suffix)
        config.paths['model'] = os.path.abspath(os.path.join(config.paths['data'], 'models', name))

    if 'state' not in config.paths:
        config.paths['state'] = config.paths['model'][:-3] + '.json'

    tpath = os.path.dirname(config.paths['state'])
    errors = oops(errors,
                  not os.path.exists(tpath),
                  'Model state path ({}) does not exist',
                  tpath)

    print(config.paths['state'])
    print(config.paths['model'])
    tpath = os.path.dirname(config.paths['model'])
    errors = oops(errors,
                  not os.path.exists(tpath),
                  'Model path ({}) does not exist',
                  tpath)

    # If we do have an existing json state, load it and override

    statepath = config.paths['state']
    if os.path.exists(statepath):
        if os.path.isfile(statepath):
            print('Loading existing Model state')
            try:
                with open(statepath, 'r') as jsonfile:
                    state = json.load(jsonfile)

                    # PU: Temp hack to change 'io' key to 'config'

                    if 'io' in state:
                        state['config'] = state['io']
                        del state['io']

            except json.decoder.JSONDecodeError:
                print('Could not parse json. Did you forget to delete a trailing comma?')
                errors = True
        else:
            errors = oops(errors,
                          True,
                          'Model state path is not a reference to a file ({})',
                          statepath)

        for setting in state['config']:
            if setting not in config.config or config.config[setting] != state['config'][setting]:
                config.config[setting] = state['config'][setting]

        # There are a couple of options that override saved configurations.

        config.config['run_epochs'] = options['run_epochs'] if 'run_epochs' in options else config.config['run_epochs']
        config.config['learning_rate'] = options['learning_rate'] if 'learning_rate' in options else config.config['learning_rate']

        # Reload config with possibly changed settings

        config = ModelIO(config.config)

    terminate(errors, False)

    # Remind user what we're about to do.

    print('             Model : {}'.format(config.model_type))
    print('        Tile Width : {}'.format(config.base_tile_width))
    print('       Tile Height : {}'.format(config.base_tile_height))
    print('       Tile Border : {}'.format(config.border))
    print('        Max Epochs : {}'.format(config.epochs))
    print('        Run Epochs : {}'.format(config.run_epochs))
    print('    Data root path : {}'.format(config.paths['data']))
    print('   Training Images : {}'.format(config.paths['training']))
    print(' Validation Images : {}'.format(config.paths['validation']))
    print('        Model File : {}'.format(config.paths['model']))
    print('  Model State File : {}'.format(config.paths['state']))
    print('  Image dimensions : {} x {}'.format(config.image_width, config.image_height))
    print('          Trimming : Top={}, Bottom={}, Left={}, Right={}'.format(
        config.trim_top, config.trim_bottom, config.trim_left, config.trim_right))
    print('Trimmed dimensions : {} x {}'.format(config.trimmed_width, config.trimmed_height))
    print('       Black level : {}'.format(config.black_level))
    print('            Jitter : {}'.format(config.jitter == 1))
    print('           Shuffle : {}'.format(config.shuffle == 1))
    print('              Skip : {}'.format(config.skip == 1))
    print('          Residual : {}'.format(config.residual == 1))
    print('     Learning Rate : {}'.format(config.learning_rate))
    print('           Quality : {}'.format(config.quality))
    print('')

    return config
示例#4
0
def setup(options):
    """Set up configuration """

    # Set remaining options

    options.setdefault('data', 'Data')
    dpath = options['data']

    options.setdefault(
        'model', os.path.join(dpath, 'models', 'BasicSR-R-60-60-2-dpx.h5'))
    options.setdefault('evaluation', os.path.join(dpath, 'eval_images'))

    if not options['model'].endswith('.h5'):
        options['model'] = options['model'] + '.h5'

    if os.path.dirname(options['model']) == '':
        options['model'] = os.path.join(dpath, 'models', options['model'])

    options['state'] = os.path.splitext(options['model'])[0] + '.json'

    model_type = os.path.basename(options['model']).split('-')[0]

    # Remind user what we're about to do.

    print('             Data : {}'.format(options['data']))
    print('Evaluation Images : {}'.format(options['evaluation']))
    print('            Model : {}'.format(options['model']))
    print('            State : {}'.format(options['state']))
    print('       Model Type : {}'.format(model_type))
    print('')

    # Validation and error checking

    errors = False
    for path in ['evaluation', 'state', 'model', 'data']:
        errors = oops(False, not os.path.exists(options[path]),
                      'Path to {} is not valid ({})', (path, options[path]))

    terminate(errors, False)

    # Load the actual model state

    with open(options['state'], 'r') as jsonfile:
        state = json.load(jsonfile)

    # Grab the config data (backwards compatible)

    config = state['config' if 'config' in state else 'io']

    # Create real config with configurable parameters. In particular we disable options like
    # jitter, shuffle, skip and quality.

    config['paths'] = options
    config['jitter'] = False
    config['shuffle'] = False
    config['skip'] = False
    config['edges'] = True
    config['quality'] = 1.0
    config['model_type'] = model_type

    config = ModelIO(config)

    # Check image files -- we do not explore subfolders. Note we have already checked
    # path validity above

    image_info = frameops.image_files(
        os.path.join(config.paths['evaluation'], config.alpha), False)

    errors = oops(False, not image_info,
                  'Input folder does not contain any images')
    errors = oops(errors,
                  len(image_info) > 1,
                  'Images folder contains more than one type of image')

    terminate(errors, False)

    # Get the list of files and check the filetype is correct

    image_info = image_info[0]

    image_ext = os.path.splitext(image_info[0])[1][1:].lower()

    errors = oops(errors, image_ext != config.img_suffix.lower(),
                  'Image files are of type {} but model was trained on {}',
                  (image_ext, config.img_suffix.lower()))

    terminate(errors, False)

    return (config, image_info)
示例#5
0
def setup(options):
    """Set up configuration """

    # Set up our initial state. Choosing to use a wide border because I was
    # seeing tile edge effects.

    errors = False
    genepool = {}
    options.setdefault('border', 10)
    options.setdefault('env', '')
    options['paths'].setdefault('genepool',
                                os.path.join('Data', 'genepool.json'))
    poolpath = options['paths']['genepool']

    if os.path.exists(poolpath):
        if os.path.isfile(poolpath):
            printlog('Loading existing genepool')
            try:
                with open(poolpath, 'r') as jsonfile:
                    genepool = json.load(jsonfile)

                # Change 'io' key to 'config' (backwards-compatibility)

                if 'io' in genepool:
                    genepool['config'] = genepool['io']
                    del genepool['io']

            except json.decoder.JSONDecodeError:
                printlog(
                    'Could not parse json. Did you edit "population" and forget to delete the trailing comma?'
                )
                errors = True
        else:
            errors = oops(errors, True,
                          'Genepool path is not a reference to a file ({})',
                          poolpath)
    else:
        errors = oops(errors,
                      not os.access(os.path.dirname(poolpath), os.W_OK),
                      'Genepool folder is not writeable ({})', poolpath)

    terminate(errors, False)

    # Genepool settings override config, so we need to update them

    for setting in genepool['config']:
        if setting not in options or options[setting] != genepool['config'][
                setting]:
            options[setting] = genepool['config'][setting]

    # Reload config with possibly changed settings

    config = ModelIO(options)

    # Validation and error checking

    import Modules.frameops as frameops

    image_paths = ['training', 'validation']
    sub_folders = ['Alpha', 'Beta']
    image_info = [[[], []], [[], []]]

    for fcnt, fpath in enumerate(image_paths):
        for scnt, _ in enumerate(sub_folders):
            image_info[fcnt][scnt] = frameops.image_files(
                os.path.join(config.paths[fpath], sub_folders[scnt]), True)

    for fcnt in [0, 1]:
        for scnt in [0, 1]:
            errors = oops(errors, image_info[fcnt][scnt] is None,
                          '{} images folder does not exist',
                          image_paths[fcnt] + '/' + sub_folders[scnt])

    terminate(errors, False)

    for fcnt in [0, 1]:
        for scnt in [0, 1]:
            errors = oops(errors,
                          len(image_info[fcnt][scnt]) == 0,
                          '{} images folder does not contain any images',
                          image_paths[fcnt] + '/' + sub_folders[scnt])
            errors = oops(
                errors,
                len(image_info[fcnt][scnt]) > 1,
                '{} images folder contains more than one type of image',
                image_paths[fcnt] + '/' + sub_folders[scnt])

    terminate(errors, False)

    for fcnt in [0, 1]:
        errors = oops(
            errors,
            len(image_info[fcnt][0][0]) != len(image_info[fcnt][1][0]),
            '{} images folders have different numbers of images',
            image_paths[fcnt])

    terminate(errors, False)

    for fcnt in [0, 1]:
        for path1, path2 in zip(image_info[fcnt][0][0],
                                image_info[fcnt][1][0]):
            path1, path2 = os.path.basename(path1), os.path.basename(path2)
            errors = oops(
                errors, path1 != path2,
                '{} images folders do not have identical image filenames ({} vs {})',
                (image_paths[fcnt], path1, path2))
            terminate(errors, False)

    # test_files = [[image_info[f][g][0][0] for g in [0, 1]] for f in [0, 1]]

    test_images = [[frameops.imread(image_info[f][g][0][0]) for g in [0, 1]]
                   for f in [0, 1]]

    # What kind of file is it? Do I win an award for the most brackets?

    # img_suffix = os.path.splitext(image_info[0][0][0][0])[1][1:]

    # Check that the Beta tiles are the same size.

    size1, size2 = np.shape(test_images[0][1]), np.shape(test_images[1][1])
    errors = oops(
        errors, size1 != size2,
        'Beta training and evaluation images do not have identical size ({} vs {})',
        (size1, size2))

    # Warn if we do have some differences between Alpha and Beta sizes

    for fcnt in [0, 1]:
        size1, size2 = np.shape(test_images[fcnt][0]), np.shape(
            test_images[fcnt][1])
        if size1 != size2:
            printlog(
                'Warning: {} Alpha and Beta images are not the same size. Will attempt to scale Alpha images.'
                .format(image_paths[fcnt].title()))

    terminate(errors, False)

    # Only check the size of the Beta output for proper configuration, since Alpha tiles will
    # be scaled as needed.

    errors = oops(errors,
                  len(size2) != 3 or size2[2] != 3,
                  'Images have improper shape ({0})', str(size2))

    terminate(errors, False)

    image_width, image_height = size2[1], size2[0]
    trimmed_width = image_width - (config.trim_left + config.trim_right)
    trimmed_height = image_height - (config.trim_top + config.trim_bottom)

    errors = oops(errors, trimmed_width <= 0,
                  'Trimmed images have invalid width ({} - ({} + {}) <= 0)',
                  (size1[0], config.trim_left, config.trim_right))
    errors = oops(errors, trimmed_width <= 0,
                  'Trimmed images have invalid height ({} - ({} + {}) <= 0)',
                  (size1[1], config.trim_top, config.trim_bottom))

    terminate(errors, False)

    errors = oops(
        errors, (trimmed_width % config.base_tile_width) != 0,
        'Trimmed images do not evenly tile horizontally ({} % {} != 0)',
        (trimmed_width, config.tile_width))
    errors = oops(
        errors, (trimmed_height % config.base_tile_height) != 0,
        'Trimmed images do not evenly tile vertically ({} % {} != 0)',
        (trimmed_height, config.tile_height))

    terminate(errors, False)

    # Attempt to automatically figure out the border color black level, by finding the minimum pixel value in one of our
    # sample images. This will definitely work if we are processing 1440x1080 4:3 embedded in 1920x1080 16:19 images.
    # Write back any change into config.

    if config.black_level < 0:
        config.black_level = np.min(test_images[0][0])
        config.config['black_level'] = config.black_level

    return (config, genepool, image_info)
示例#6
0
def setup(options):
    """Set up configuration """

    # Set remaining options

    options.setdefault('data', 'Data')
    dpath = options['data']

    options.setdefault('model',
                       os.path.join(dpath, 'models', 'BasicSR-60-60-2-dpx.h5'))

    basename = os.path.basename(options['model'])
    graphname = os.path.splitext(basename)[0] + '.png'
    options.setdefault('graph',
                       os.path.join(dpath, 'models', 'graphs', graphname))

    if not options['model'].endswith('.h5'):
        options['model'] = options['model'] + '.h5'

    if os.path.dirname(options['model']) == '':
        options['model'] = os.path.join(dpath, 'models', options['model'])

    options['state'] = os.path.splitext(options['model'])[0] + '.json'

    model_type = basename.split('-')[0]

    # Remind user what we're about to do.

    print('             Data : {}'.format(options['data']))
    print('            Model : {}'.format(options['model']))
    print('            State : {}'.format(options['state']))
    print('            Graph : {}'.format(options['graph']))
    print('       Model Type : {}'.format(model_type))
    print('')

    # Validation and error checking

    errors = False
    for path in ['graph', 'state', 'model', 'data']:
        errors = oops(False, not os.path.exists(options[path]),
                      'Path to {} is not valid ({})', (path, options[path]))

    terminate(errors, False)

    # Load the actual model state

    with open(options['state'], 'r') as jsonfile:
        state = json.load(jsonfile)

    # Grab the config data (backwards compatible)

    config = state['config' if 'config' in state else 'io']

    # Create real config with configurable parameters. In particular we disable options like
    # jitter, shuffle, skip and quality.

    config['paths'] = options
    config['model_type'] = model_type

    config = ModelIO(config)

    return (config, None)