Beispiel #1
0
def main_summary(task_index, cluster, dataset):
    from dxpy.learn.distribute.cluster import get_server
    from dxpy.learn.net.zoo.srms.summ import SRSummaryWriter_v2
    from dxpy.learn.utils.general import pre_work
    from dxpy.learn.session import set_default_session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    server = get_server(cluster, 'worker', task_index, config=config)
    with tf.device(
            tf.train.replica_device_setter(
                worker_device="/job:worker/task:{}".format(task_index),
                cluster=cluster)):
        pre_work()
        dataset = get_dataset(name='dataset/srms')
        network = get_network(name='network/srms', dataset=dataset)
        result = network()
        hooks = [tf.train.StepCounterHook()]
        with tf.train.MonitoredTrainingSession(master=server.target,
                                               is_chief=(task_index == 0),
                                               checkpoint_dir="./save",
                                               config=config,
                                               hooks=hooks) as sess:
            set_default_session(sess)
            network.nodes['trainer'].run('set_learning_rate')
            sw = SRSummaryWriter_v2(network=network,
                                    tensors=tensors,
                                    name=summary_config)
            sw.post_session_created()
            while True:
                sw.auto_summary()
Beispiel #2
0
def train(definition_func):
    with open('dxln.yml') as fin:
        ycfg = yaml.load(fin)
    config.update(ycfg)
    pre_work()
    train_cfgs = get_train_configs()
    steps = train_cfgs['steps']
    summary_freq = train_cfgs['summary_freq']
    save_freq = train_cfgs['save_freq']
    network, summary = definition_func(ycfg)
    session = Session()
    with session.as_default():
        network.post_session_created()
        summary.post_session_created()
        session.post_session_created()

    with session.as_default():
        network.load()
        for i in tqdm(range(steps)):
            network.train()
            if i % summary_freq == 0 and i > 0:
                summary.summary()
                summary.flush()
            if i % save_freq == 0 and i > 0:
                network.save()

    with session.as_default():
        network.save()
Beispiel #3
0
def main_worker(task_index, cluster):
    from dxpy.learn.distribute.cluster import get_server
    from dxpy.learn.utils.general import pre_work
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    server = get_server(cluster, 'worker', task_index, config=config)
    with tf.device(
            tf.train.replica_device_setter(
                worker_device="/job:worker/task:{}".format(task_index),
                cluster=cluster)):
        pre_work()
        dataset = get_dataset(name='dataset/srms')
        network = get_network(name='network/srms', dataset=dataset)
        result = network()
        hooks = [tf.train.StepCounterHook()]
        with tf.train.MonitoredTrainingSession(master=server.target,
                                               is_chief=(task_index == 0),
                                               checkpoint_dir="./save",
                                               config=config,
                                               hooks=hooks) as sess:
            from dxpy.learn.session import set_default_session
            set_default_session(sess)
            network.nodes['trainer'].run('set_learning_rate')
            for _ in tqdm(range(10000000000000), ascii=True):
                network.train()
Beispiel #4
0
def create_dataset_network_summary(dataset_maker_name, network_maker_name, summary_maker_name):
    from dxpy.learn.dataset.api import get_dataset
    from dxpy.learn.net.api import get_network, get_summary
    from dxpy.learn.utils.general import pre_work
    pre_work()
    dataset = get_dataset(dataset_maker_name)
    network = get_network(network_maker_name, dataset=dataset)
    result = network()
    summary = get_summary(summary_maker_name, dataset, network, result)
    return dataset, network, summary
