Ejemplo n.º 1
0
    data_config = icg.utils.loadYaml(args.data_config, ['data_config'])
    global_config = icg.utils.loadYaml(args.global_config, ['global_config'])

    # Tensorflow config
    tf_config = tf.compat.v1.ConfigProto(log_device_placement=False)
    tf_config.gpu_options.allow_growth = global_config['tf_allow_gpu_growth']

    # define the output locations
    base_name = os.path.basename(args.network_config).split('.')[0]
    suffix = base_name + '_' + time.strftime('%Y-%m-%d--%H-%M-%S')
    vn.setupLogDirs(suffix, args, checkpoint_config)

    # load data
    filename_producer = VnMriFilenameProducer(data_config)
    data = VnMriReconstructionData(
        data_config,
        filename_dequeue_op=filename_producer.dequeue_op,
        queue_capacity=global_config['data_queue_capacity'])

    network_config['sampling_pattern'] = data_config['sampling_pattern']

    # Create a queue runner that will run 4 threads in parallel to enqueue examples.
    qr_data = tf.train.QueueRunner(data.queue, [data.enqueue_op] *
                                   global_config['data_num_threads'])
    # Create a queue runner to produce the filenames
    qr_filenames = tf.train.QueueRunner(filename_producer.queue,
                                        [filename_producer.enqueue_op])

    # Create a coordinator, launch the queue runner threads.
    coord = tf.train.Coordinator()

    # define parameters
Ejemplo n.º 2
0
                print(traceback.print_exc())
                continue

            # extract a few ops and variables to be used in evaluation
            u_op = tf.get_collection('u_op')[0]
            u_var = tf.get_collection('u_var')
            g_var = tf.get_collection('g_var')
            c_var = tf.get_collection('c_var')
            m_var = tf.get_collection('m_var')
            f_var = tf.get_collection('f_var')

            # create data object
            data = VnMriReconstructionData(data_config,
                                           u_var=u_var,
                                           f_var=f_var,
                                           g_var=g_var,
                                           c_var=c_var,
                                           m_var=m_var,
                                           load_eval_data=False)

            # Evaluate the performance
            for dataset in eval_datasets:
                eval_patients = dataset['eval_patients']
                if not os.path.exists(eval_output_dir +
                                      '/%s' % dataset['name']):
                    os.makedirs(eval_output_dir + '/%s' % dataset['name'])

                print(
                    "Evaluating performance {:s} for {:s}, epoch {:d}".format(
                        suffix, dataset['name'], epoch))
            epoch = vn.utils.loadCheckpoint(sess, ckpt_dir, epoch=epoch)
        except Exception as e:
            print(traceback.print_exc())

        u_op = tf.compat.v1.get_collection('u_op')[0]
        u_var = tf.compat.v1.get_collection('u_var')
        c_var = tf.compat.v1.get_collection('c_var')
        m_var = tf.compat.v1.get_collection('m_var')
        f_var = tf.compat.v1.get_collection('f_var')
        g_var = tf.compat.v1.get_collection('g_var')

        # create data object
        data = VnMriReconstructionData(data_config,
                                       u_var=u_var,
                                       f_var=f_var,
                                       c_var=c_var,
                                       m_var=m_var,
                                       g_var=g_var,
                                       load_eval_data=False,
                                       load_target=False)

        # run the model
        print('start reconstruction')
        eval_start_time = time.time()

        u_volume = []

        path = os.path.expanduser(data_config['base_dir'] + '/' +
                                  data_config['dataset']['name'] + '/')
        num_slices = len(
            glob.glob(path +
                      '/%d/rawdata*.mat' % data_config['dataset']['patient']))
Ejemplo n.º 4
0
    except Exception as e:
        print(traceback.print_exc())

    # extract operators and variables from the graph
    u_op = tf.get_collection('u_op')[0]
    u_var = tf.get_collection('u_var')
    c_var = tf.get_collection('c_var')
    m_var = tf.get_collection('m_var')
    f_var = tf.get_collection('f_var')
    g_var = tf.get_collection('g_var')

    # create data object
    data = VnMriReconstructionData(data_config,
                                   u_var=u_var,
                                   f_var=f_var,
                                   c_var=c_var,
                                   m_var=m_var,
                                   g_var=g_var,
                                   load_eval_data=False,
                                   load_target=True)

    # load data
    kspace, coil_sens, x_adj, ref, mask, norm \
                    = data.get_test_data(data_config['dataset'],
                                              data_config['dataset']['patient'],
                                              data_config['dataset']['slice'])

    # compile functions
    def val_df(x_adj, kspace, coil_sens, mask, label):
        tf_label = tf.placeholder(u_op.dtype)
        loss = tf.nn.l2_loss(tf.abs(tf_label - u_op))
        grad = tf.gradients(loss, u_var[0])