コード例 #1
ファイル: forward.py プロジェクト: pranavcode/znn-release
def generate_full_output(Dataset, network, params, dtype="float32", verbose=True):
	Performs a full forward pass for a given ConfigSample object (Dataset) and
	a given network object.

    # Making sure loaded images expect same size output volume
    output_vol_shapes = Dataset.output_volume_shape()
    assert output_volume_shape_consistent(output_vol_shapes)
    output_vol_shape = output_vol_shapes.values()[0]

    Output = front_end.ConfigSampleOutput(params, network, output_vol_shape, dtype)

    input_num_patches = Dataset.num_patches()
    output_num_patches = Output.num_patches()

    assert num_patches_consistent(input_num_patches, output_num_patches)

    num_patches = output_num_patches.values()[0]

    for i in xrange(num_patches):

        if verbose:
            print "Output patch #{} of {}".format(i + 1, num_patches)  # i is just an index

        input_patches, junk = Dataset.get_next_patch()

        vol_ins = utils.make_continuous(input_patches, dtype=dtype)

        output = network.forward(vol_ins)


    return Output
コード例 #2
ファイル: forward.py プロジェクト: zlinzju/znn-release
def forward_pass(params, Dataset, network, verbose=True):
    Performs a full forward pass for a given ConfigSample object (Dataset) and
    a given network object.
    # Making sure loaded images expect same size output volume
    output_vol_shapes = Dataset.output_volume_shape()
    assert output_volume_shape_consistent(output_vol_shapes)
    output_vol_shape = output_vol_shapes.values()[0]
    Output = zsample.ConfigSampleOutput(params, network, output_vol_shape)
    input_num_patches = Dataset.num_patches()
    output_num_patches = Output.num_patches()
    assert num_patches_consistent(input_num_patches, output_num_patches)
    num_patches = output_num_patches.values()[0]

    for i in xrange(num_patches):
        if verbose:
            print "Output patch #{} of {}".format(
                i + 1, num_patches)  # i is just an index
        input_patches, junk = Dataset.get_next_patch()
        vol_ins = utils.make_continuous(input_patches)
        output = network.forward(vol_ins)
        if params['is_check']:
    # softmax if using softmax_loss
    if 'softmax' in params['cost_fn_str']:
        print "softmax filter..."
        Output = run_softmax(Output)

    return Output
コード例 #3
ファイル: forward.py プロジェクト: seung-lab/znn-release
def forward_pass( params, Dataset, network, verbose=True ):
    Performs a full forward pass for a given ConfigSample object (Dataset) and
    a given network object.
    # Making sure loaded images expect same size output volume
    output_vol_shapes = Dataset.output_volume_shape()
    assert output_volume_shape_consistent(output_vol_shapes)
    output_vol_shape = output_vol_shapes.values()[0]
    Output = zsample.ConfigSampleOutput( params, network, output_vol_shape)
    input_num_patches = Dataset.num_patches()
    output_num_patches = Output.num_patches()
    assert num_patches_consistent(input_num_patches, output_num_patches)
    num_patches = output_num_patches.values()[0]

    for i in xrange( num_patches ):
        if verbose:
	    print "Output patch #{} of {}".format(i+1, num_patches) # i is just an index
        input_patches, junk = Dataset.get_next_patch()
	vol_ins = utils.make_continuous(input_patches)
	output = network.forward( vol_ins )
        Output.set_next_patch( output )
        if params['is_check']:
    # softmax if using softmax_loss
    if 'softmax' in params['cost_fn_str']:
        print "softmax filter..."
        Output = run_softmax( Output )

    return Output