Beispiel #5
0
def main(task='train', job_name='worker', task_index=0, cluster_config='cluster.yml'):
    from dxpy.learn.distribute.cluster import get_cluster_spec, get_server, get_nb_tasks
    from dxpy.learn.distribute.dataset import get_dist_dataset
    from dxpy.learn.distribute.ps import get_dist_network as get_ps_network
    from dxpy.learn.distribute.worker import apply_dist_network
    from dxpy.learn.distribute.summary import get_dist_summary
    from dxpy.learn.distribute.worker import get_dist_network as get_worker_network
    from dxpy.learn.utils.general import pre_work
    cluster = get_cluster_spec(cluster_config, job_name=None)
    server = get_server(cluster, job_name, task_index)
    dataset = get_dist_dataset(name='cluster/dataset/task0')
    if job_name == 'dataset':
        server.join()
        return
    elif job_name == 'ps':
        server.join()
        return
    # if job_name in ['ps', 'worker', 'test', 'saver', 'summary', 'saver']:
    #     pre_work(device='/job:ps/task:0')
    #     network = get_ps_network(name='cluster/ps/task0', dataset=dataset)
    #     result_main = network()
    # if job_name == 'ps':
    #     # sess = SessionDist(target=server.target)
    #     # with sess.as_default():
    #     #     network.post_session_created()
    #     #     sess.post_session_created()
    #     #     network.load()
    #     server.join()
    #     return
    # if job_name in ['worker', 'summary']:
    #     network_worker, result = get_worker_network(network_ps=network,
    #         name='cluster/worker/task{}'.format(task_index), dataset=dataset)
    #     # result = apply_dist_network(
    #     #     network=network, dataset=dataset, name='cluster/worker/task{}'.format(task_index))
    # if job_name == 'worker':
    #     sess = SessionDist(target=server.target)
    #     with sess.as_default():
    #         for _ in tqdm(range(40)):
    #             for _ in range(100):
    #                 network_worker.train()
    #             print(sess.run(result[NodeKeys.LOSS]))
    #     while True:
    #         time.sleep(1)
    #     return
    elif job_name in ['worker', 'summary']:
        with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:{}".format(task_index), cluster=cluster)):
            pre_work()
            network = get_network(name='network/sin', dataset=dataset)
            result = network()
        if job_name == 'worker':
            hooks = [tf.train.StepCounterHook()]
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            with tf.train.MonitoredTrainingSession(master=server.target,
                                           is_chief=(task_index == 0),
                                           checkpoint_dir="./save",
                                           config=config,
                                           hooks=hooks) as sess:
                from dxpy.learn.session import set_default_session
                set_default_session(sess)
            # sess = SessionDist(target=server.target)
                # with sess._sess.as_default():
                for _ in tqdm(range(40)):
                    for _ in range(100):
                        network.train()
                    print(sess.run(result[NodeKeys.LOSS]))
    if job_name == 'test':
        sess = SessionDist(target=server.target)

        with sess.as_default():
            xv, yv, yp, loss, gs = sess.run([dataset['x'],
                                             dataset['y'],
                                             result_main[NodeKeys.INFERENCE],
                                             result_main[NodeKeys.LOSS], global_step()])
            print(xv)
            print(yv)
            print(yp)
            print(loss)
            print(gs)
        while True:
            time.sleep(1)
    if job_name == 'summary':
        name = 'cluster/summary/task{}'.format(task_index)
        result = apply_dist_network(
            network=network, dataset=dataset, name=name)
        sw = get_dist_summary(tensors={NodeKeys.LOSS: result[NodeKeys.LOSS],
                                       'global_step': global_step()}, name=name)
        sess = SessionDist(target=server.target)
        with sess.as_default():
            sw.post_session_created()
            while True:
                sw.summary()
                print('Add one summary.')
                time.sleep(1)
        return
    if job_name == 'saver':
        network.save()
Beispiel #6
0
def main(task='train',
         job_name='worker',
         task_index=0,
         cluster_config='cluster.yml'):
    from dxpy.learn.distribute.cluster import get_cluster_spec, get_server, get_nb_tasks
    from dxpy.learn.distribute.dataset import get_dist_dataset
    from dxpy.learn.distribute.ps import get_dist_network as get_ps_network
    from dxpy.learn.distribute.worker import apply_dist_network
    from dxpy.learn.distribute.summary import get_dist_summary
    from dxpy.learn.distribute.worker import get_dist_network as get_worker_network
    from dxpy.learn.utils.general import pre_work
    cluster = get_cluster_spec(cluster_config, job_name=None)
    if not job_name == 'worker':
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        server = get_server(cluster, job_name, task_index, config=config)
        datasets = []
        nb_datasets = get_nb_tasks(cluster_config, 'dataset')
        for i in range(1):
            datasets.append(
                get_dist_dataset(name='cluster/dataset/task{}'.format(i)))
    if job_name == 'dataset':
        server.join()
        return
    elif job_name == 'ps':
        server.join()
        return
    elif job_name == 'worker':
        main_worker(task_index, cluster)
    elif job_name in ['summary']:
        with tf.device(
                tf.train.replica_device_setter(
                    worker_device="/job:worker/task:{}".format(task_index),
                    cluster=cluster)):
            pre_work()
            network = get_network(name='network/srms',
                                  dataset=datasets[task_index % nb_datasets])
            result = network()
        if job_name == 'worker':
            hooks = [tf.train.StepCounterHook()]
            with tf.train.MonitoredTrainingSession(
                    master=server.target,
                    is_chief=(task_index == 0),
                    #    is_chief=True,
                    checkpoint_dir="./save",
                    config=config,
                    hooks=hooks) as sess:
                from dxpy.learn.session import set_default_session
                set_default_session(sess)
                network.nodes['trainer'].run('set_learning_rate')
                # sess = SessionDist(target=server.target)
                # with sess._sess.as_default():
                for _ in tqdm(range(10000000000000), ascii=True):
                    network.train()
    if job_name == 'summary':
        name = 'cluster/summary/task{}'.format(task_index)
        result = apply_dist_network(network=network,
                                    dataset=datasets[-1],
                                    name=name)
        result.update(datasets[-1].nodes)
        sw = get_dist_summary(tensors=result, network=network, name=name)
        hooks = [tf.train.StepCounterHook()]
        with tf.train.MonitoredTrainingSession(master=server.target,
                                               is_chief=False,
                                               config=config,
                                               hooks=hooks) as sess:
            from dxpy.learn.session import set_default_session
            set_default_session(sess)
            sw.post_session_created()
            while True:
                sw.auto_summary()
            #
        # sess = SessionDist(target=server.target)
        # with sess.as_default():
        #     sw.post_session_created()
        #     while True:
        #         sw.auto_summary()
        return
