Ejemplo n.º 1
0
def main_summary(task_index, cluster, dataset):
    from dxpy.learn.distribute.cluster import get_server
    from dxpy.learn.net.zoo.srms.summ import SRSummaryWriter_v2
    from dxpy.learn.utils.general import pre_work
    from dxpy.learn.session import set_default_session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    server = get_server(cluster, 'worker', task_index, config=config)
    with tf.device(
            tf.train.replica_device_setter(
                worker_device="/job:worker/task:{}".format(task_index),
                cluster=cluster)):
        pre_work()
        dataset = get_dataset(name='dataset/srms')
        network = get_network(name='network/srms', dataset=dataset)
        result = network()
        hooks = [tf.train.StepCounterHook()]
        with tf.train.MonitoredTrainingSession(master=server.target,
                                               is_chief=(task_index == 0),
                                               checkpoint_dir="./save",
                                               config=config,
                                               hooks=hooks) as sess:
            set_default_session(sess)
            network.nodes['trainer'].run('set_learning_rate')
            sw = SRSummaryWriter_v2(network=network,
                                    tensors=tensors,
                                    name=summary_config)
            sw.post_session_created()
            while True:
                sw.auto_summary()
Ejemplo n.º 2
0
def main_worker(task_index, cluster):
    from dxpy.learn.distribute.cluster import get_server
    from dxpy.learn.utils.general import pre_work
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    server = get_server(cluster, 'worker', task_index, config=config)
    with tf.device(
            tf.train.replica_device_setter(
                worker_device="/job:worker/task:{}".format(task_index),
                cluster=cluster)):
        pre_work()
        dataset = get_dataset(name='dataset/srms')
        network = get_network(name='network/srms', dataset=dataset)
        result = network()
        hooks = [tf.train.StepCounterHook()]
        with tf.train.MonitoredTrainingSession(master=server.target,
                                               is_chief=(task_index == 0),
                                               checkpoint_dir="./save",
                                               config=config,
                                               hooks=hooks) as sess:
            from dxpy.learn.session import set_default_session
            set_default_session(sess)
            network.nodes['trainer'].run('set_learning_rate')
            for _ in tqdm(range(10000000000000), ascii=True):
                network.train()
Ejemplo n.º 3
0
def create_dataset_network_summary(dataset_maker_name, network_maker_name, summary_maker_name):
    from dxpy.learn.dataset.api import get_dataset
    from dxpy.learn.net.api import get_network, get_summary
    from dxpy.learn.utils.general import pre_work
    pre_work()
    dataset = get_dataset(dataset_maker_name)
    network = get_network(network_maker_name, dataset=dataset)
    result = network()
    summary = get_summary(summary_maker_name, dataset, network, result)
    return dataset, network, summary
Ejemplo n.º 4
0
Archivo: ps.py Proyecto: Hong-Xiang/dxl
def get_dist_network(job_name='ps',
                     task_index=0,
                     network_config=None,
                     dataset=None,
                     name='cluster/ps/task0'):
    if network_config is None:
        network_config = name
    from dxpy.learn.net.api import get_network
    with tf.device('/job:{}/task:{}'.format(job_name, task_index)):
        network = get_network(name=network_config, dataset=dataset)
    return network
Ejemplo n.º 5
0
def get_dist_network(job_name='dataset',
                     task_index=0,
                     network_config=None,
                     dataset=None,
                     network_ps=None,
                     name='cluster/ps/task0'):
    import tf.train.replica_device_setter
    if network_config is None:
        network_config = name
    from dxpy.learn.net.api import get_network
    with tf.device('/job:{}/task:{}'.format(job_name, task_index)):
        # with tf.device()
        network = get_network(name=network_config,
                              dataset=dataset,
                              network_ps=network_ps,
                              reuse=True,
                              scope=network_ps._scope)
        result = network()
    return network, result
