def main( conf_file='config.cfg', logfile=None ): #%% parameters print "reading config parameters..." config, pars = front_end.parser( conf_file ) if pars.has_key('logging') and pars['logging']: print "recording configuration file..." front_end.record_config_file( pars ) logfile = front_end.make_logfile_name( pars ) #%% create and initialize the network if pars['train_load_net'] and os.path.exists(pars['train_load_net']): print "loading network..." net = netio.load_network( pars ) # load existing learning curve lc = zstatistics.CLearnCurve( pars['train_load_net'] ) else: if pars['train_seed_net'] and os.path.exists(pars['train_seed_net']): print "seeding network..." net = netio.load_network( pars, is_seed=True ) else: print "initializing network..." net = netio.init_network( pars ) # initalize a learning curve lc = zstatistics.CLearnCurve() # show field of view print "field of view: ", net.get_fov() print "output volume info: ", net.get_outputs_setsz() # set some parameters print 'setting up the network...' vn = utils.get_total_num(net.get_outputs_setsz()) eta = pars['eta'] #/ vn net.set_eta( eta ) net.set_momentum( pars['momentum'] ) net.set_weight_decay( pars['weight_decay'] ) # initialize samples outsz = pars['train_outsz'] print "\n\ncreate train samples..." smp_trn = front_end.CSamples(config, pars, pars['train_range'], net, outsz, logfile) print "\n\ncreate test samples..." smp_tst = front_end.CSamples(config, pars, pars['test_range'], net, outsz, logfile) # initialization elapsed = 0 err = 0 cls = 0 # interactive visualization plt.ion() plt.show() # the last iteration we want to continue training iter_last = lc.get_last_it() print "start training..." start = time.time() print "start from ", iter_last+1 for i in xrange(iter_last+1, pars['Max_iter']+1): vol_ins, lbl_outs, msks = smp_trn.get_random_sample() # forward pass vol_ins = utils.make_continuous(vol_ins, dtype=pars['dtype']) props = net.forward( vol_ins ) # cost function and accumulate errors props, cerr, grdts = pars['cost_fn']( props, lbl_outs ) err = err + cerr cls = cls + cost_fn.get_cls(props, lbl_outs) # mask process the gradient grdts = utils.dict_mul(grdts, msks) # run backward pass grdts = utils.make_continuous(grdts, dtype=pars['dtype']) net.backward( grdts ) if pars['is_malis'] : malis_weights = cost_fn.malis_weight(props, lbl_outs) grdts = utils.dict_mul(grdts, malis_weights) if i%pars['Num_iter_per_test']==0: # test the net lc = test.znn_test(net, pars, smp_tst, vn, i, lc) if i%pars['Num_iter_per_show']==0: # anneal factor eta = eta * pars['anneal_factor'] net.set_eta(eta) # normalize err = err / vn / pars['Num_iter_per_show'] cls = cls / vn / pars['Num_iter_per_show'] lc.append_train(i, err, cls) # time elapsed = time.time() - start elapsed = elapsed / pars['Num_iter_per_show'] show_string = "iteration %d, err: %.3f, cls: %.3f, elapsed: %.1f s/iter, learning rate: %.6f"\ %(i, err, cls, elapsed, eta ) if pars.has_key('logging') and pars['logging']: utils.write_to_log(logfile, show_string) print show_string if pars['is_visual']: # show results To-do: run in a separate thread front_end.inter_show(start, lc, eta, vol_ins, props, lbl_outs, grdts, pars) if pars['is_rebalance'] and 'aff' not in pars['out_type']: plt.subplot(247) plt.imshow(msks.values()[0][0,0,:,:], interpolation='nearest', cmap='gray') plt.xlabel('rebalance weight') if pars['is_malis']: plt.subplot(248) plt.imshow(malis_weights.values()[0][0,0,:,:], interpolation='nearest', cmap='gray') plt.xlabel('malis weight (log)') plt.pause(2) plt.show() # reset err and cls err = 0 cls = 0 # reset time start = time.time() if i%pars['Num_iter_per_save']==0: # save network netio.save_network(net, pars['train_save_net'], num_iters=i) lc.save( pars, elapsed )
def main( conf_file='config.cfg', logfile=None ): #%% parameters print "reading config parameters..." config, pars = zconfig.parser( conf_file ) if pars.has_key('logging') and pars['logging']: print "recording configuration file..." zlog.record_config_file( pars ) logfile = zlog.make_logfile_name( pars ) #%% create and initialize the network if pars['train_load_net'] and os.path.exists(pars['train_load_net']): print "loading network..." net = znetio.load_network( pars ) # load existing learning curve lc = zstatistics.CLearnCurve( pars['train_load_net'] ) # the last iteration we want to continue training iter_last = lc.get_last_it() else: if pars['train_seed_net'] and os.path.exists(pars['train_seed_net']): print "seeding network..." net = znetio.load_network( pars, is_seed=True ) else: print "initializing network..." net = znetio.init_network( pars ) # initalize a learning curve lc = zstatistics.CLearnCurve() iter_last = lc.get_last_it() # show field of view print "field of view: ", net.get_fov() # total voxel number of output volumes vn = utils.get_total_num(net.get_outputs_setsz()) # set some parameters print 'setting up the network...' eta = pars['eta'] net.set_eta( pars['eta'] ) net.set_momentum( pars['momentum'] ) net.set_weight_decay( pars['weight_decay'] ) # initialize samples outsz = pars['train_outsz'] print "\n\ncreate train samples..." smp_trn = zsample.CSamples(config, pars, pars['train_range'], net, outsz, logfile) print "\n\ncreate test samples..." smp_tst = zsample.CSamples(config, pars, pars['test_range'], net, outsz, logfile) # initialization elapsed = 0 err = 0.0 # cost energy cls = 0.0 # pixel classification error re = 0.0 # rand error # number of voxels which accumulate error # (if a mask exists) num_mask_voxels = 0 if pars['is_malis']: malis_cls = 0.0 print "start training..." start = time.time() total_time = 0.0 print "start from ", iter_last+1 #Saving initialized network if iter_last+1 == 1: znetio.save_network(net, pars['train_save_net'], num_iters=0) lc.save( pars, 0.0 ) for i in xrange(iter_last+1, pars['Max_iter']+1): # get random sub volume from sample vol_ins, lbl_outs, msks, wmsks = smp_trn.get_random_sample() # forward pass # apply the transformations in memory rather than array view vol_ins = utils.make_continuous(vol_ins, dtype=pars['dtype']) props = net.forward( vol_ins ) # cost function and accumulate errors props, cerr, grdts = pars['cost_fn']( props, lbl_outs, msks ) err += cerr cls += cost_fn.get_cls(props, lbl_outs) num_mask_voxels += utils.sum_over_dict(msks) # gradient reweighting grdts = utils.dict_mul( grdts, msks ) grdts = utils.dict_mul( grdts, wmsks ) if pars['is_malis'] : malis_weights, rand_errors = cost_fn.malis_weight(pars, props, lbl_outs) grdts = utils.dict_mul(grdts, malis_weights) # accumulate the rand error re += rand_errors.values()[0] malis_cls_dict = utils.get_malis_cls( props, lbl_outs, malis_weights ) malis_cls += malis_cls_dict.values()[0] total_time += time.time() - start start = time.time() # test the net if i%pars['Num_iter_per_test']==0: lc = test.znn_test(net, pars, smp_tst, vn, i, lc) if i%pars['Num_iter_per_show']==0: # normalize if utils.dict_mask_empty(msks): err = err / vn / pars['Num_iter_per_show'] cls = cls / vn / pars['Num_iter_per_show'] else: err = err / num_mask_voxels / pars['Num_iter_per_show'] cls = cls / num_mask_voxels / pars['Num_iter_per_show'] lc.append_train(i, err, cls) # time elapsed = total_time / pars['Num_iter_per_show'] if pars['is_malis']: re = re / pars['Num_iter_per_show'] lc.append_train_rand_error( re ) malis_cls = malis_cls / pars['Num_iter_per_show'] lc.append_train_malis_cls( malis_cls ) show_string = "iteration %d, err: %.3f, cls: %.3f, re: %.6f, mc: %.3f, elapsed: %.1f s/iter, learning rate: %.6f"\ %(i, err, cls, re, malis_cls, elapsed, eta ) else: show_string = "iteration %d, err: %.3f, cls: %.3f, elapsed: %.1f s/iter, learning rate: %.6f"\ %(i, err, cls, elapsed, eta ) if pars.has_key('logging') and pars['logging']: utils.write_to_log(logfile, show_string) print show_string # reset err and cls err = 0 cls = 0 re = 0 num_mask_voxels = 0 if pars['is_malis']: malis_cls = 0 # reset time total_time = 0 start = time.time() if i%pars['Num_iter_per_annealing']==0: # anneal factor eta = eta * pars['anneal_factor'] net.set_eta(eta) if i%pars['Num_iter_per_save']==0: # save network znetio.save_network(net, pars['train_save_net'], num_iters=i) lc.save( pars, elapsed ) if pars['is_malis']: utils.save_malis(malis_weights, pars['train_save_net'], num_iters=i) # run backward pass grdts = utils.make_continuous(grdts, dtype=pars['dtype']) net.backward( grdts )
def main(args): config, pars, logfile = parse_args(args) #%% create and initialize the network net, lc = znetio.create_net(pars) # total voxel number of output volumes vn = utils.get_total_num(net.get_outputs_setsz()) # initialize samples outsz = pars['train_outsz'] print "\n\ncreate train samples..." smp_trn = zsample.CSamples(config, pars, pars['train_range'], net, outsz, logfile) print "\n\ncreate test samples..." smp_tst = zsample.CSamples(config, pars, pars['test_range'], net, outsz, logfile) if pars['is_check']: import zcheck zcheck.check_patch(pars, smp_trn) # gradient check is not working now. # zcheck.check_gradient(pars, net, smp_trn) # initialization eta = pars['eta'] elapsed = 0 err = 0.0 # cost energy cls = 0.0 # pixel classification error re = 0.0 # rand error # number of voxels which accumulate error # (if a mask exists) num_mask_voxels = 0 if pars['is_malis']: malis_cls = 0.0 malis_eng = 0.0 else: malis_weights = None # the last iteration we want to continue training iter_last = lc.get_last_it() print "start training..." start = time.time() total_time = 0.0 print "start from ", iter_last + 1 #Saving initial/seeded network # get file name fname, fname_current = znetio.get_net_fname(pars['train_net_prefix'], iter_last, suffix="init") znetio.save_network(net, fname, pars['is_stdio']) lc.save(pars, fname, elapsed=0.0, suffix="init_iter{}".format(iter_last)) # no nan detected nonan = True for i in xrange(iter_last + 1, pars['Max_iter'] + 1): # time cumulation total_time += time.time() - start start = time.time() # get random sub volume from sample vol_ins, lbl_outs, msks, wmsks = smp_trn.get_random_sample() # forward pass # apply the transformations in memory rather than array view vol_ins = utils.make_continuous(vol_ins) props = net.forward(vol_ins) # cost function and accumulate errors props, cerr, grdts = pars['cost_fn'](props, lbl_outs, msks) err += cerr cls += cost_fn.get_cls(props, lbl_outs) # compute rand error if pars['is_debug']: assert not np.all(lbl_outs.values()[0] == 0) re += pyznn.get_rand_error(props.values()[0], lbl_outs.values()[0]) num_mask_voxels += utils.sum_over_dict(msks) # check whether there is a NaN here! if pars['is_debug']: nonan = nonan and utils.check_dict_nan(vol_ins) nonan = nonan and utils.check_dict_nan(lbl_outs) nonan = nonan and utils.check_dict_nan(msks) nonan = nonan and utils.check_dict_nan(wmsks) nonan = nonan and utils.check_dict_nan(props) nonan = nonan and utils.check_dict_nan(grdts) if not nonan: utils.inter_save(pars, net, lc, vol_ins, props, lbl_outs, \ grdts, malis_weights, wmsks, elapsed, i) # stop training return # gradient reweighting grdts = utils.dict_mul(grdts, msks) if pars['rebalance_mode']: grdts = utils.dict_mul(grdts, wmsks) if pars['is_malis']: malis_weights, rand_errors, num_non_bdr = cost_fn.malis_weight( pars, props, lbl_outs) if num_non_bdr <= 1: # skip this iteration continue grdts = utils.dict_mul(grdts, malis_weights) dmc, dme = utils.get_malis_cost(props, lbl_outs, malis_weights) malis_cls += dmc.values()[0] malis_eng += dme.values()[0] # run backward pass grdts = utils.make_continuous(grdts) net.backward(grdts) total_time += time.time() - start start = time.time() if i % pars['Num_iter_per_show'] == 0: # time elapsed = total_time / pars['Num_iter_per_show'] # normalize if utils.dict_mask_empty(msks): err = err / vn / pars['Num_iter_per_show'] cls = cls / vn / pars['Num_iter_per_show'] else: err = err / num_mask_voxels / pars['Num_iter_per_show'] cls = cls / num_mask_voxels / pars['Num_iter_per_show'] re = re / pars['Num_iter_per_show'] lc.append_train(i, err, cls, re) if pars['is_malis']: malis_cls = malis_cls / pars['Num_iter_per_show'] malis_eng = malis_eng / pars['Num_iter_per_show'] lc.append_train_malis_cls(malis_cls) lc.append_train_malis_eng(malis_eng) show_string = "update %d, cost: %.3f, pixel error: %.3f, rand error: %.3f, me: %.3f, mc: %.3f, elapsed: %.1f s/iter, learning rate: %.5f"\ %(i, err, cls, re, malis_eng, malis_cls, elapsed, eta ) else: show_string = "update %d, cost: %.3f, pixel error: %.3f, rand error: %.3f, elapsed: %.1f s/iter, learning rate: %.5f"\ %(i, err, cls, re, elapsed, eta ) if pars.has_key('logging') and pars['logging']: utils.write_to_log(logfile, show_string) print show_string # reset err and cls err = 0 cls = 0 re = 0 num_mask_voxels = 0 if pars['is_malis']: malis_cls = 0 # reset time total_time = 0 start = time.time() # test the net if i % pars['Num_iter_per_test'] == 0: # time accumulation should skip the test total_time += time.time() - start lc = test.znn_test(net, pars, smp_tst, vn, i, lc) start = time.time() if i % pars['Num_iter_per_save'] == 0: utils.inter_save(pars, net, lc, vol_ins, props, lbl_outs, \ grdts, malis_weights, wmsks, elapsed, i) if i % pars['Num_iter_per_annealing'] == 0: # anneal factor eta = eta * pars['anneal_factor'] net.set_eta(eta) # stop the iteration at checking mode if pars['is_check']: print "only need one iteration for checking, stop program..." break
def main(conf_file='config.cfg', logfile=None): #%% parameters print "reading config parameters..." config, pars = front_end.parser(conf_file) if pars.has_key('logging') and pars['logging']: print "recording configuration file..." front_end.record_config_file(pars) logfile = front_end.make_logfile_name(pars) #%% create and initialize the network if pars['train_load_net'] and os.path.exists(pars['train_load_net']): print "loading network..." net = netio.load_network(pars) # load existing learning curve lc = zstatistics.CLearnCurve(pars['train_load_net']) else: if pars['train_seed_net'] and os.path.exists(pars['train_seed_net']): print "seeding network..." net = netio.seed_network(pars, is_seed=True) else: print "initializing network..." net = netio.init_network(pars) # initalize a learning curve lc = zstatistics.CLearnCurve() # show field of view print "field of view: ", net.get_fov() print "output volume info: ", net.get_outputs_setsz() # set some parameters print 'setting up the network...' vn = utils.get_total_num(net.get_outputs_setsz()) eta = pars['eta'] #/ vn net.set_eta(eta) net.set_momentum(pars['momentum']) net.set_weight_decay(pars['weight_decay']) # initialize samples outsz = pars['train_outsz'] print "\n\ncreate train samples..." smp_trn = front_end.CSamples(config, pars, pars['train_range'], net, outsz, logfile) print "\n\ncreate test samples..." smp_tst = front_end.CSamples(config, pars, pars['test_range'], net, outsz, logfile) # initialization elapsed = 0 err = 0 cls = 0 # interactive visualization plt.ion() plt.show() # the last iteration we want to continue training iter_last = lc.get_last_it() print "start training..." start = time.time() print "start from ", iter_last + 1 for i in xrange(iter_last + 1, pars['Max_iter'] + 1): vol_ins, lbl_outs, msks = smp_trn.get_random_sample() # forward pass vol_ins = utils.make_continuous(vol_ins, dtype=pars['dtype']) props = net.forward(vol_ins) # cost function and accumulate errors props, cerr, grdts = pars['cost_fn'](props, lbl_outs) err = err + cerr cls = cls + cost_fn.get_cls(props, lbl_outs) # mask process the gradient grdts = utils.dict_mul(grdts, msks) # run backward pass grdts = utils.make_continuous(grdts, dtype=pars['dtype']) net.backward(grdts) if pars['is_malis']: malis_weights = cost_fn.malis_weight(props, lbl_outs) grdts = utils.dict_mul(grdts, malis_weights) if i % pars['Num_iter_per_test'] == 0: # test the net lc = test.znn_test(net, pars, smp_tst, vn, i, lc) if i % pars['Num_iter_per_show'] == 0: # anneal factor eta = eta * pars['anneal_factor'] net.set_eta(eta) # normalize err = err / vn / pars['Num_iter_per_show'] cls = cls / vn / pars['Num_iter_per_show'] lc.append_train(i, err, cls) # time elapsed = time.time() - start elapsed = elapsed / pars['Num_iter_per_show'] show_string = "iteration %d, err: %.3f, cls: %.3f, elapsed: %.1f s/iter, learning rate: %.6f"\ %(i, err, cls, elapsed, eta ) if pars.has_key('logging') and pars['logging']: utils.write_to_log(logfile, show_string) print show_string if pars['is_visual']: # show results To-do: run in a separate thread front_end.inter_show(start, lc, eta, vol_ins, props, lbl_outs, grdts, pars) if pars['is_rebalance'] and 'aff' not in pars['out_type']: plt.subplot(247) plt.imshow(msks.values()[0][0, 0, :, :], interpolation='nearest', cmap='gray') plt.xlabel('rebalance weight') if pars['is_malis']: plt.subplot(248) plt.imshow(malis_weights.values()[0][0, 0, :, :], interpolation='nearest', cmap='gray') plt.xlabel('malis weight (log)') plt.pause(2) plt.show() # reset err and cls err = 0 cls = 0 # reset time start = time.time() if i % pars['Num_iter_per_save'] == 0: # save network netio.save_network(net, pars['train_save_net'], num_iters=i) lc.save(pars, elapsed)