def init_network( params=None, train=True, network_specfile=None, output_patch_shape=None, num_threads=None, optimize=None, force_fft=None ): ''' Initializes a random network using the Boost Python interface and configuration file options. The function will define this network by a parameter object (as generated by the front_end.parse function), or by the specified options. If both a parameter object and any optional arguments are specified, the parameter object will form the default options, and those will be overwritten by the other optional arguments ''' #Need to specify either a params object, or all of the other optional args #"ALL" optional args excludes train assert_arglist(params, [network_specfile, output_patch_shape, num_threads, optimize] ) #Defining phase argument by train argument phase = int(not train) #If a params object exists, then those options are the default if params is not None: if train: _output_patch_shape = params['train_outsz'] _optimize = params['is_train_optimize'] else: _output_patch_shape = params['forward_outsz'] _optimize = params['is_forward_optimize'] _force_fft = params['force_fft'] _network_specfile = params['fnet_spec'] _num_threads = params['num_threads'] #Overwriting defaults with any other optional args if network_specfile is not None: _network_specfile = network_specfile if output_patch_shape is not None: _output_patch_shape = output_patch_shape if num_threads is not None: _num_threads = num_threads if optimize is not None: _optimize = optimize if force_fft is not None: _force_fft = force_fft return pyznn.CNet(_network_specfile, _output_patch_shape, _num_threads, _optimize, phase, _force_fft)
def load_network(params=None, train=True, hdf5_filename=None, network_specfile=None, output_patch_shape=None, num_threads=None, optimize=None): ''' Loads a network from an hdf5 file. The function will define the loading process by a parameter object (as generated by the front_end.parse function), or by the specified options. If both a parameter object and any optional arguments are specified, the parameter object will form the default options, and those will be overwritten by the other optional arguments ''' #Need to specify either a params object, or all of the other optional args params_defined = params is not None #"ALL" optional args excludes train (it has a default) assert_arglist(params, [ hdf5_filename, network_specfile, output_patch_shape, num_threads, optimize ]) #Defining phase argument by train argument phase = int(not train) #If a params object exists, then those options are the default if params_defined: if train: _hdf5_filename = params['train_load_net'] _output_patch_shape = params['train_outsz'] _optimize = params['is_train_optimize'] else: _hdf5_filename = params['forward_net'] _output_patch_shape = params['forward_outsz'] _optimize = params['is_forward_optimize'] _network_specfile = params['fnet_spec'] _num_threads = params['num_threads'] #Overwriting defaults with any other optional args if hdf5_filename is not None: _hdf5_filename = hdf5_filename if network_specfile is not None: _network_specfile = network_specfile if output_patch_shape is not None: _output_patch_shape = output_patch_shape if num_threads is not None: _num_threads = num_threads #ACTUAL LOADING FUNCTIONALITY #This is a little strange to allow for "seeding" larger # nets with other training runs # 1) Initialize template net for network_specfile # 2) Load options from hdf5_filename (possibly containing the seed net) # 3) Consolidate options from the template and seed (see consolidate_opts above) # NOTE: the seed network could simply be the network we want to load, # in which case this will overwrite all of the relevant template opts # 4) Return consolidated CNet object template = init_network(params, train, _network_specfile, _output_patch_shape, _num_threads, False) #If the file doesn't exist, init a new network if os.path.isfile(_hdf5_filename): load_options = load_opts(_hdf5_filename) template_options = template.get_opts() del template print "consolidating options..." final_options = consolidate_opts(load_options, template_options, params) else: final_options = template.get_opts() del template return pyznn.CNet(final_options, _network_specfile, _output_patch_shape, _num_threads, _optimize, phase)