Ejemplo n.º 6
0
def main(task='train', job_name='worker', task_index=0, cluster_config='cluster.yml'):
    from dxpy.learn.distribute.cluster import get_cluster_spec, get_server, get_nb_tasks
    from dxpy.learn.distribute.dataset import get_dist_dataset
    from dxpy.learn.distribute.ps import get_dist_network as get_ps_network
    from dxpy.learn.distribute.worker import apply_dist_network
    from dxpy.learn.distribute.summary import get_dist_summary
    from dxpy.learn.distribute.worker import get_dist_network as get_worker_network
    from dxpy.learn.utils.general import pre_work
    cluster = get_cluster_spec(cluster_config, job_name=None)
    server = get_server(cluster, job_name, task_index)
    dataset = get_dist_dataset(name='cluster/dataset/task0')
    if job_name == 'dataset':
        server.join()
        return
    elif job_name == 'ps':
        server.join()
        return
    # if job_name in ['ps', 'worker', 'test', 'saver', 'summary', 'saver']:
    #     pre_work(device='/job:ps/task:0')
    #     network = get_ps_network(name='cluster/ps/task0', dataset=dataset)
    #     result_main = network()
    # if job_name == 'ps':
    #     # sess = SessionDist(target=server.target)
    #     # with sess.as_default():
    #     #     network.post_session_created()
    #     #     sess.post_session_created()
    #     #     network.load()
    #     server.join()
    #     return
    # if job_name in ['worker', 'summary']:
    #     network_worker, result = get_worker_network(network_ps=network,
    #         name='cluster/worker/task{}'.format(task_index), dataset=dataset)
    #     # result = apply_dist_network(
    #     #     network=network, dataset=dataset, name='cluster/worker/task{}'.format(task_index))
    # if job_name == 'worker':
    #     sess = SessionDist(target=server.target)
    #     with sess.as_default():
    #         for _ in tqdm(range(40)):
    #             for _ in range(100):
    #                 network_worker.train()
    #             print(sess.run(result[NodeKeys.LOSS]))
    #     while True:
    #         time.sleep(1)
    #     return
    elif job_name in ['worker', 'summary']:
        with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:{}".format(task_index), cluster=cluster)):
            pre_work()
            network = get_network(name='network/sin', dataset=dataset)
            result = network()
        if job_name == 'worker':
            hooks = [tf.train.StepCounterHook()]
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            with tf.train.MonitoredTrainingSession(master=server.target,
                                           is_chief=(task_index == 0),
                                           checkpoint_dir="./save",
                                           config=config,
                                           hooks=hooks) as sess:
                from dxpy.learn.session import set_default_session
                set_default_session(sess)
            # sess = SessionDist(target=server.target)
                # with sess._sess.as_default():
                for _ in tqdm(range(40)):
                    for _ in range(100):
                        network.train()
                    print(sess.run(result[NodeKeys.LOSS]))
    if job_name == 'test':
        sess = SessionDist(target=server.target)

        with sess.as_default():
            xv, yv, yp, loss, gs = sess.run([dataset['x'],
                                             dataset['y'],
                                             result_main[NodeKeys.INFERENCE],
                                             result_main[NodeKeys.LOSS], global_step()])
            print(xv)
            print(yv)
            print(yp)
            print(loss)
            print(gs)
        while True:
            time.sleep(1)
    if job_name == 'summary':
        name = 'cluster/summary/task{}'.format(task_index)
        result = apply_dist_network(
            network=network, dataset=dataset, name=name)
        sw = get_dist_summary(tensors={NodeKeys.LOSS: result[NodeKeys.LOSS],
                                       'global_step': global_step()}, name=name)
        sess = SessionDist(target=server.target)
        with sess.as_default():
            sw.post_session_created()
            while True:
                sw.summary()
                print('Add one summary.')
                time.sleep(1)
        return
    if job_name == 'saver':
        network.save()
