def init_network( params=None, train=True, network_specfile=None,
            output_patch_shape=None, num_threads=None, optimize=None,
            force_fft=None ):
    '''
    Initializes a random network using the Boost Python interface and configuration
    file options.

    The function will define this network by a parameter object
    (as generated by the front_end.parse function), or by the specified options.

    If both a parameter object and any optional arguments are specified,
    the parameter object will form the default options, and those will be
    overwritten by the other optional arguments
    '''
    #Need to specify either a params object, or all of the other optional args
    #"ALL" optional args excludes train
    assert_arglist(params,
                [network_specfile, output_patch_shape,
                num_threads, optimize]
                )

    #Defining phase argument by train argument
    phase = int(not train)

    #If a params object exists, then those options are the default
    if params is not None:

        if train:
            _output_patch_shape = params['train_outsz']
            _optimize = params['is_train_optimize']
        else:
            _output_patch_shape = params['forward_outsz']
            _optimize = params['is_forward_optimize']

        _force_fft = params['force_fft']
        _network_specfile = params['fnet_spec']
        _num_threads = params['num_threads']

    #Overwriting defaults with any other optional args
    if network_specfile is not None:
        _network_specfile = network_specfile
    if output_patch_shape is not None:
        _output_patch_shape = output_patch_shape
    if num_threads is not None:
        _num_threads = num_threads
    if optimize is not None:
        _optimize = optimize
    if force_fft is not None:
        _force_fft = force_fft

    return pyznn.CNet(_network_specfile, _output_patch_shape,
                    _num_threads, _optimize, phase, _force_fft)
예제 #2
0
def load_network(params=None,
                 train=True,
                 hdf5_filename=None,
                 network_specfile=None,
                 output_patch_shape=None,
                 num_threads=None,
                 optimize=None):
    '''
    Loads a network from an hdf5 file.

    The function will define the loading process by a parameter object
    (as generated by the front_end.parse function), or by the specified options.

    If both a parameter object and any optional arguments are specified,
    the parameter object will form the default options, and those will be
    overwritten by the other optional arguments
    '''
    #Need to specify either a params object, or all of the other optional args
    params_defined = params is not None

    #"ALL" optional args excludes train (it has a default)
    assert_arglist(params, [
        hdf5_filename, network_specfile, output_patch_shape, num_threads,
        optimize
    ])

    #Defining phase argument by train argument
    phase = int(not train)

    #If a params object exists, then those options are the default
    if params_defined:

        if train:
            _hdf5_filename = params['train_load_net']
            _output_patch_shape = params['train_outsz']
            _optimize = params['is_train_optimize']
        else:
            _hdf5_filename = params['forward_net']
            _output_patch_shape = params['forward_outsz']
            _optimize = params['is_forward_optimize']

        _network_specfile = params['fnet_spec']
        _num_threads = params['num_threads']

    #Overwriting defaults with any other optional args
    if hdf5_filename is not None:
        _hdf5_filename = hdf5_filename
    if network_specfile is not None:
        _network_specfile = network_specfile
    if output_patch_shape is not None:
        _output_patch_shape = output_patch_shape
    if num_threads is not None:
        _num_threads = num_threads

    #ACTUAL LOADING FUNCTIONALITY
    #This is a little strange to allow for "seeding" larger
    # nets with other training runs
    # 1) Initialize template net for network_specfile
    # 2) Load options from hdf5_filename (possibly containing the seed net)
    # 3) Consolidate options from the template and seed (see consolidate_opts above)
    #  NOTE: the seed network could simply be the network we want to load,
    #        in which case this will overwrite all of the relevant template opts
    # 4) Return consolidated CNet object

    template = init_network(params, train, _network_specfile,
                            _output_patch_shape, _num_threads, False)

    #If the file doesn't exist, init a new network
    if os.path.isfile(_hdf5_filename):

        load_options = load_opts(_hdf5_filename)
        template_options = template.get_opts()
        del template

        print "consolidating options..."
        final_options = consolidate_opts(load_options, template_options,
                                         params)

    else:
        final_options = template.get_opts()
        del template

    return pyznn.CNet(final_options, _network_specfile, _output_patch_shape,
                      _num_threads, _optimize, phase)