コード例 #4
def generate_full_output( Dataset, network, dtype='float32', verbose=True ):
	Performs a full forward pass for a given ConfigSample object (Dataset) and
	a given network object.

	# Making sure loaded images expect same size output volume
	output_vol_shapes = Dataset.output_volume_shape()
	assert output_volume_shape_consistent(output_vol_shapes)
	output_vol_shape = output_vol_shapes.values()[0]

	Output = front_end.ConfigSampleOutput( network, output_vol_shape, dtype )

	input_num_patches = Dataset.num_patches()
	output_num_patches = Output.num_patches()

	assert num_patches_consistent(input_num_patches, output_num_patches)

	num_patches = output_num_patches.values()[0]

	for i in xrange( num_patches ):

		if verbose:
			print "Output patch #{} of {}".format(i+1, num_patches) # i is just an index

		input_patches, junk = Dataset.get_next_patch()

		vol_ins = utils.make_continuous(input_patches, dtype=dtype)

		output = network.forward( vol_ins )

		Output.set_next_patch( output )

	return Output
コード例 #5
def check_gradient(pars, net, smp, h=0.00001):
    gradient check method:

    Note that this function is currently not working!
    We should get the gradient from the C++ core, which is not implemented yet.
    # get random sub volume from sample
    vol_ins, lbl_outs, msks, wmsks = smp.get_random_sample()

    # numerical gradient
    # apply the transformations in memory rather than array view
    vol_ins = utils.make_continuous(vol_ins)
    # shift the input to compute the analytical gradient
    vol_ins1 = dict()
    vol_ins2 = dict()
    for key, val in vol_ins.iteritems():
        vol_ins1[key] = val - h
        vol_ins2[key] = val + h
        assert np.any(vol_ins1[key] != vol_ins2[key])
    props = net.forward(vol_ins)
    props1 = net.forward(vol_ins1)
    props2 = net.forward(vol_ins2)
    import copy
    props_tmp, cerr, grdts = pars['cost_fn'](copy.deepcopy(props), lbl_outs,

    # compute the analytical gradient
    for key, g in grdts.iteritems():
        lbl = lbl_outs[key]
        prop = props[key]
        prop1 = props1[key]
        prop2 = props2[key]
        ag = (prop2 - prop1) / (2 * h)
        error = g - ag

        # label value
        print "ground truth label: ", lbl[0, ...]
        print "forward output: ", prop[0, ...]
        print "forward output - h: ", prop1[0, ...]
        print "forward output + h: ", prop2[0, ...]
        print "numerical gradient: ", g[0, ...]
        print "analytical gradient: ", ag[0, ...]
        # check the error range
        print "gradient error: ", error[0, ...]

        com = emirt.show.CompareVol((lbl[0, ...], prop[0, ...], prop1[0, ...],
                                     prop2[0, ...], g[0, ...], ag[0, ...]))

        # check the relative error
        rle = np.abs(ag - g) / (np.maximum(np.abs(ag), np.abs(g)))
        print "relative gradient error: ", rle[0, ...]
        assert error.max < 10 * h * h
        assert rle.max() < 0.01
コード例 #6
ファイル: zcheck.py プロジェクト: seung-lab/znn-release
def check_gradient(pars, net, smp, h=0.00001):
    gradient check method:

    Note that this function is currently not working!
    We should get the gradient from the C++ core, which is not implemented yet.
    # get random sub volume from sample
    vol_ins, lbl_outs, msks, wmsks = smp.get_random_sample()

    # numerical gradient
    # apply the transformations in memory rather than array view
    vol_ins = utils.make_continuous(vol_ins)
    # shift the input to compute the analytical gradient
    vol_ins1 = dict()
    vol_ins2 = dict()
    for key, val in vol_ins.iteritems():
        vol_ins1[key] = val - h
        vol_ins2[key] = val + h
        assert np.any(vol_ins1[key]!=vol_ins2[key])
    props = net.forward( vol_ins )
    props1 = net.forward( vol_ins1 )
    props2 = net.forward( vol_ins2 )
    import copy
    props_tmp, cerr, grdts = pars['cost_fn']( copy.deepcopy(props), lbl_outs, msks )

    # compute the analytical gradient
    for key, g in grdts.iteritems():
        lbl = lbl_outs[key]
        prop = props[key]
        prop1 = props1[key]
        prop2 = props2[key]
        ag = (prop2 - prop1)/ (2 * h)
        error = g-ag

        # label value
        print "ground truth label: ", lbl[0,...]
        print "forward output: ", prop[0,...]
        print "forward output - h: ", prop1[0,...]
        print "forward output + h: ", prop2[0,...]
        print "numerical gradient: ", g[0,...]
        print "analytical gradient: ", ag[0,...]
        # check the error range
        print "gradient error: ", error[0,...]

        com = emirt.show.CompareVol((lbl[0,...], prop[0,...], prop1[0,...], prop2[0,...], g[0,...], ag[0,...]))

        # check the relative error
        rle = np.abs(ag-g) / (np.maximum(np.abs(ag),np.abs(g)))
        print "relative gradient error: ", rle[0,...]
        assert error.max < 10*h*h
        assert rle.max() < 0.01
コード例 #7
ファイル: test.py プロジェクト: yanweifu/znn-release
def _single_test(net, pars, sample):
    vol_ins, lbl_outs, msks = sample.get_random_sample()

    # forward pass
    vol_ins = utils.make_continuous(vol_ins, dtype=pars['dtype'])
    props = net.forward(vol_ins)

    # cost function and accumulate errors
    props, err, grdts = pars['cost_fn'](props, lbl_outs)
    cls = cost_fn.get_cls(props, lbl_outs)
    return props, err, cls
コード例 #8
ファイル: test.py プロジェクト: Nuzhny007/znn-release
def _single_test(net, pars, sample):
    vol_ins, lbl_outs, msks = sample.get_random_sample()

    # forward pass
    vol_ins = utils.make_continuous(vol_ins, dtype=pars['dtype'])
    props = net.forward( vol_ins )

    # cost function and accumulate errors
    props, err, grdts = pars['cost_fn']( props, lbl_outs )
    cls = cost_fn.get_cls(props, lbl_outs)
    return props, err, cls
コード例 #9
ファイル: test.py プロジェクト: muqiao0626/znn-release
def _single_test(net, pars, sample):
    vol_ins, lbl_outs, msks, wmsks = sample.get_random_sample()

    # forward pass
    vol_ins = utils.make_continuous(vol_ins, dtype=pars['dtype'])
    props = net.forward( vol_ins )

    # cost function and accumulate errors
    props, err, grdts = pars['cost_fn']( props, lbl_outs )
    cls = cost_fn.get_cls(props, lbl_outs)

    re = 0.0
    malis_cls = 0.0
    if pars['is_malis']:
        malis_weights, rand_errors = cost_fn.malis_weight( pars, props, lbl_outs )
        re = rand_errors.values()[0]
        # dictionary of malis classification error
        mcd = utils.get_malis_cls( props, lbl_outs, malis_weights )
        malis_cls = mcd.values()[0]
    return props, err, cls, re, malis_cls
コード例 #10
ファイル: test.py プロジェクト: zlinzju/znn-release
def _single_test(net, pars, sample):
    vol_ins, lbl_outs, msks, wmsks = sample.get_random_sample()

    # forward pass
    vol_ins = utils.make_continuous(vol_ins)
    props = net.forward( vol_ins )

    # cost function and accumulate errors
    props, err, grdts = pars['cost_fn']( props, lbl_outs )
    # pixel classification error
    cls = cost_fn.get_cls(props, lbl_outs)
    # rand error
    re = pyznn.get_rand_error(props.values()[0], lbl_outs.values()[0])

    malis_cls = 0.0
    malis_eng = 0.0

    if pars['is_malis']:
        malis_weights, rand_errors, num_non_bdr = cost_fn.malis_weight( pars, props, lbl_outs )
        # dictionary of malis classification error
        dmc, dme = utils.get_malis_cost( props, lbl_outs, malis_weights )
        malis_cls = dmc.values()[0]
        malis_eng = dme.values()[0]
    return props, err, cls, re, malis_cls, malis_eng
コード例 #11
ファイル: train.py プロジェクト: pranavcode/znn-release
def main( conf_file='config.cfg', logfile=None ):
    #%% parameters
    print "reading config parameters..."
    config, pars = front_end.parser( conf_file )

    if pars.has_key('logging') and pars['logging']:
        print "recording configuration file..."
        front_end.record_config_file( pars )

        logfile = front_end.make_logfile_name( pars )

    #%% create and initialize the network
    if pars['train_load_net'] and os.path.exists(pars['train_load_net']):
        print "loading network..."
        net = netio.load_network( pars )
        # load existing learning curve
        lc = zstatistics.CLearnCurve( pars['train_load_net'] )
        if pars['train_seed_net'] and os.path.exists(pars['train_seed_net']):
            print "seeding network..."
            net = netio.load_network( pars, is_seed=True )
            print "initializing network..."
            net = netio.init_network( pars )
        # initalize a learning curve
        lc = zstatistics.CLearnCurve()

    # show field of view
    print "field of view: ", net.get_fov()
    print "output volume info: ", net.get_outputs_setsz()

    # set some parameters
    print 'setting up the network...'
    vn = utils.get_total_num(net.get_outputs_setsz())
    eta = pars['eta'] #/ vn
    net.set_eta( eta )
    net.set_momentum( pars['momentum'] )
    net.set_weight_decay( pars['weight_decay'] )

    # initialize samples
    outsz = pars['train_outsz']
    print "\n\ncreate train samples..."
    smp_trn = front_end.CSamples(config, pars, pars['train_range'], net, outsz, logfile)
    print "\n\ncreate test samples..."
    smp_tst = front_end.CSamples(config, pars, pars['test_range'],  net, outsz, logfile)

    # initialization
    elapsed = 0
    err = 0
    cls = 0

    # interactive visualization

    # the last iteration we want to continue training
    iter_last = lc.get_last_it()

    print "start training..."
    start = time.time()
    print "start from ", iter_last+1
    for i in xrange(iter_last+1, pars['Max_iter']+1):
        vol_ins, lbl_outs, msks = smp_trn.get_random_sample()

        # forward pass
        vol_ins = utils.make_continuous(vol_ins, dtype=pars['dtype'])

        props = net.forward( vol_ins )

        # cost function and accumulate errors
        props, cerr, grdts = pars['cost_fn']( props, lbl_outs )
        err = err + cerr
        cls = cls + cost_fn.get_cls(props, lbl_outs)

        # mask process the gradient
        grdts = utils.dict_mul(grdts, msks)

        # run backward pass
        grdts = utils.make_continuous(grdts, dtype=pars['dtype'])
        net.backward( grdts )

        if pars['is_malis'] :
            malis_weights = cost_fn.malis_weight(props, lbl_outs)
            grdts = utils.dict_mul(grdts, malis_weights)

        if i%pars['Num_iter_per_test']==0:
            # test the net
            lc = test.znn_test(net, pars, smp_tst, vn, i, lc)

        if i%pars['Num_iter_per_show']==0:
            # anneal factor
            eta = eta * pars['anneal_factor']
            # normalize
            err = err / vn / pars['Num_iter_per_show']
            cls = cls / vn / pars['Num_iter_per_show']
            lc.append_train(i, err, cls)

            # time
            elapsed = time.time() - start
            elapsed = elapsed / pars['Num_iter_per_show']

            show_string = "iteration %d,    err: %.3f,    cls: %.3f,   elapsed: %.1f s/iter, learning rate: %.6f"\
                    %(i, err, cls, elapsed, eta )

            if pars.has_key('logging') and pars['logging']:
                utils.write_to_log(logfile, show_string)
            print show_string

            if pars['is_visual']:
                # show results To-do: run in a separate thread
                front_end.inter_show(start, lc, eta, vol_ins, props, lbl_outs, grdts, pars)
                if pars['is_rebalance'] and 'aff' not in pars['out_type']:
                    plt.imshow(msks.values()[0][0,0,:,:], interpolation='nearest', cmap='gray')
                    plt.xlabel('rebalance weight')
                if pars['is_malis']:
                    plt.imshow(malis_weights.values()[0][0,0,:,:], interpolation='nearest', cmap='gray')
                    plt.xlabel('malis weight (log)')
            # reset err and cls
            err = 0
            cls = 0
            # reset time
            start = time.time()

        if i%pars['Num_iter_per_save']==0:
            # save network
            netio.save_network(net, pars['train_save_net'], num_iters=i)
            lc.save( pars, elapsed )
コード例 #12
def main(args):
    config, pars, logfile = parse_args(args)
    #%% create and initialize the network
    net, lc = znetio.create_net(pars)

    # total voxel number of output volumes
    vn = utils.get_total_num(net.get_outputs_setsz())

    # initialize samples
    outsz = pars['train_outsz']
    print "\n\ncreate train samples..."
    smp_trn = zsample.CSamples(config, pars, pars['train_range'], net, outsz,
    print "\n\ncreate test samples..."
    smp_tst = zsample.CSamples(config, pars, pars['test_range'], net, outsz,

    if pars['is_check']:
        import zcheck
        zcheck.check_patch(pars, smp_trn)
        # gradient check is not working now.
        # zcheck.check_gradient(pars, net, smp_trn)

    # initialization
    eta = pars['eta']
    elapsed = 0
    err = 0.0  # cost energy
    cls = 0.0  # pixel classification error
    re = 0.0  # rand error
    # number of voxels which accumulate error
    # (if a mask exists)
    num_mask_voxels = 0

    if pars['is_malis']:
        malis_cls = 0.0
        malis_eng = 0.0
        malis_weights = None

    # the last iteration we want to continue training
    iter_last = lc.get_last_it()

    print "start training..."
    start = time.time()
    total_time = 0.0
    print "start from ", iter_last + 1

    #Saving initial/seeded network
    # get file name
    fname, fname_current = znetio.get_net_fname(pars['train_net_prefix'],
    znetio.save_network(net, fname, pars['is_stdio'])
    lc.save(pars, fname, elapsed=0.0, suffix="init_iter{}".format(iter_last))
    # no nan detected
    nonan = True

    for i in xrange(iter_last + 1, pars['Max_iter'] + 1):
        # time cumulation
        total_time += time.time() - start
        start = time.time()

        # get random sub volume from sample
        vol_ins, lbl_outs, msks, wmsks = smp_trn.get_random_sample()

        # forward pass
        # apply the transformations in memory rather than array view
        vol_ins = utils.make_continuous(vol_ins)
        props = net.forward(vol_ins)

        # cost function and accumulate errors
        props, cerr, grdts = pars['cost_fn'](props, lbl_outs, msks)
        err += cerr
        cls += cost_fn.get_cls(props, lbl_outs)
        # compute rand error
        if pars['is_debug']:
            assert not np.all(lbl_outs.values()[0] == 0)
        re += pyznn.get_rand_error(props.values()[0], lbl_outs.values()[0])
        num_mask_voxels += utils.sum_over_dict(msks)

        # check whether there is a NaN here!
        if pars['is_debug']:
            nonan = nonan and utils.check_dict_nan(vol_ins)
            nonan = nonan and utils.check_dict_nan(lbl_outs)
            nonan = nonan and utils.check_dict_nan(msks)
            nonan = nonan and utils.check_dict_nan(wmsks)
            nonan = nonan and utils.check_dict_nan(props)
            nonan = nonan and utils.check_dict_nan(grdts)
            if not nonan:
                utils.inter_save(pars, net, lc, vol_ins, props, lbl_outs, \
                             grdts, malis_weights, wmsks, elapsed, i)
                # stop training

        # gradient reweighting
        grdts = utils.dict_mul(grdts, msks)
        if pars['rebalance_mode']:
            grdts = utils.dict_mul(grdts, wmsks)

        if pars['is_malis']:
            malis_weights, rand_errors, num_non_bdr = cost_fn.malis_weight(
                pars, props, lbl_outs)
            if num_non_bdr <= 1:
                # skip this iteration
            grdts = utils.dict_mul(grdts, malis_weights)
            dmc, dme = utils.get_malis_cost(props, lbl_outs, malis_weights)
            malis_cls += dmc.values()[0]
            malis_eng += dme.values()[0]

        # run backward pass
        grdts = utils.make_continuous(grdts)

        total_time += time.time() - start
        start = time.time()

        if i % pars['Num_iter_per_show'] == 0:
            # time
            elapsed = total_time / pars['Num_iter_per_show']

            # normalize
            if utils.dict_mask_empty(msks):
                err = err / vn / pars['Num_iter_per_show']
                cls = cls / vn / pars['Num_iter_per_show']
                err = err / num_mask_voxels / pars['Num_iter_per_show']
                cls = cls / num_mask_voxels / pars['Num_iter_per_show']
            re = re / pars['Num_iter_per_show']
            lc.append_train(i, err, cls, re)

            if pars['is_malis']:
                malis_cls = malis_cls / pars['Num_iter_per_show']
                malis_eng = malis_eng / pars['Num_iter_per_show']

                show_string = "update %d,    cost: %.3f, pixel error: %.3f, rand error: %.3f, me: %.3f, mc: %.3f, elapsed: %.1f s/iter, learning rate: %.5f"\
                              %(i, err, cls, re, malis_eng, malis_cls, elapsed, eta )
                show_string = "update %d,    cost: %.3f, pixel error: %.3f, rand error: %.3f, elapsed: %.1f s/iter, learning rate: %.5f"\
                    %(i, err, cls, re, elapsed, eta )

            if pars.has_key('logging') and pars['logging']:
                utils.write_to_log(logfile, show_string)
            print show_string

            # reset err and cls
            err = 0
            cls = 0
            re = 0
            num_mask_voxels = 0

            if pars['is_malis']:
                malis_cls = 0

            # reset time
            total_time = 0
            start = time.time()

        # test the net
        if i % pars['Num_iter_per_test'] == 0:
            # time accumulation should skip the test
            total_time += time.time() - start
            lc = test.znn_test(net, pars, smp_tst, vn, i, lc)
            start = time.time()

        if i % pars['Num_iter_per_save'] == 0:
            utils.inter_save(pars, net, lc, vol_ins, props, lbl_outs, \
                             grdts, malis_weights, wmsks, elapsed, i)

        if i % pars['Num_iter_per_annealing'] == 0:
            # anneal factor
            eta = eta * pars['anneal_factor']

        # stop the iteration at checking mode
        if pars['is_check']:
            print "only need one iteration for checking, stop program..."
コード例 #13
ファイル: train.py プロジェクト: muqiao0626/znn-release
def main( conf_file='config.cfg', logfile=None ):
    #%% parameters
    print "reading config parameters..."
    config, pars = zconfig.parser( conf_file )

    if pars.has_key('logging') and pars['logging']:
        print "recording configuration file..."
        zlog.record_config_file( pars )

        logfile = zlog.make_logfile_name( pars )

    #%% create and initialize the network
    if pars['train_load_net'] and os.path.exists(pars['train_load_net']):
        print "loading network..."
        net = znetio.load_network( pars )
        # load existing learning curve
        lc = zstatistics.CLearnCurve( pars['train_load_net'] )
        # the last iteration we want to continue training
        iter_last = lc.get_last_it()
        if pars['train_seed_net'] and os.path.exists(pars['train_seed_net']):
            print "seeding network..."
            net = znetio.load_network( pars, is_seed=True )
            print "initializing network..."
            net = znetio.init_network( pars )
        # initalize a learning curve
        lc = zstatistics.CLearnCurve()
        iter_last = lc.get_last_it()

    # show field of view
    print "field of view: ", net.get_fov()

    # total voxel number of output volumes
    vn = utils.get_total_num(net.get_outputs_setsz())

    # set some parameters
    print 'setting up the network...'
    eta = pars['eta']
    net.set_eta( pars['eta'] )
    net.set_momentum( pars['momentum'] )
    net.set_weight_decay( pars['weight_decay'] )

    # initialize samples
    outsz = pars['train_outsz']
    print "\n\ncreate train samples..."
    smp_trn = zsample.CSamples(config, pars, pars['train_range'], net, outsz, logfile)
    print "\n\ncreate test samples..."
    smp_tst = zsample.CSamples(config, pars, pars['test_range'],  net, outsz, logfile)

    # initialization
    elapsed = 0
    err = 0.0 # cost energy
    cls = 0.0 # pixel classification error
    re = 0.0  # rand error
    # number of voxels which accumulate error
    # (if a mask exists)
    num_mask_voxels = 0

    if pars['is_malis']:
        malis_cls = 0.0

    print "start training..."
    start = time.time()
    total_time = 0.0
    print "start from ", iter_last+1

    #Saving initialized network
    if iter_last+1 == 1:
        znetio.save_network(net, pars['train_save_net'], num_iters=0)
        lc.save( pars, 0.0 )

    for i in xrange(iter_last+1, pars['Max_iter']+1):
        # get random sub volume from sample
        vol_ins, lbl_outs, msks, wmsks = smp_trn.get_random_sample()

        # forward pass
        # apply the transformations in memory rather than array view
        vol_ins = utils.make_continuous(vol_ins, dtype=pars['dtype'])
        props = net.forward( vol_ins )

        # cost function and accumulate errors
        props, cerr, grdts = pars['cost_fn']( props, lbl_outs, msks )
        err += cerr
        cls += cost_fn.get_cls(props, lbl_outs)
        num_mask_voxels += utils.sum_over_dict(msks)

        # gradient reweighting
        grdts = utils.dict_mul( grdts, msks  )
        grdts = utils.dict_mul( grdts, wmsks )

        if pars['is_malis'] :
            malis_weights, rand_errors = cost_fn.malis_weight(pars, props, lbl_outs)
            grdts = utils.dict_mul(grdts, malis_weights)
            # accumulate the rand error
            re += rand_errors.values()[0]
            malis_cls_dict = utils.get_malis_cls( props, lbl_outs, malis_weights )
            malis_cls += malis_cls_dict.values()[0]

        total_time += time.time() - start
        start = time.time()

        # test the net
        if i%pars['Num_iter_per_test']==0:
            lc = test.znn_test(net, pars, smp_tst, vn, i, lc)

        if i%pars['Num_iter_per_show']==0:
            # normalize
            if utils.dict_mask_empty(msks):
                err = err / vn / pars['Num_iter_per_show']
                cls = cls / vn / pars['Num_iter_per_show']
                err = err / num_mask_voxels / pars['Num_iter_per_show']
                cls = cls / num_mask_voxels / pars['Num_iter_per_show']

            lc.append_train(i, err, cls)

            # time
            elapsed = total_time / pars['Num_iter_per_show']

            if pars['is_malis']:
                re = re / pars['Num_iter_per_show']
                lc.append_train_rand_error( re )
                malis_cls = malis_cls / pars['Num_iter_per_show']
                lc.append_train_malis_cls( malis_cls )

                show_string = "iteration %d,    err: %.3f, cls: %.3f, re: %.6f, mc: %.3f, elapsed: %.1f s/iter, learning rate: %.6f"\
                              %(i, err, cls, re, malis_cls, elapsed, eta )
                show_string = "iteration %d,    err: %.3f, cls: %.3f, elapsed: %.1f s/iter, learning rate: %.6f"\
                    %(i, err, cls, elapsed, eta )

            if pars.has_key('logging') and pars['logging']:
                utils.write_to_log(logfile, show_string)
            print show_string

            # reset err and cls
            err = 0
            cls = 0
            re = 0
            num_mask_voxels = 0

            if pars['is_malis']:
                malis_cls = 0

            # reset time
            total_time  = 0
            start = time.time()

        if i%pars['Num_iter_per_annealing']==0:
            # anneal factor
            eta = eta * pars['anneal_factor']

        if i%pars['Num_iter_per_save']==0:
            # save network
            znetio.save_network(net, pars['train_save_net'], num_iters=i)
            lc.save( pars, elapsed )
            if pars['is_malis']:
                utils.save_malis(malis_weights,  pars['train_save_net'], num_iters=i)

        # run backward pass
        grdts = utils.make_continuous(grdts, dtype=pars['dtype'])
        net.backward( grdts )
コード例 #14
def main(conf_file='config.cfg', logfile=None):
    #%% parameters
    print "reading config parameters..."
    config, pars = front_end.parser(conf_file)

    if pars.has_key('logging') and pars['logging']:
        print "recording configuration file..."

        logfile = front_end.make_logfile_name(pars)

    #%% create and initialize the network
    if pars['train_load_net'] and os.path.exists(pars['train_load_net']):
        print "loading network..."
        net = netio.load_network(pars)
        # load existing learning curve
        lc = zstatistics.CLearnCurve(pars['train_load_net'])
        if pars['train_seed_net'] and os.path.exists(pars['train_seed_net']):
            print "seeding network..."
            net = netio.seed_network(pars, is_seed=True)
            print "initializing network..."
            net = netio.init_network(pars)
        # initalize a learning curve
        lc = zstatistics.CLearnCurve()

    # show field of view
    print "field of view: ", net.get_fov()
    print "output volume info: ", net.get_outputs_setsz()

    # set some parameters
    print 'setting up the network...'
    vn = utils.get_total_num(net.get_outputs_setsz())
    eta = pars['eta']  #/ vn

    # initialize samples
    outsz = pars['train_outsz']
    print "\n\ncreate train samples..."
    smp_trn = front_end.CSamples(config, pars, pars['train_range'], net, outsz,
    print "\n\ncreate test samples..."
    smp_tst = front_end.CSamples(config, pars, pars['test_range'], net, outsz,

    # initialization
    elapsed = 0
    err = 0
    cls = 0

    # interactive visualization

    # the last iteration we want to continue training
    iter_last = lc.get_last_it()

    print "start training..."
    start = time.time()
    print "start from ", iter_last + 1
    for i in xrange(iter_last + 1, pars['Max_iter'] + 1):
        vol_ins, lbl_outs, msks = smp_trn.get_random_sample()

        # forward pass
        vol_ins = utils.make_continuous(vol_ins, dtype=pars['dtype'])

        props = net.forward(vol_ins)

        # cost function and accumulate errors
        props, cerr, grdts = pars['cost_fn'](props, lbl_outs)
        err = err + cerr
        cls = cls + cost_fn.get_cls(props, lbl_outs)

        # mask process the gradient
        grdts = utils.dict_mul(grdts, msks)

        # run backward pass
        grdts = utils.make_continuous(grdts, dtype=pars['dtype'])

        if pars['is_malis']:
            malis_weights = cost_fn.malis_weight(props, lbl_outs)
            grdts = utils.dict_mul(grdts, malis_weights)

        if i % pars['Num_iter_per_test'] == 0:
            # test the net
            lc = test.znn_test(net, pars, smp_tst, vn, i, lc)

        if i % pars['Num_iter_per_show'] == 0:
            # anneal factor
            eta = eta * pars['anneal_factor']
            # normalize
            err = err / vn / pars['Num_iter_per_show']
            cls = cls / vn / pars['Num_iter_per_show']
            lc.append_train(i, err, cls)

            # time
            elapsed = time.time() - start
            elapsed = elapsed / pars['Num_iter_per_show']

            show_string = "iteration %d,    err: %.3f,    cls: %.3f,   elapsed: %.1f s/iter, learning rate: %.6f"\
                    %(i, err, cls, elapsed, eta )

            if pars.has_key('logging') and pars['logging']:
                utils.write_to_log(logfile, show_string)
            print show_string

            if pars['is_visual']:
                # show results To-do: run in a separate thread
                front_end.inter_show(start, lc, eta, vol_ins, props, lbl_outs,
                                     grdts, pars)
                if pars['is_rebalance'] and 'aff' not in pars['out_type']:
                    plt.imshow(msks.values()[0][0, 0, :, :],
                    plt.xlabel('rebalance weight')
                if pars['is_malis']:
                    plt.imshow(malis_weights.values()[0][0, 0, :, :],
                    plt.xlabel('malis weight (log)')
            # reset err and cls
            err = 0
            cls = 0
            # reset time
            start = time.time()

        if i % pars['Num_iter_per_save'] == 0:
            # save network
            netio.save_network(net, pars['train_save_net'], num_iters=i)
            lc.save(pars, elapsed)