Ejemplo n.º 7
0
def main(task='train',
         job_name='worker',
         task_index=0,
         cluster_config='cluster.yml'):
    from dxpy.learn.distribute.cluster import get_cluster_spec, get_server, get_nb_tasks
    from dxpy.learn.distribute.dataset import get_dist_dataset
    from dxpy.learn.distribute.ps import get_dist_network as get_ps_network
    from dxpy.learn.distribute.worker import apply_dist_network
    from dxpy.learn.distribute.summary import get_dist_summary
    from dxpy.learn.distribute.worker import get_dist_network as get_worker_network
    from dxpy.learn.utils.general import pre_work
    cluster = get_cluster_spec(cluster_config, job_name=None)
    if not job_name == 'worker':
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        server = get_server(cluster, job_name, task_index, config=config)
        datasets = []
        nb_datasets = get_nb_tasks(cluster_config, 'dataset')
        for i in range(1):
            datasets.append(
                get_dist_dataset(name='cluster/dataset/task{}'.format(i)))
    if job_name == 'dataset':
        server.join()
        return
    elif job_name == 'ps':
        server.join()
        return
    elif job_name == 'worker':
        main_worker(task_index, cluster)
    elif job_name in ['summary']:
        with tf.device(
                tf.train.replica_device_setter(
                    worker_device="/job:worker/task:{}".format(task_index),
                    cluster=cluster)):
            pre_work()
            network = get_network(name='network/srms',
                                  dataset=datasets[task_index % nb_datasets])
            result = network()
        if job_name == 'worker':
            hooks = [tf.train.StepCounterHook()]
            with tf.train.MonitoredTrainingSession(
                    master=server.target,
                    is_chief=(task_index == 0),
                    #    is_chief=True,
                    checkpoint_dir="./save",
                    config=config,
                    hooks=hooks) as sess:
                from dxpy.learn.session import set_default_session
                set_default_session(sess)
                network.nodes['trainer'].run('set_learning_rate')
                # sess = SessionDist(target=server.target)
                # with sess._sess.as_default():
                for _ in tqdm(range(10000000000000), ascii=True):
                    network.train()
    if job_name == 'summary':
        name = 'cluster/summary/task{}'.format(task_index)
        result = apply_dist_network(network=network,
                                    dataset=datasets[-1],
                                    name=name)
        result.update(datasets[-1].nodes)
        sw = get_dist_summary(tensors=result, network=network, name=name)
        hooks = [tf.train.StepCounterHook()]
        with tf.train.MonitoredTrainingSession(master=server.target,
                                               is_chief=False,
                                               config=config,
                                               hooks=hooks) as sess:
            from dxpy.learn.session import set_default_session
            set_default_session(sess)
            sw.post_session_created()
            while True:
                sw.auto_summary()
            #
        # sess = SessionDist(target=server.target)
        # with sess.as_default():
        #     sw.post_session_created()
        #     while True:
        #         sw.auto_summary()
        return