Beispiel #7
0
def infer_mice(dataset, nb_samples, output):
    import numpy as np
    import tensorflow as tf
    from dxpy.learn.dataset.api import get_dataset
    from dxpy.learn.net.api import get_network
    from dxpy.learn.utils.general import load_yaml_config, pre_work
    from dxpy.learn.session import Session
    from dxpy.numpy_extension.visual import grid_view
    from tqdm import tqdm
    pre_work()
    input_data = np.load(
        '/home/hongxwing/Workspace/NetInference/Mice/mice_test_data.npz')
    input_data = {k: np.array(input_data[k]) for k in input_data}
    dataset_origin = get_dataset('dataset/srms')
    is_low_dose = dataset_origin.param('low_dose')
    from dxpy.debug.utils import dbgmsg
    dbgmsg('IS LOW DOSE: ', is_low_dose)
    for k in input_data:
        print(k, input_data[k].shape)
    input_keys = ['input/image{}x'.format(2**i) for i in range(4)]
    label_keys = ['label/image{}x'.format(2**i) for i in range(4)]
    shapes = [[1] + list(input_data['clean/image{}x'.format(2**i)].shape)[1:] +
              [1] for i in range(4)]
    inputs = {
        input_keys[i]: tf.placeholder(tf.float32,
                                      shapes[i],
                                      name='input{}x'.format(2**i))
        for i in range(4)
    }
    labels = {
        label_keys[i]: tf.placeholder(tf.float32,
                                      shapes[i],
                                      name='label{}x'.format(2**i))
        for i in range(4)
    }
    dataset = dict(inputs)
    dataset.update(labels)
    network = get_network('network/srms', dataset=dataset)
    nb_down_sample = network.param('nb_down_sample')
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.train.MonitoredTrainingSession(checkpoint_dir='./save',
                                             config=config,
                                             save_checkpoint_secs=None)

    if not is_low_dose:
        prefix = 'clean/image'
    else:
        prefix = 'noise/image'

    def get_feed(idx):
        #     return dict()
        data_raw = input_data['{}{}x'.format(prefix, 2**nb_down_sample)][idx,
                                                                         ...]
        data_raw = np.reshape(data_raw, [1] + list(data_raw.shape) + [1])
        data_label = input_data['{}1x'.format(prefix)][idx, ...]
        data_label = np.reshape(data_label, [1] + list(data_label.shape) + [1])
        return {
            dataset['input/image{}x'.format(2**nb_down_sample)]: data_raw,
            dataset['input/image1x'.format(2**nb_down_sample)]: data_label,
            dataset['label/image1x'.format(2**nb_down_sample)]: data_label,
        }

    to_run = {
        'inf': network['outputs/inference'],
        'itp': network['outputs/interp'],
        'high': network['input/image1x'],
        #     'li': network['outputs/loss_itp'],
        #     'ls': network['outputs/loss'],
        #     'la': network['outputs/aligned_label']
    }

    def crop(data, target):
        if len(data.shape) == 4:
            data = data[0, :, :, 0]
        o1 = data.shape[0] // 2
        o2 = (data.shape[1] - target[1]) // 2
        return data[o1:o1 + target[0], o2:-o2]

    MEAN = 100.0
    STD = 150.0
    if is_low_dose:
        MEAN /= dataset_origin.param('low_dose_ratio')
        STD /= dataset_origin.param('low_dose_ratio')
    NB_IMAGES = nb_samples

    def get_infer(idx):
        result = sess.run(to_run, feed_dict=get_feed(idx))
        inf = crop(result['inf'], [320, 64])
        itp = crop(result['itp'], [320, 64])
        high = crop(input_data['{}1x'.format(prefix)][idx, ...], [320, 64])
        low = crop(
            input_data['{}{}x'.format(prefix, 2**nb_down_sample)][idx, ...],
            [320 // (2**nb_down_sample), 64 // (2**nb_down_sample)])

        high = high * STD + MEAN
        low = low * STD + MEAN
        inf = inf * STD + MEAN
        itp = itp * STD + MEAN
        high = np.pad(high, [[0, 0], [32, 32]], mode='constant')
        low = np.pad(low, [[0, 0], [32 // (2**nb_down_sample)] * 2],
                     mode='constant')
        inf = np.pad(inf, [[0, 0], [32, 32]], mode='constant')
        # inf = np.maximum(inf, 0.0)
        itp = np.pad(itp, [[0, 0], [32, 32]], mode='constant')
        return high, low, inf, itp

    results = {'high': [], 'low': [], 'inf': [], 'itp': []}
    for i in tqdm(range(NB_IMAGES)):
        high, low, inf, itp = get_infer(i)
        results['high'].append(high)
        results['low'].append(low)
        results['inf'].append(inf)
        results['itp'].append(itp)

    np.savez(output, **results)
Beispiel #8
0
def infer_ext(input_npz_filename,
              input_key='clean/image1x',
              phantom_npz_filename=None,
              phantom_key='phantom',
              dataset_config_name='dataset/srms',
              network_config_name='network/srms',
              output_shape=(320, 320),
              output='infer_ext_result.npz',
              ids=None,
              nb_run=1,
              low_dose_ratio=1.0,
              crop_method='hc',
              config_filename='dxln.yml'):
    """
    """
    import numpy as np
    from dxpy.learn.dataset.api import get_dataset
    from dxpy.learn.net.api import get_network
    from dxpy.tensor.collections import dict_append, dict_element_to_tensor
    from dxpy.learn.utils.general import load_yaml_config, pre_work
    from tqdm import tqdm
    from dxpy.learn.session import SessionMonitored
    pre_work()
    inputs, phantoms = _load_input_and_phantom(input_npz_filename,
                                               phantom_npz_filename, input_key,
                                               phantom_key)
    data_input = inputs / low_dose_ratio
    load_yaml_config(config_filename)
    mean, std = _update_statistics(dataset_config_name, low_dose_ratio)
    dataset = get_dataset(dataset_config_name)
    dataset_feed = dataset['external_place_holder']
    nb_down = dataset.param('nb_down_sample')
    nb_down_ratio = [2**i for i in range(nb_down + 1)]
    prefix = 'noise' if dataset.param('with_poission_noise') else 'clean'
    label_key = '{}/image1x'.format(prefix)
    input_key = '{}/image{}x'.format(prefix, 2**nb_down)
    network = get_network(network_config_name,
                          dataset=_get_input_dict(dataset))
    sess = SessionMonitored()
    fetches = {
        'label': dataset[label_key],
        'input': dataset[input_key],
        'infer': network['inference'],
        'interp': network['outputs/interp']
    }

    def proc(result, is_low=False):
        from dxpy.tensor.transform import crop_to_shape
        if is_low:
            target_shape = [s // (2**nb_down) for s in output_shape]
        else:
            target_shape = output_shape
        result = crop_to_shape(result, target_shape, '0' + crop_method + '0')
        result = result * std + mean
        result = np.maximum(result, 0.0)
        return result

    def get_result(idx):
        phan = phantoms[idx, ...]
        result = sess.run(fetches,
                          feed_dict={
                              dataset_feed:
                              np.reshape(data_input[idx:idx + 1, ...],
                                         dataset.param('input_shape'))
                          })
        result_c = {
            'sino_highs': proc(result['label']),
            'sino_lows': proc(result['input'], True),
            'sino_infs': proc(result['infer']),
            'sino_itps': proc(result['interp']),
            'phantoms': phan
        }
        return result_c

    keys = ['sino_highs', 'sino_lows', 'sino_infs', 'sino_itps', 'phantoms']
    results_sino = {k: [] for k in keys}
    for idx in tqdm(ids, ascii=True):
        for _ in tqdm(range(nb_run), ascii=True, leave=False):
            result_sino = get_result(idx)
            dict_append(results_sino, result_sino)
    dict_element_to_tensor(results_sino)
    np.savez(output, **results_sino)
Beispiel #9
0
 def __enter__(self):
     from dxpy.learn.utils.general import pre_work, load_yaml_config
     load_yaml_config(self._config)
     if self._with_pre_work:
         pre_work()
     return self