def main_summary(task_index, cluster, dataset): from dxpy.learn.distribute.cluster import get_server from dxpy.learn.net.zoo.srms.summ import SRSummaryWriter_v2 from dxpy.learn.utils.general import pre_work from dxpy.learn.session import set_default_session config = tf.ConfigProto() config.gpu_options.allow_growth = True server = get_server(cluster, 'worker', task_index, config=config) with tf.device( tf.train.replica_device_setter( worker_device="/job:worker/task:{}".format(task_index), cluster=cluster)): pre_work() dataset = get_dataset(name='dataset/srms') network = get_network(name='network/srms', dataset=dataset) result = network() hooks = [tf.train.StepCounterHook()] with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(task_index == 0), checkpoint_dir="./save", config=config, hooks=hooks) as sess: set_default_session(sess) network.nodes['trainer'].run('set_learning_rate') sw = SRSummaryWriter_v2(network=network, tensors=tensors, name=summary_config) sw.post_session_created() while True: sw.auto_summary()
def main_worker(task_index, cluster): from dxpy.learn.distribute.cluster import get_server from dxpy.learn.utils.general import pre_work config = tf.ConfigProto() config.gpu_options.allow_growth = True server = get_server(cluster, 'worker', task_index, config=config) with tf.device( tf.train.replica_device_setter( worker_device="/job:worker/task:{}".format(task_index), cluster=cluster)): pre_work() dataset = get_dataset(name='dataset/srms') network = get_network(name='network/srms', dataset=dataset) result = network() hooks = [tf.train.StepCounterHook()] with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(task_index == 0), checkpoint_dir="./save", config=config, hooks=hooks) as sess: from dxpy.learn.session import set_default_session set_default_session(sess) network.nodes['trainer'].run('set_learning_rate') for _ in tqdm(range(10000000000000), ascii=True): network.train()
def create_dataset_network_summary(dataset_maker_name, network_maker_name, summary_maker_name): from dxpy.learn.dataset.api import get_dataset from dxpy.learn.net.api import get_network, get_summary from dxpy.learn.utils.general import pre_work pre_work() dataset = get_dataset(dataset_maker_name) network = get_network(network_maker_name, dataset=dataset) result = network() summary = get_summary(summary_maker_name, dataset, network, result) return dataset, network, summary
def get_dist_network(job_name='ps', task_index=0, network_config=None, dataset=None, name='cluster/ps/task0'): if network_config is None: network_config = name from dxpy.learn.net.api import get_network with tf.device('/job:{}/task:{}'.format(job_name, task_index)): network = get_network(name=network_config, dataset=dataset) return network
def get_dist_network(job_name='dataset', task_index=0, network_config=None, dataset=None, network_ps=None, name='cluster/ps/task0'): import tf.train.replica_device_setter if network_config is None: network_config = name from dxpy.learn.net.api import get_network with tf.device('/job:{}/task:{}'.format(job_name, task_index)): # with tf.device() network = get_network(name=network_config, dataset=dataset, network_ps=network_ps, reuse=True, scope=network_ps._scope) result = network() return network, result
def main(task='train', job_name='worker', task_index=0, cluster_config='cluster.yml'): from dxpy.learn.distribute.cluster import get_cluster_spec, get_server, get_nb_tasks from dxpy.learn.distribute.dataset import get_dist_dataset from dxpy.learn.distribute.ps import get_dist_network as get_ps_network from dxpy.learn.distribute.worker import apply_dist_network from dxpy.learn.distribute.summary import get_dist_summary from dxpy.learn.distribute.worker import get_dist_network as get_worker_network from dxpy.learn.utils.general import pre_work cluster = get_cluster_spec(cluster_config, job_name=None) server = get_server(cluster, job_name, task_index) dataset = get_dist_dataset(name='cluster/dataset/task0') if job_name == 'dataset': server.join() return elif job_name == 'ps': server.join() return # if job_name in ['ps', 'worker', 'test', 'saver', 'summary', 'saver']: # pre_work(device='/job:ps/task:0') # network = get_ps_network(name='cluster/ps/task0', dataset=dataset) # result_main = network() # if job_name == 'ps': # # sess = SessionDist(target=server.target) # # with sess.as_default(): # # network.post_session_created() # # sess.post_session_created() # # network.load() # server.join() # return # if job_name in ['worker', 'summary']: # network_worker, result = get_worker_network(network_ps=network, # name='cluster/worker/task{}'.format(task_index), dataset=dataset) # # result = apply_dist_network( # # network=network, dataset=dataset, name='cluster/worker/task{}'.format(task_index)) # if job_name == 'worker': # sess = SessionDist(target=server.target) # with sess.as_default(): # for _ in tqdm(range(40)): # for _ in range(100): # network_worker.train() # print(sess.run(result[NodeKeys.LOSS])) # while True: # time.sleep(1) # return elif job_name in ['worker', 'summary']: with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:{}".format(task_index), cluster=cluster)): pre_work() network = get_network(name='network/sin', dataset=dataset) result = network() if job_name == 'worker': hooks = [tf.train.StepCounterHook()] config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(task_index == 0), checkpoint_dir="./save", config=config, hooks=hooks) as sess: from dxpy.learn.session import set_default_session set_default_session(sess) # sess = SessionDist(target=server.target) # with sess._sess.as_default(): for _ in tqdm(range(40)): for _ in range(100): network.train() print(sess.run(result[NodeKeys.LOSS])) if job_name == 'test': sess = SessionDist(target=server.target) with sess.as_default(): xv, yv, yp, loss, gs = sess.run([dataset['x'], dataset['y'], result_main[NodeKeys.INFERENCE], result_main[NodeKeys.LOSS], global_step()]) print(xv) print(yv) print(yp) print(loss) print(gs) while True: time.sleep(1) if job_name == 'summary': name = 'cluster/summary/task{}'.format(task_index) result = apply_dist_network( network=network, dataset=dataset, name=name) sw = get_dist_summary(tensors={NodeKeys.LOSS: result[NodeKeys.LOSS], 'global_step': global_step()}, name=name) sess = SessionDist(target=server.target) with sess.as_default(): sw.post_session_created() while True: sw.summary() print('Add one summary.') time.sleep(1) return if job_name == 'saver': network.save()
def main(task='train', job_name='worker', task_index=0, cluster_config='cluster.yml'): from dxpy.learn.distribute.cluster import get_cluster_spec, get_server, get_nb_tasks from dxpy.learn.distribute.dataset import get_dist_dataset from dxpy.learn.distribute.ps import get_dist_network as get_ps_network from dxpy.learn.distribute.worker import apply_dist_network from dxpy.learn.distribute.summary import get_dist_summary from dxpy.learn.distribute.worker import get_dist_network as get_worker_network from dxpy.learn.utils.general import pre_work cluster = get_cluster_spec(cluster_config, job_name=None) if not job_name == 'worker': config = tf.ConfigProto() config.gpu_options.allow_growth = True server = get_server(cluster, job_name, task_index, config=config) datasets = [] nb_datasets = get_nb_tasks(cluster_config, 'dataset') for i in range(1): datasets.append( get_dist_dataset(name='cluster/dataset/task{}'.format(i))) if job_name == 'dataset': server.join() return elif job_name == 'ps': server.join() return elif job_name == 'worker': main_worker(task_index, cluster) elif job_name in ['summary']: with tf.device( tf.train.replica_device_setter( worker_device="/job:worker/task:{}".format(task_index), cluster=cluster)): pre_work() network = get_network(name='network/srms', dataset=datasets[task_index % nb_datasets]) result = network() if job_name == 'worker': hooks = [tf.train.StepCounterHook()] with tf.train.MonitoredTrainingSession( master=server.target, is_chief=(task_index == 0), # is_chief=True, checkpoint_dir="./save", config=config, hooks=hooks) as sess: from dxpy.learn.session import set_default_session set_default_session(sess) network.nodes['trainer'].run('set_learning_rate') # sess = SessionDist(target=server.target) # with sess._sess.as_default(): for _ in tqdm(range(10000000000000), ascii=True): network.train() if job_name == 'summary': name = 'cluster/summary/task{}'.format(task_index) result = apply_dist_network(network=network, dataset=datasets[-1], name=name) result.update(datasets[-1].nodes) sw = get_dist_summary(tensors=result, network=network, name=name) hooks = [tf.train.StepCounterHook()] with tf.train.MonitoredTrainingSession(master=server.target, is_chief=False, config=config, hooks=hooks) as sess: from dxpy.learn.session import set_default_session set_default_session(sess) sw.post_session_created() while True: sw.auto_summary() # # sess = SessionDist(target=server.target) # with sess.as_default(): # sw.post_session_created() # while True: # sw.auto_summary() return
def infer_mice(dataset, nb_samples, output): import numpy as np import tensorflow as tf from dxpy.learn.dataset.api import get_dataset from dxpy.learn.net.api import get_network from dxpy.learn.utils.general import load_yaml_config, pre_work from dxpy.learn.session import Session from dxpy.numpy_extension.visual import grid_view from tqdm import tqdm pre_work() input_data = np.load( '/home/hongxwing/Workspace/NetInference/Mice/mice_test_data.npz') input_data = {k: np.array(input_data[k]) for k in input_data} dataset_origin = get_dataset('dataset/srms') is_low_dose = dataset_origin.param('low_dose') from dxpy.debug.utils import dbgmsg dbgmsg('IS LOW DOSE: ', is_low_dose) for k in input_data: print(k, input_data[k].shape) input_keys = ['input/image{}x'.format(2**i) for i in range(4)] label_keys = ['label/image{}x'.format(2**i) for i in range(4)] shapes = [[1] + list(input_data['clean/image{}x'.format(2**i)].shape)[1:] + [1] for i in range(4)] inputs = { input_keys[i]: tf.placeholder(tf.float32, shapes[i], name='input{}x'.format(2**i)) for i in range(4) } labels = { label_keys[i]: tf.placeholder(tf.float32, shapes[i], name='label{}x'.format(2**i)) for i in range(4) } dataset = dict(inputs) dataset.update(labels) network = get_network('network/srms', dataset=dataset) nb_down_sample = network.param('nb_down_sample') config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.train.MonitoredTrainingSession(checkpoint_dir='./save', config=config, save_checkpoint_secs=None) if not is_low_dose: prefix = 'clean/image' else: prefix = 'noise/image' def get_feed(idx): # return dict() data_raw = input_data['{}{}x'.format(prefix, 2**nb_down_sample)][idx, ...] data_raw = np.reshape(data_raw, [1] + list(data_raw.shape) + [1]) data_label = input_data['{}1x'.format(prefix)][idx, ...] data_label = np.reshape(data_label, [1] + list(data_label.shape) + [1]) return { dataset['input/image{}x'.format(2**nb_down_sample)]: data_raw, dataset['input/image1x'.format(2**nb_down_sample)]: data_label, dataset['label/image1x'.format(2**nb_down_sample)]: data_label, } to_run = { 'inf': network['outputs/inference'], 'itp': network['outputs/interp'], 'high': network['input/image1x'], # 'li': network['outputs/loss_itp'], # 'ls': network['outputs/loss'], # 'la': network['outputs/aligned_label'] } def crop(data, target): if len(data.shape) == 4: data = data[0, :, :, 0] o1 = data.shape[0] // 2 o2 = (data.shape[1] - target[1]) // 2 return data[o1:o1 + target[0], o2:-o2] MEAN = 100.0 STD = 150.0 if is_low_dose: MEAN /= dataset_origin.param('low_dose_ratio') STD /= dataset_origin.param('low_dose_ratio') NB_IMAGES = nb_samples def get_infer(idx): result = sess.run(to_run, feed_dict=get_feed(idx)) inf = crop(result['inf'], [320, 64]) itp = crop(result['itp'], [320, 64]) high = crop(input_data['{}1x'.format(prefix)][idx, ...], [320, 64]) low = crop( input_data['{}{}x'.format(prefix, 2**nb_down_sample)][idx, ...], [320 // (2**nb_down_sample), 64 // (2**nb_down_sample)]) high = high * STD + MEAN low = low * STD + MEAN inf = inf * STD + MEAN itp = itp * STD + MEAN high = np.pad(high, [[0, 0], [32, 32]], mode='constant') low = np.pad(low, [[0, 0], [32 // (2**nb_down_sample)] * 2], mode='constant') inf = np.pad(inf, [[0, 0], [32, 32]], mode='constant') # inf = np.maximum(inf, 0.0) itp = np.pad(itp, [[0, 0], [32, 32]], mode='constant') return high, low, inf, itp results = {'high': [], 'low': [], 'inf': [], 'itp': []} for i in tqdm(range(NB_IMAGES)): high, low, inf, itp = get_infer(i) results['high'].append(high) results['low'].append(low) results['inf'].append(inf) results['itp'].append(itp) np.savez(output, **results)
def infer_ext(input_npz_filename, input_key='clean/image1x', phantom_npz_filename=None, phantom_key='phantom', dataset_config_name='dataset/srms', network_config_name='network/srms', output_shape=(320, 320), output='infer_ext_result.npz', ids=None, nb_run=1, low_dose_ratio=1.0, crop_method='hc', config_filename='dxln.yml'): """ """ import numpy as np from dxpy.learn.dataset.api import get_dataset from dxpy.learn.net.api import get_network from dxpy.tensor.collections import dict_append, dict_element_to_tensor from dxpy.learn.utils.general import load_yaml_config, pre_work from tqdm import tqdm from dxpy.learn.session import SessionMonitored pre_work() inputs, phantoms = _load_input_and_phantom(input_npz_filename, phantom_npz_filename, input_key, phantom_key) data_input = inputs / low_dose_ratio load_yaml_config(config_filename) mean, std = _update_statistics(dataset_config_name, low_dose_ratio) dataset = get_dataset(dataset_config_name) dataset_feed = dataset['external_place_holder'] nb_down = dataset.param('nb_down_sample') nb_down_ratio = [2**i for i in range(nb_down + 1)] prefix = 'noise' if dataset.param('with_poission_noise') else 'clean' label_key = '{}/image1x'.format(prefix) input_key = '{}/image{}x'.format(prefix, 2**nb_down) network = get_network(network_config_name, dataset=_get_input_dict(dataset)) sess = SessionMonitored() fetches = { 'label': dataset[label_key], 'input': dataset[input_key], 'infer': network['inference'], 'interp': network['outputs/interp'] } def proc(result, is_low=False): from dxpy.tensor.transform import crop_to_shape if is_low: target_shape = [s // (2**nb_down) for s in output_shape] else: target_shape = output_shape result = crop_to_shape(result, target_shape, '0' + crop_method + '0') result = result * std + mean result = np.maximum(result, 0.0) return result def get_result(idx): phan = phantoms[idx, ...] result = sess.run(fetches, feed_dict={ dataset_feed: np.reshape(data_input[idx:idx + 1, ...], dataset.param('input_shape')) }) result_c = { 'sino_highs': proc(result['label']), 'sino_lows': proc(result['input'], True), 'sino_infs': proc(result['infer']), 'sino_itps': proc(result['interp']), 'phantoms': phan } return result_c keys = ['sino_highs', 'sino_lows', 'sino_infs', 'sino_itps', 'phantoms'] results_sino = {k: [] for k in keys} for idx in tqdm(ids, ascii=True): for _ in tqdm(range(nb_run), ascii=True, leave=False): result_sino = get_result(idx) dict_append(results_sino, result_sino) dict_element_to_tensor(results_sino) np.savez(output, **results_sino)
def infer_sino_sr(dataset, nb_samples, output): """ Use network in current directory as input for inference """ import tensorflow as tf from dxpy.learn.dataset.api import get_dataset from dxpy.learn.net.api import get_network from dxpy.configs import ConfigsView from dxpy.learn.config import config import numpy as np import yaml from dxpy.debug.utils import dbgmsg dbgmsg(dataset) data_raw = np.load(dataset) data_raw = {k: np.array(data_raw[k]) for k in data_raw.keys()} config_view = ConfigsView(config) def tensor_shape(key): shape_origin = data_raw[key].shape return [1] + list(shape_origin[1:3]) + [1] with tf.name_scope('inputs'): keys = ['input/image{}x'.format(2**i) for i in range(4)] keys += ['label/image{}x'.format(2**i) for i in range(4)] dataset = { k: tf.placeholder(tf.float32, tensor_shape(k)) for k in keys } network = get_network('network/srms', dataset=dataset) nb_down_sample = network.param('nb_down_sample') config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.train.MonitoredTrainingSession(checkpoint_dir='./save', config=config, save_checkpoint_secs=None) STAT_STD = 9.27 STAT_MEAN = 9.76 BASE_SHAPE = (640, 320) dataset_configs = config_view['dataset']['srms'] with_noise = dataset_configs['with_poission_noise'] if with_noise: PREFIX = 'input' else: PREFIX = 'label' def crop_sinogram(tensor, target_shape=None): if target_shape is None: target_shape = BASE_SHAPE if len(tensor.shape) == 4: tensor = tensor[0, :, :, 0] o1 = (tensor.shape[0] - target_shape[0]) // 2 o2 = (tensor.shape[1] - target_shape[1]) // 2 return tensor[o1:-o1, o2:-o2] def run_infer(idx): input_key = '{}/image{}x'.format(PREFIX, 2**nb_down_sample) low_sino = np.reshape(data_raw[input_key][idx, :, :], tensor_shape(input_key)) low_sino = (low_sino - STAT_MEAN) / STAT_STD feeds = {dataset['input/image{}x'.format(2**nb_down_sample)]: low_sino} inf, itp = sess.run( [network['outputs/inference'], network['outputs/interp']], feed_dict=feeds) infc = crop_sinogram(inf) itpc = crop_sinogram(itp) infc = infc * STAT_STD + STAT_MEAN itpc = itpc * STAT_STD + STAT_MEAN return infc, itpc phans = [] sino_highs = [] sino_lows = [] sino_itps = [] sino_infs = [] NB_MAX = data_raw['phantom'].shape[0] for idx in tqdm(range(nb_samples), ascii=True): if idx > NB_MAX: import sys print( 'Index {} out of range {}, stop running and store current result...' .format(idx, NB_MAX), file=sys.stderr) break phans.append(data_raw['phantom'][idx, ...]) sino_highs.append( crop_sinogram(data_raw['{}/image1x'.format(PREFIX)][idx, :, :])) sino_lows.append( crop_sinogram( data_raw['{}/image{}x'.format(PREFIX, 2**nb_down_sample)][idx, ...], [s // (2**nb_down_sample) for s in BASE_SHAPE])) sino_inf, sino_itp = run_infer(idx) sino_infs.append(sino_inf) sino_itps.append(sino_itp) results = { 'phantom': phans, 'sino_itps': sino_itps, 'sino_infs': sino_infs, 'sino_highs': sino_highs, 'sino_lows': sino_lows } np.savez(output, **results)
def infer_mct(dataset, nb_samples, output): """ Use network in current directory as input for inference """ import tensorflow as tf from dxpy.learn.dataset.api import get_dataset from dxpy.learn.net.api import get_network from dxpy.configs import ConfigsView from dxpy.learn.config import config import numpy as np import yaml from dxpy.debug.utils import dbgmsg print('Using dataset file:', dataset) data_raw = np.load(dataset) data_raw = {k: np.array(data_raw[k]) for k in data_raw.keys()} config_view = ConfigsView(config) def data_key(nd): return 'image{}x'.format(2**nd) def tensor_shape(key): shape_origin = data_raw[key].shape return [1] + list(shape_origin[1:3]) + [1] with tf.name_scope('inputs'): keys = ['input/image{}x'.format(2**i) for i in range(4)] keys += ['label/image{}x'.format(2**i) for i in range(4)] dataset = { k: tf.placeholder(tf.float32, tensor_shape(data_key(i % 4))) for i, k in enumerate(keys) } network = get_network('network/srms', dataset=dataset) nb_down_sample = network.param('nb_down_sample') config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.train.MonitoredTrainingSession(checkpoint_dir='./save', config=config, save_checkpoint_secs=None) STAT_MEAN = 9.93 STAT_STD = 7.95 STAT_MEAN_LOW = 9.93 * (4.0**nb_down_sample) STAT_STD_LOW = 7.95 * (4.0**nb_down_sample) BASE_SHAPE = (384, 384) def crop_image(tensor, target_shape=None): if target_shape is None: target_shape = BASE_SHAPE if len(tensor.shape) == 4: tensor = tensor[0, :, :, 0] # o1 = (tensor.shape[0] - target_shape[0]) // 2 o1 = tensor.shape[0] // 2 o2 = (tensor.shape[1] - target_shape[1]) // 2 return tensor[o1:o1 + target_shape[0], o2:-o2] input_key = data_key(nb_down_sample) dbgmsg('input_key:', input_key) def run_infer(idx): low_phan = np.reshape(data_raw[input_key][idx, :, :], tensor_shape(input_key)) # low_phan = (low_phan - STAT_MEAN) / STAT_STD feeds = {dataset['input/image{}x'.format(2**nb_down_sample)]: low_phan} inf, itp = sess.run( [network['outputs/inference'], network['outputs/interp']], feed_dict=feeds) infc = crop_image(inf) itpc = crop_image(itp) infc = infc * STAT_STD_LOW + STAT_MEAN_LOW itpc = itpc * STAT_STD_LOW + STAT_MEAN_LOW return infc, itpc phans = [] img_highs = [] img_lows = [] img_itps = [] img_infs = [] NB_MAX = data_raw['phantom'].shape[0] for idx in tqdm(range(nb_samples), ascii=True): if idx > NB_MAX: import sys print( 'Index {} out of range {}, stop running and store current result...' .format(idx, NB_MAX), file=sys.stderr) break phans.append(data_raw['phantom'][idx, ...]) img_high = crop_image(data_raw[data_key(0)][idx, :, :]) img_high = img_high * STAT_STD + STAT_MEAN # img_high = img_high * STAT_STD / \ # (4.0**nb_down_sample) + STAT_MEAN / (4.0**nb_down_sample) img_highs.append(img_high) img_low = crop_image(data_raw[data_key(nb_down_sample)][idx, ...], [s // (2**nb_down_sample) for s in BASE_SHAPE]) img_low = img_low * STAT_STD_LOW + STAT_MEAN_LOW img_lows.append(img_low) img_inf, img_itp = run_infer(idx) img_infs.append(img_inf) img_itps.append(img_itp) img_highs = np.array(img_highs) img_infs = np.array(img_infs) / (4.0**nb_down_sample) img_itps = np.array(img_itps) / (4.0**nb_down_sample) img_lows = np.array(img_lows) / (4.0**nb_down_sample) results = { 'phantom': phans, 'sino_itps': img_itps, 'sino_infs': img_infs, 'sino_highs': img_highs, 'sino_lows': img_lows } np.savez(output, **results)