Ejemplo n.º 8
0
def infer_mice(dataset, nb_samples, output):
    import numpy as np
    import tensorflow as tf
    from dxpy.learn.dataset.api import get_dataset
    from dxpy.learn.net.api import get_network
    from dxpy.learn.utils.general import load_yaml_config, pre_work
    from dxpy.learn.session import Session
    from dxpy.numpy_extension.visual import grid_view
    from tqdm import tqdm
    pre_work()
    input_data = np.load(
        '/home/hongxwing/Workspace/NetInference/Mice/mice_test_data.npz')
    input_data = {k: np.array(input_data[k]) for k in input_data}
    dataset_origin = get_dataset('dataset/srms')
    is_low_dose = dataset_origin.param('low_dose')
    from dxpy.debug.utils import dbgmsg
    dbgmsg('IS LOW DOSE: ', is_low_dose)
    for k in input_data:
        print(k, input_data[k].shape)
    input_keys = ['input/image{}x'.format(2**i) for i in range(4)]
    label_keys = ['label/image{}x'.format(2**i) for i in range(4)]
    shapes = [[1] + list(input_data['clean/image{}x'.format(2**i)].shape)[1:] +
              [1] for i in range(4)]
    inputs = {
        input_keys[i]: tf.placeholder(tf.float32,
                                      shapes[i],
                                      name='input{}x'.format(2**i))
        for i in range(4)
    }
    labels = {
        label_keys[i]: tf.placeholder(tf.float32,
                                      shapes[i],
                                      name='label{}x'.format(2**i))
        for i in range(4)
    }
    dataset = dict(inputs)
    dataset.update(labels)
    network = get_network('network/srms', dataset=dataset)
    nb_down_sample = network.param('nb_down_sample')
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.train.MonitoredTrainingSession(checkpoint_dir='./save',
                                             config=config,
                                             save_checkpoint_secs=None)

    if not is_low_dose:
        prefix = 'clean/image'
    else:
        prefix = 'noise/image'

    def get_feed(idx):
        #     return dict()
        data_raw = input_data['{}{}x'.format(prefix, 2**nb_down_sample)][idx,
                                                                         ...]
        data_raw = np.reshape(data_raw, [1] + list(data_raw.shape) + [1])
        data_label = input_data['{}1x'.format(prefix)][idx, ...]
        data_label = np.reshape(data_label, [1] + list(data_label.shape) + [1])
        return {
            dataset['input/image{}x'.format(2**nb_down_sample)]: data_raw,
            dataset['input/image1x'.format(2**nb_down_sample)]: data_label,
            dataset['label/image1x'.format(2**nb_down_sample)]: data_label,
        }

    to_run = {
        'inf': network['outputs/inference'],
        'itp': network['outputs/interp'],
        'high': network['input/image1x'],
        #     'li': network['outputs/loss_itp'],
        #     'ls': network['outputs/loss'],
        #     'la': network['outputs/aligned_label']
    }

    def crop(data, target):
        if len(data.shape) == 4:
            data = data[0, :, :, 0]
        o1 = data.shape[0] // 2
        o2 = (data.shape[1] - target[1]) // 2
        return data[o1:o1 + target[0], o2:-o2]

    MEAN = 100.0
    STD = 150.0
    if is_low_dose:
        MEAN /= dataset_origin.param('low_dose_ratio')
        STD /= dataset_origin.param('low_dose_ratio')
    NB_IMAGES = nb_samples

    def get_infer(idx):
        result = sess.run(to_run, feed_dict=get_feed(idx))
        inf = crop(result['inf'], [320, 64])
        itp = crop(result['itp'], [320, 64])
        high = crop(input_data['{}1x'.format(prefix)][idx, ...], [320, 64])
        low = crop(
            input_data['{}{}x'.format(prefix, 2**nb_down_sample)][idx, ...],
            [320 // (2**nb_down_sample), 64 // (2**nb_down_sample)])

        high = high * STD + MEAN
        low = low * STD + MEAN
        inf = inf * STD + MEAN
        itp = itp * STD + MEAN
        high = np.pad(high, [[0, 0], [32, 32]], mode='constant')
        low = np.pad(low, [[0, 0], [32 // (2**nb_down_sample)] * 2],
                     mode='constant')
        inf = np.pad(inf, [[0, 0], [32, 32]], mode='constant')
        # inf = np.maximum(inf, 0.0)
        itp = np.pad(itp, [[0, 0], [32, 32]], mode='constant')
        return high, low, inf, itp

    results = {'high': [], 'low': [], 'inf': [], 'itp': []}
    for i in tqdm(range(NB_IMAGES)):
        high, low, inf, itp = get_infer(i)
        results['high'].append(high)
        results['low'].append(low)
        results['inf'].append(inf)
        results['itp'].append(itp)

    np.savez(output, **results)
Ejemplo n.º 9
0
def infer_ext(input_npz_filename,
              input_key='clean/image1x',
              phantom_npz_filename=None,
              phantom_key='phantom',
              dataset_config_name='dataset/srms',
              network_config_name='network/srms',
              output_shape=(320, 320),
              output='infer_ext_result.npz',
              ids=None,
              nb_run=1,
              low_dose_ratio=1.0,
              crop_method='hc',
              config_filename='dxln.yml'):
    """
    """
    import numpy as np
    from dxpy.learn.dataset.api import get_dataset
    from dxpy.learn.net.api import get_network
    from dxpy.tensor.collections import dict_append, dict_element_to_tensor
    from dxpy.learn.utils.general import load_yaml_config, pre_work
    from tqdm import tqdm
    from dxpy.learn.session import SessionMonitored
    pre_work()
    inputs, phantoms = _load_input_and_phantom(input_npz_filename,
                                               phantom_npz_filename, input_key,
                                               phantom_key)
    data_input = inputs / low_dose_ratio
    load_yaml_config(config_filename)
    mean, std = _update_statistics(dataset_config_name, low_dose_ratio)
    dataset = get_dataset(dataset_config_name)
    dataset_feed = dataset['external_place_holder']
    nb_down = dataset.param('nb_down_sample')
    nb_down_ratio = [2**i for i in range(nb_down + 1)]
    prefix = 'noise' if dataset.param('with_poission_noise') else 'clean'
    label_key = '{}/image1x'.format(prefix)
    input_key = '{}/image{}x'.format(prefix, 2**nb_down)
    network = get_network(network_config_name,
                          dataset=_get_input_dict(dataset))
    sess = SessionMonitored()
    fetches = {
        'label': dataset[label_key],
        'input': dataset[input_key],
        'infer': network['inference'],
        'interp': network['outputs/interp']
    }

    def proc(result, is_low=False):
        from dxpy.tensor.transform import crop_to_shape
        if is_low:
            target_shape = [s // (2**nb_down) for s in output_shape]
        else:
            target_shape = output_shape
        result = crop_to_shape(result, target_shape, '0' + crop_method + '0')
        result = result * std + mean
        result = np.maximum(result, 0.0)
        return result

    def get_result(idx):
        phan = phantoms[idx, ...]
        result = sess.run(fetches,
                          feed_dict={
                              dataset_feed:
                              np.reshape(data_input[idx:idx + 1, ...],
                                         dataset.param('input_shape'))
                          })
        result_c = {
            'sino_highs': proc(result['label']),
            'sino_lows': proc(result['input'], True),
            'sino_infs': proc(result['infer']),
            'sino_itps': proc(result['interp']),
            'phantoms': phan
        }
        return result_c

    keys = ['sino_highs', 'sino_lows', 'sino_infs', 'sino_itps', 'phantoms']
    results_sino = {k: [] for k in keys}
    for idx in tqdm(ids, ascii=True):
        for _ in tqdm(range(nb_run), ascii=True, leave=False):
            result_sino = get_result(idx)
            dict_append(results_sino, result_sino)
    dict_element_to_tensor(results_sino)
    np.savez(output, **results_sino)
Ejemplo n.º 10
0
def infer_sino_sr(dataset, nb_samples, output):
    """
    Use network in current directory as input for inference
    """
    import tensorflow as tf
    from dxpy.learn.dataset.api import get_dataset
    from dxpy.learn.net.api import get_network
    from dxpy.configs import ConfigsView
    from dxpy.learn.config import config
    import numpy as np
    import yaml
    from dxpy.debug.utils import dbgmsg
    dbgmsg(dataset)
    data_raw = np.load(dataset)
    data_raw = {k: np.array(data_raw[k]) for k in data_raw.keys()}
    config_view = ConfigsView(config)

    def tensor_shape(key):
        shape_origin = data_raw[key].shape
        return [1] + list(shape_origin[1:3]) + [1]

    with tf.name_scope('inputs'):
        keys = ['input/image{}x'.format(2**i) for i in range(4)]
        keys += ['label/image{}x'.format(2**i) for i in range(4)]
        dataset = {
            k: tf.placeholder(tf.float32, tensor_shape(k))
            for k in keys
        }

    network = get_network('network/srms', dataset=dataset)
    nb_down_sample = network.param('nb_down_sample')
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.train.MonitoredTrainingSession(checkpoint_dir='./save',
                                             config=config,
                                             save_checkpoint_secs=None)

    STAT_STD = 9.27
    STAT_MEAN = 9.76
    BASE_SHAPE = (640, 320)

    dataset_configs = config_view['dataset']['srms']
    with_noise = dataset_configs['with_poission_noise']
    if with_noise:
        PREFIX = 'input'
    else:
        PREFIX = 'label'

    def crop_sinogram(tensor, target_shape=None):
        if target_shape is None:
            target_shape = BASE_SHAPE
        if len(tensor.shape) == 4:
            tensor = tensor[0, :, :, 0]
        o1 = (tensor.shape[0] - target_shape[0]) // 2
        o2 = (tensor.shape[1] - target_shape[1]) // 2
        return tensor[o1:-o1, o2:-o2]

    def run_infer(idx):
        input_key = '{}/image{}x'.format(PREFIX, 2**nb_down_sample)
        low_sino = np.reshape(data_raw[input_key][idx, :, :],
                              tensor_shape(input_key))
        low_sino = (low_sino - STAT_MEAN) / STAT_STD
        feeds = {dataset['input/image{}x'.format(2**nb_down_sample)]: low_sino}
        inf, itp = sess.run(
            [network['outputs/inference'], network['outputs/interp']],
            feed_dict=feeds)
        infc = crop_sinogram(inf)
        itpc = crop_sinogram(itp)
        infc = infc * STAT_STD + STAT_MEAN
        itpc = itpc * STAT_STD + STAT_MEAN
        return infc, itpc

    phans = []
    sino_highs = []
    sino_lows = []
    sino_itps = []
    sino_infs = []
    NB_MAX = data_raw['phantom'].shape[0]
    for idx in tqdm(range(nb_samples), ascii=True):
        if idx > NB_MAX:
            import sys
            print(
                'Index {} out of range {}, stop running and store current result...'
                .format(idx, NB_MAX),
                file=sys.stderr)
            break

        phans.append(data_raw['phantom'][idx, ...])
        sino_highs.append(
            crop_sinogram(data_raw['{}/image1x'.format(PREFIX)][idx, :, :]))
        sino_lows.append(
            crop_sinogram(
                data_raw['{}/image{}x'.format(PREFIX, 2**nb_down_sample)][idx,
                                                                          ...],
                [s // (2**nb_down_sample) for s in BASE_SHAPE]))
        sino_inf, sino_itp = run_infer(idx)
        sino_infs.append(sino_inf)
        sino_itps.append(sino_itp)

    results = {
        'phantom': phans,
        'sino_itps': sino_itps,
        'sino_infs': sino_infs,
        'sino_highs': sino_highs,
        'sino_lows': sino_lows
    }
    np.savez(output, **results)
Ejemplo n.º 11
0
def infer_mct(dataset, nb_samples, output):
    """
    Use network in current directory as input for inference
    """
    import tensorflow as tf
    from dxpy.learn.dataset.api import get_dataset
    from dxpy.learn.net.api import get_network
    from dxpy.configs import ConfigsView
    from dxpy.learn.config import config
    import numpy as np
    import yaml
    from dxpy.debug.utils import dbgmsg
    print('Using dataset file:', dataset)
    data_raw = np.load(dataset)
    data_raw = {k: np.array(data_raw[k]) for k in data_raw.keys()}
    config_view = ConfigsView(config)

    def data_key(nd):
        return 'image{}x'.format(2**nd)

    def tensor_shape(key):
        shape_origin = data_raw[key].shape
        return [1] + list(shape_origin[1:3]) + [1]

    with tf.name_scope('inputs'):
        keys = ['input/image{}x'.format(2**i) for i in range(4)]
        keys += ['label/image{}x'.format(2**i) for i in range(4)]
        dataset = {
            k: tf.placeholder(tf.float32, tensor_shape(data_key(i % 4)))
            for i, k in enumerate(keys)
        }

    network = get_network('network/srms', dataset=dataset)
    nb_down_sample = network.param('nb_down_sample')
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.train.MonitoredTrainingSession(checkpoint_dir='./save',
                                             config=config,
                                             save_checkpoint_secs=None)

    STAT_MEAN = 9.93
    STAT_STD = 7.95
    STAT_MEAN_LOW = 9.93 * (4.0**nb_down_sample)
    STAT_STD_LOW = 7.95 * (4.0**nb_down_sample)
    BASE_SHAPE = (384, 384)

    def crop_image(tensor, target_shape=None):
        if target_shape is None:
            target_shape = BASE_SHAPE
        if len(tensor.shape) == 4:
            tensor = tensor[0, :, :, 0]
        # o1 = (tensor.shape[0] - target_shape[0]) // 2
        o1 = tensor.shape[0] // 2
        o2 = (tensor.shape[1] - target_shape[1]) // 2
        return tensor[o1:o1 + target_shape[0], o2:-o2]

    input_key = data_key(nb_down_sample)
    dbgmsg('input_key:', input_key)

    def run_infer(idx):
        low_phan = np.reshape(data_raw[input_key][idx, :, :],
                              tensor_shape(input_key))
        # low_phan = (low_phan - STAT_MEAN) / STAT_STD
        feeds = {dataset['input/image{}x'.format(2**nb_down_sample)]: low_phan}
        inf, itp = sess.run(
            [network['outputs/inference'], network['outputs/interp']],
            feed_dict=feeds)
        infc = crop_image(inf)
        itpc = crop_image(itp)
        infc = infc * STAT_STD_LOW + STAT_MEAN_LOW
        itpc = itpc * STAT_STD_LOW + STAT_MEAN_LOW
        return infc, itpc

    phans = []
    img_highs = []
    img_lows = []
    img_itps = []
    img_infs = []
    NB_MAX = data_raw['phantom'].shape[0]
    for idx in tqdm(range(nb_samples), ascii=True):
        if idx > NB_MAX:
            import sys
            print(
                'Index {} out of range {}, stop running and store current result...'
                .format(idx, NB_MAX),
                file=sys.stderr)
            break

        phans.append(data_raw['phantom'][idx, ...])
        img_high = crop_image(data_raw[data_key(0)][idx, :, :])
        img_high = img_high * STAT_STD + STAT_MEAN
        # img_high = img_high * STAT_STD / \
        # (4.0**nb_down_sample) + STAT_MEAN / (4.0**nb_down_sample)
        img_highs.append(img_high)
        img_low = crop_image(data_raw[data_key(nb_down_sample)][idx, ...],
                             [s // (2**nb_down_sample) for s in BASE_SHAPE])
        img_low = img_low * STAT_STD_LOW + STAT_MEAN_LOW
        img_lows.append(img_low)
        img_inf, img_itp = run_infer(idx)
        img_infs.append(img_inf)
        img_itps.append(img_itp)

    img_highs = np.array(img_highs)
    img_infs = np.array(img_infs) / (4.0**nb_down_sample)
    img_itps = np.array(img_itps) / (4.0**nb_down_sample)
    img_lows = np.array(img_lows) / (4.0**nb_down_sample)

    results = {
        'phantom': phans,
        'sino_itps': img_itps,
        'sino_infs': img_infs,
        'sino_highs': img_highs,
        'sino_lows': img_lows
    }
    np.savez(output, **results)