Ejemplo n.º 1
0
def predict(gpu_id, csv_path, split_col, split, batch_size, out_dir,
            gen_model_file, delta, out_imgs):

    # GPU handling
    if gpu_id is not None:
        gpu = '/gpu:' + str(gpu_id)
        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        set_session(tf.Session(config=config))
    else:
        gpu = '/cpu:0'

    # check generator model file exists
    assert os.path.isfile(
        gen_model_file), "generator model file does not exist"

    gen_model_file = os.path.abspath(gen_model_file)

    # extract run directory and model checkpoint name
    model_dir, model_name = os.path.split(gen_model_file)
    model_name = os.path.splitext(model_name)[0]
    model_name += '_{:02.0f}'.format(delta) if delta is not None else ''

    # load model config
    config_path = os.path.join(model_dir, 'config.pkl')

    assert os.path.isfile(config_path), "model_config not found"

    if isinstance(split, str):
        split = [split]

    # create out_dir
    if out_dir is None or out_dir == '':
        out_dir = os.path.join(model_dir, 'predict',
                               split_col + '_' + ''.join(split), model_name)

    print('out_dir:', out_dir)

    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)

    model_config = pickle.load(open(config_path, 'rb'))

    # extract config variables
    vol_shape = model_config['vol_shape']
    max_delta = model_config['max_delta']
    vel_resize = model_config['vel_resize']
    prior_lambda = model_config['prior_lambda']
    cri_loss_weights = model_config['cri_loss_weights']
    gen_loss_weights = model_config['gen_loss_weights']
    cri_base_nf = model_config['cri_base_nf']
    int_steps = model_config['int_steps']
    reg_model_file = model_config['reg_model_file']
    clf_model_file = model_config['clf_model_file']
    batchnorm = model_config['batchnorm']
    leaky = model_config['leaky']

    flow_shape = tuple(int(d * vel_resize) for d in vol_shape)

    use_reg = reg_model_file is not None
    use_clf = clf_model_file is not None

    # use csv used in training if no path is provided
    if csv_path is None or csv_path == '':
        csv_path = model_config['csv_path']

    csv_out_path = os.path.join(out_dir, 'meta.csv')

    if not os.path.isfile(csv_out_path):
        csv = pd.read_csv(csv_path)
        csv = csv[csv[split_col].isin(split)]

        # backup then overwrite delta if provided, else use delta from csv
        if delta is not None:
            csv['delta_t_real'] = csv['delta_t']
            csv['delta_t'] = delta * 365

        # write meta to out_dir
        csv.to_csv(csv_out_path, index=False)

    csv = pd.read_csv(csv_out_path)

    img_keys = ['img_path_0', 'img_path_1']
    lbl_keys = ['delta_t', 'pat_dx_1', 'img_id_0', 'img_id_1']

    # datagenerator (from meta in out_dir!)
    test_csv_data = datagenerators.csv_gen(csv_out_path,
                                           img_keys=img_keys,
                                           lbl_keys=lbl_keys,
                                           batch_size=batch_size,
                                           split=(split_col, split),
                                           n_epochs=1,
                                           sample=False,
                                           shuffle=False)

    _, gen_test_data = datagenerators.gan_generators(csv_gen=test_csv_data,
                                                     batch_size=batch_size,
                                                     vol_shape=vol_shape,
                                                     flow_shape=flow_shape,
                                                     max_delta=max_delta,
                                                     int_steps=int_steps,
                                                     use_reg=use_reg,
                                                     use_clf=use_clf)

    with tf.device(gpu):

        print('loading model')

        loss_class = losses.GANLosses(prior_lambda=prior_lambda,
                                      flow_shape=flow_shape)

        # create generator model
        _, gen_net = networks.gan_models(vol_shape,
                                         batch_size,
                                         loss_class,
                                         cri_loss_weights=cri_loss_weights,
                                         cri_optimizer=Adam(),
                                         cri_base_nf=cri_base_nf,
                                         gen_loss_weights=gen_loss_weights,
                                         gen_optimizer=Adam(),
                                         vel_resize=vel_resize,
                                         int_steps=int_steps,
                                         reg_model_file=reg_model_file,
                                         clf_model_file=clf_model_file,
                                         batchnorm=batchnorm,
                                         leaky=leaky)

        # load weights into model
        gen_net.load_weights(gen_model_file)

        print('starting predict')

        # predict
        for step, (inputs, _, batch) in enumerate(gen_test_data):

            if step % 10 == 0:
                print('step', step)

            # generate ws loss, perceived delta loss, y_hat, flow_params, feature map
            pred_out = gen_net.predict(inputs)

            Df, yf, flow_params, _, flow, features = pred_out[:6]

            xr_ids = batch['img_id_0']
            yr_ids = batch['img_id_1']

            for i in range(batch_size):

                xr_id = xr_ids[i][0]
                yr_id = yr_ids[i][0]

                # index of the row in csv
                index = (csv.img_id_0 == xr_id) & (csv.img_id_1 == yr_id)

                img_name = str(xr_id) + '_' + str(yr_id) + '_{img_type}.nii.gz'
                img_path = os.path.join(out_dir, img_name)

                def _save_nii(data, img_type):
                    nii = nib.Nifti1Image(data, np.eye(4))
                    path = img_path.format(img_type=img_type)
                    nib.save(nii, path)
                    csv.loc[index, 'img_path_' + img_type] = path

                if 'yf' in out_imgs:
                    _save_nii(yf[i], 'yf')

                if 'flow_params' in out_imgs:
                    _save_nii(flow_params[i], 'flow_params')

                if 'flow' in out_imgs:
                    _save_nii(flow[i], 'flow')

                csv.loc[index, 'Df'] = Df[i]

        csv.to_csv(csv_out_path, index=False)
Ejemplo n.º 2
0
def predict(gpu_id, csv_path, split, batch_size, out_dir, gen_model_file,
            delta, out_imgs):

    # GPU handling
    if gpu_id is not None:
        gpu = '/gpu:' + str(gpu_id)
        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        set_session(tf.Session(config=config))
    else:
        gpu = '/cpu:0'

    # check generator model file exists
    assert os.path.isfile(
        gen_model_file), "generator model file does not exist"

    gen_model_file = os.path.abspath(gen_model_file)

    # extract run directory and model checkpoint name
    model_dir, model_name = os.path.split(gen_model_file)
    model_name = os.path.splitext(model_name)[0]
    model_name += '_{:02.0f}'.format(delta) if delta is not None else ''

    # load model config
    config_path = os.path.join(model_dir, 'config.pkl')

    assert os.path.isfile(config_path), "model_config not found"

    model_config = pickle.load(open(config_path, 'rb'))

    # create out_dir
    if out_dir is None or out_dir == '':
        out_dir = os.path.join(model_dir, 'predict', split, model_name)

    print('out_dir:', out_dir)

    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)

    # extract config variables
    vol_shape = model_config['vol_shape']
    max_delta = model_config['max_delta']
    vel_resize = model_config['vel_resize']
    prior_lambda = model_config['prior_lambda']
    cri_loss_weights = model_config['cri_loss_weights']
    gen_loss_weights = model_config['gen_loss_weights']
    enc_nf = model_config['enc_nf']
    dec_nf = model_config['dec_nf']
    cri_base_nf = model_config['cri_base_nf']
    int_steps = model_config['int_steps']
    ti_flow = model_config['ti_flow']

    flow_shape = tuple(int(d * vel_resize) for d in vol_shape)

    # use csv used in training if no path is provided
    if csv_path is None or csv_path == '':
        csv_path = model_config['csv_path']

    csv = pd.read_csv(csv_path)
    csv = csv[csv.split == split]

    # backup then overwrite delta if provided, else use delta from csv
    if delta is not None:
        csv['delta_t_real'] = csv['delta_t']
        csv['delta_t'] = delta * 365

    # write meta to out_dir
    csv_out_path = os.path.join(out_dir, 'meta.csv')
    csv.to_csv(csv_out_path, index=False)

    img_keys = ['img_path_0', 'img_path_1']
    lbl_keys = ['delta_t', 'img_id_0', 'img_id_1']

    # datagenerator (from meta in out_dir!)
    test_csv_data = datagenerators.csv_gen(csv_out_path,
                                           img_keys=img_keys,
                                           lbl_keys=lbl_keys,
                                           batch_size=batch_size,
                                           split=split,
                                           n_epochs=1,
                                           sample=False,
                                           shuffle=False)

    test_data = convert_delta(test_csv_data, max_delta, int_steps)

    kl_dummy = np.zeros((batch_size, *flow_shape, len(vol_shape) - 1))

    with tf.device(gpu):

        print('loading model')

        loss_class = losses.GANLosses(prior_lambda=prior_lambda,
                                      flow_shape=flow_shape)

        # create generator model
        _, gen_net = networks.gan_models(vol_shape,
                                         batch_size,
                                         loss_class,
                                         cri_loss_weights=cri_loss_weights,
                                         cri_optimizer=Adam(),
                                         gen_loss_weights=gen_loss_weights,
                                         gen_optimizer=Adam(),
                                         enc_nf=enc_nf,
                                         dec_nf=dec_nf,
                                         cri_base_nf=cri_base_nf,
                                         vel_resize=vel_resize,
                                         ti_flow=ti_flow,
                                         int_steps=int_steps)

        # load weights into model
        gen_net.load_weights(gen_model_file)

        print('starting predict')

        # predict
        for i, (imgs, lbls) in enumerate(test_data):

            if i % 10 == 0:
                print('step', i)

            # generate
            yf, flow, flow_ti, Df = gen_net.predict(
                [imgs[0], lbls[0], lbls[1]])

            xr_ids = lbls[2]
            yr_ids = lbls[3]

            for i in range(batch_size):

                xr_id = xr_ids[i][0]
                yr_id = yr_ids[i][0]

                # get index of the current row in csv (TODO de-ugly)
                index = (csv.img_id_0 == xr_id) & (csv.img_id_1 == yr_id)

                img_name = str(xr_id) + '_' + str(yr_id) + '_{img_type}.nii.gz'
                img_path = os.path.join(out_dir, img_name)

                def _save_nii(data, img_type):
                    nii = nib.Nifti1Image(data, np.eye(4))
                    path = img_path.format(img_type=img_type)
                    nib.save(nii, path)
                    csv.loc[index, 'img_path_' + img_type] = path

                if 'yf' in out_imgs:
                    _save_nii(yf[i], 'yf')

                if 'flow' in out_imgs:
                    _save_nii(flow[i], 'flow')

                if 'flow_ti' in out_imgs:
                    _save_nii(flow_ti[i], 'flow_ti')

                csv.loc[index, 'Df'] = Df[i]

        csv.to_csv(csv_out_path, index=False)
Ejemplo n.º 3
0
def train(csv_path,
          tag,
          gpu_id,
          epochs,
          steps_per_epoch,
          batch_size,
          vol_shape,
          int_steps,
          vel_resize,
          sample_weights,
          lr,
          beta_1,
          beta_2,
          epsilon,
          prior_lambda,
          batchnorm,
          leaky,
          split_col,
          split_train,
          split_eval,
          reg_model_file,
          clf_model_file,
          cri_base_nf,
          gen_loss_weights,
          cri_loss_weights,
          cri_steps,
          cri_retune_freq,
          cri_retune_steps,
          valid_freq,
          valid_steps):
    
    """
    model training function
    :param csv_path: path to data csv (img paths, labels)
    :param tag: tag for the run, added to run_dir
    :param gpu_id: integer specifying the gpu to use
    :param lr: learning rate
    :param epochs: number of training iterations
    :param steps_per_epoch: frequency with which to save models
    :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size
    :param prior_lambda: the prior_lambda, the scalar in front of the smoothing laplacian, in MICCAI paper
    """

    # grab config (all local variables at this point)
    model_config = locals()

    # claim gpu, do early so we fail early if it's occupied
    gpu = '/gpu:%d' % 0 # gpu_id
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # convert vol_shape (is list)
    vol_shape = tuple(vol_shape)
    model_config['vol_shape'] = vol_shape 
    
    print('input vol_shape is {}'.format(vol_shape))

    # check csv exists
    assert os.path.isfile(csv_path), 'csv not found at {}'.format(csv_path)

    # add csv path to config 
    csv_path = os.path.abspath(csv_path)
    model_config['csv_path'] = csv_path

    # check regressor model file exists 
    if reg_model_file:
        msg = 'reg model file not found at {}'.format(reg_model_file)
        assert os.path.isfile(reg_model_file), msg
        reg_model_file = os.path.abspath(reg_model_file) 
        model_config['reg_model_file'] = reg_model_file

    # check classifier model file exists 
    if clf_model_file:
        msg = 'clf model file not found at {}'.format(clf_model_file)
        assert os.path.isfile(clf_model_file), msg
        clf_model_file = os.path.abspath(clf_model_file)
        model_config['clf_model_file'] = clf_model_file

    # stitch together run_dir name
    run_dir = 'runs/'
    run_dir += 'gan_{:%Y%m%d_%H%M}'.format(datetime.now())
    run_dir += '_gpu={}'.format(str(gpu_id))
    run_dir += '_bs={}'.format(batch_size)
    run_dir += '_cl={}'.format(cri_base_nf)
    run_dir += '_lr={}'.format(lr)
    run_dir += '_b1={}'.format(beta_1)
    run_dir += '_b2={}'.format(beta_2)
    run_dir += '_ep={}'.format(epsilon)
    run_dir += '_pl={}'.format(prior_lambda)
    run_dir += '_lk={}'.format(leaky)
    run_dir += '_bn={}'.format(batchnorm)
    run_dir += '_vr={}'.format(vel_resize)
    run_dir += '_is={}'.format(int_steps)
    run_dir += '_cs={}'.format(cri_steps)
    run_dir += '_rf={}'.format(cri_retune_freq)
    run_dir += '_rs={}'.format(cri_retune_steps)
    run_dir += '_sw={}'.format(sample_weights is not None)
    run_dir += '_reg={}'.format(reg_model_file is not None)
    run_dir += '_clf={}'.format(clf_model_file is not None)
    run_dir += '_glw={}'.format(gen_loss_weights)
    run_dir += '_clw={}'.format(cri_loss_weights)
    run_dir += '_tag={}'.format(tag) if tag != '' else ''
    
    run_dir = run_dir.replace(' ', '')
    run_dir = run_dir.replace(',', '_')

    print('run_dir is {}'.format(run_dir))

    # calculate flow_shape given the resize param
    flow_shape = tuple(int(d * vel_resize) for d in vol_shape)

    # create run dirs
    if not os.path.isdir(run_dir):
        os.mkdir(run_dir)

    valid_dir = os.path.join(run_dir, 'eval')

    if not os.path.isdir(valid_dir):
        os.mkdir(valid_dir)

    # prepare the model
    with tf.device(gpu):
        
        # load models
        loss_class = losses.GANLosses(prior_lambda=prior_lambda, flow_shape=flow_shape)

        cri_optimizer = Adam(lr=lr, beta_1=beta_1, beta_2=beta_2, epsilon=epsilon)
        gen_optimizer = Adam(lr=lr, beta_1=beta_1, beta_2=beta_2, epsilon=epsilon)

        cri_model, gen_model = networks.gan_models(
                                        vol_shape, batch_size, loss_class,
                                        cri_loss_weights=cri_loss_weights,
                                        cri_optimizer=cri_optimizer,
                                        cri_base_nf=cri_base_nf,
                                        gen_loss_weights=gen_loss_weights,
                                        gen_optimizer=gen_optimizer,
                                        vel_resize=vel_resize,
                                        int_steps=int_steps,
                                        reg_model_file=reg_model_file,
                                        clf_model_file=clf_model_file,
                                        batchnorm=batchnorm,
                                        leaky=leaky)
      
        cri_model_save_path = os.path.join(run_dir, 'cri_{:03d}.h5')
        gen_model_save_path = os.path.join(run_dir, 'gen_{:03d}.h5')
 
        # save inital models
        cri_model.save(cri_model_save_path.format(0))
        gen_model.save(gen_model_save_path.format(0))

    # load csv
    csv = pd.read_csv(csv_path)
    
    # get max_delta from csv and store in config
    # max_delta and int_steps determine the resolution of the flow integration
    # e.g. max_delta=6y, int_steps=5 results in a resolution of about 5 weeks
    # max_steps = 2**(int_steps+1)-1 = 63, 6 years = 72 months
    max_delta = csv['delta_t'].max()
    model_config['max_delta'] = max_delta
    
    # csv columns for img paths and labels
    img_keys = ['img_path_0', 'img_path_1']
    lbl_keys = ['delta_t', 'pat_dx_1']

    # datagens for training and validation
    train_csv_data = datagenerators.csv_gen(csv_path, img_keys=img_keys,
                            lbl_keys=lbl_keys, batch_size=batch_size,
                            sample=True, weights=sample_weights,
                            split=(split_col, split_train))

    valid_csv_data = datagenerators.csv_gen(csv_path, img_keys=img_keys,
                            lbl_keys=lbl_keys, batch_size=batch_size,
                            sample=True, weights=sample_weights,
                            split=(split_col, split_eval))

    use_reg = reg_model_file is not None
    use_clf = clf_model_file is not None

    cri_train_data, gen_train_data = datagenerators.gan_generators(
                                    csv_gen=train_csv_data, vol_shape=vol_shape,
                                    flow_shape=flow_shape, max_delta=max_delta,
                                    int_steps=int_steps, use_reg=use_reg,
                                    use_clf=use_clf)

    cri_valid_data, gen_valid_data = datagenerators.gan_generators(
                                    csv_gen=valid_csv_data, vol_shape=vol_shape,
                                    flow_shape=flow_shape, max_delta=max_delta,
                                    int_steps=int_steps, use_reg=use_reg,
                                    use_clf=use_clf)

    # write model_config to run_dir
    config_path = os.path.join(run_dir, 'config.pkl')
    pickle.dump(model_config, open(config_path, 'wb'))
    
    print('model_config:')
    print(model_config)

    # tboard callbacks
    tboard_train = TensorBoardExt(log_dir=run_dir, use_reg=use_reg, use_clf=use_clf)
    tboard_train.set_model(gen_model)

    tboard_valid = TensorBoardVal(log_dir=valid_dir, use_reg=use_reg, use_clf=use_clf,
                              cri_data=cri_valid_data, gen_data=gen_valid_data,
                              cri_model=cri_model, gen_model=gen_model,
                              freq=valid_freq, steps=valid_steps)
    tboard_valid.set_model(gen_model)


    # fit generator
    with tf.device(gpu):

        abs_step = 0

        for epoch in range(epochs):
            
            print('epoch {}/{}'.format(epoch, epochs))

            cri_steps_ep = cri_steps

            # check if retune epoch, if so adjust critic steps
            if epoch % cri_retune_freq == 0:
                cri_steps_ep = cri_retune_steps
                print('retuning critic')

            progress_bar = Progbar(target=steps_per_epoch)

            for step in range(steps_per_epoch):
                
                # train critic
                for c_step in range(cri_steps_ep):
                   
                    inputs, labels, _ = next(cri_train_data)
 
                    cri_logs = cri_model.train_on_batch(inputs, labels)

                inputs, labels, _ = next(gen_train_data)

                # train generator
                gen_logs = gen_model.train_on_batch(inputs, labels)

                # update tensorboard
                tboard_train.on_epoch_end(abs_step, cri_logs, gen_logs)
                tboard_valid.on_epoch_end(abs_step)

                abs_step += 1
                progress_bar.add(1)

            if epoch % 5 == 0:
                cri_model.save(cri_model_save_path.format(epoch))
                gen_model.save(gen_model_save_path.format(epoch))
Ejemplo n.º 4
0
def train(csv_path,
          tag,
          gpu_id,
          epochs,
          steps_per_epoch,
          batch_size,
          int_steps,
          vel_resize,
          ti_flow,
          sample_weights,
          lr,
          beta_1,
          beta_2,
          epsilon,
          prior_lambda,
          enc_nf,
          dec_nf,
          cri_base_nf,
          gen_loss_weights,
          cri_loss_weights,
          cri_steps,
          cri_retune_freq,
          cri_retune_steps,
          valid_freq,
          valid_steps):
    
    """
    model training function
    :param csv_path: path to data csv (img paths, labels)
    :param tag: tag for the run, added to run_dir
    :param gpu_id: integer specifying the gpu to use
    :param lr: learning rate
    :param epochs: number of training iterations
    :param steps_per_epoch: frequency with which to save models
    :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size
    :param prior_lambda: the prior_lambda, the scalar in front of the smoothing laplacian, in MICCAI paper
    """

    model_config = locals()

    # gpu handling
    gpu = '/gpu:%d' % 0 # gpu_id
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    vol_shape = (80, 96, 80)

    print('input vol_shape is {}'.format(vol_shape))
    
    assert os.path.isfile(csv_path), 'csv not found at {}'.format(csv_path)

    csv_path = os.path.abspath(csv_path)

    model_config['csv_path'] = csv_path

    model_dir = 'runs/'
    model_dir += 'gan_{:%Y%m%d_%H%M}'.format(datetime.now())
    model_dir += '_gpu={}'.format(str(gpu_id))
    model_dir += '_bs={}'.format(batch_size)
    model_dir += '_enc={}'.format(enc_nf)
    model_dir += '_dec={}'.format(dec_nf)
    model_dir += '_cbn={}'.format(cri_base_nf)
    model_dir += '_lr={}'.format(lr)
    model_dir += '_b1={}'.format(beta_1)
    model_dir += '_b2={}'.format(beta_2)
    model_dir += '_ep={}'.format(epsilon)
    model_dir += '_pl={}'.format(prior_lambda)
    model_dir += '_vr={}'.format(vel_resize)
    model_dir += '_ti={}'.format(ti_flow)
    model_dir += '_is={}'.format(int_steps)
    model_dir += '_cs={}'.format(cri_steps)
    model_dir += '_rf={}'.format(cri_retune_freq)
    model_dir += '_rs={}'.format(cri_retune_steps)
    model_dir += '_sw={}'.format(sample_weights is not None)
    model_dir += '_glw={}'.format(gen_loss_weights)
    model_dir += '_clw={}'.format(cri_loss_weights)
    model_dir += '_tag={}'.format(tag) if tag != '' else ''
    
    model_dir = model_dir.replace(' ', '')
    model_dir = model_dir.replace(',', '_')

    print('model_dir is {}'.format(model_dir))

    flow_shape = tuple(int(d * vel_resize) for d in vol_shape)

    valid_dir = os.path.join(model_dir, 'eval')

    # prepare model folder
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    if not os.path.isdir(valid_dir):
        os.mkdir(valid_dir)

    # prepare the model
    with tf.device(gpu):
        
        # load models
        loss_class = losses.GANLosses(prior_lambda=prior_lambda, flow_shape=flow_shape)

        cri_optimizer = Adam(lr=lr, beta_1=beta_1, beta_2=beta_2, epsilon=epsilon)
        gen_optimizer = Adam(lr=lr, beta_1=beta_1, beta_2=beta_2, epsilon=epsilon)

        cri_model, gen_model = networks.gan_models(
                                        vol_shape, batch_size, loss_class,
                                        cri_loss_weights=cri_loss_weights,
                                        cri_optimizer=cri_optimizer,
                                        gen_loss_weights=gen_loss_weights,
                                        gen_optimizer=gen_optimizer,
                                        enc_nf=enc_nf, dec_nf=dec_nf,
                                        cri_base_nf=cri_base_nf,
                                        vel_resize=vel_resize,
                                        ti_flow=ti_flow,
                                        int_steps=int_steps)
      
        cri_model_save_path = os.path.join(model_dir, 'cri_{:03d}.h5')
        gen_model_save_path = os.path.join(model_dir, 'gen_{:03d}.h5')
 
        # save inital models
        cri_model.save(cri_model_save_path.format(0))
        gen_model.save(gen_model_save_path.format(0))
       

    # data generator
    num_gpus = len(gpu_id.split(','))
    assert np.mod(batch_size, num_gpus) == 0, \
        'batch_size should be a multiple of the nr. of gpus. ' + \
        'Got batch_size %d, %d gpus' % (batch_size, num_gpus)

    # load csv
    csv = pd.read_csv(csv_path)
    
    # get max_delta from csv and store in config
    # max_delta and int_steps determine the resolution of the flow integration
    # e.g. max_delta=6, int_steps=5 results in a resolution of about 5 weeks
    # max_steps = 2**(int_steps+1)-1 = 63, 6 years = 72 months
    max_delta = csv['delta_t'].max()
    model_config['max_delta'] = max_delta
    
    # csv columns for img paths and labels
    img_keys = ['img_path_0', 'img_path_1']
    lbl_keys = ['delta_t']

    # datagens for training and validation
    train_csv_data = datagenerators.csv_gen(csv_path, img_keys=img_keys,
                            lbl_keys=lbl_keys, batch_size=batch_size,
                            sample=True, weights=sample_weights, split='train')

    valid_csv_data = datagenerators.csv_gen(csv_path, img_keys=img_keys,
                            lbl_keys=lbl_keys, batch_size=batch_size,
                            sample=True, weights=sample_weights, split='eval')

    # convert the delta to channel (for critic) and bin_repr (for ss in gen)
    train_data = datagenerators.gan_gen(train_csv_data, max_delta, int_steps)
    valid_data = datagenerators.gan_gen(valid_csv_data, max_delta, int_steps)


    # write model_config to run_dir
    config_path = os.path.join(model_dir, 'config.pkl')
    pickle.dump(model_config, open(config_path, 'wb'))


    # labels for train/predict
    # dummy tensor for kl loss, must have correct flow shape
    kl_dummy = np.zeros((batch_size, *flow_shape, len(vol_shape)-1))
   
    # labels for critic ws loss
    real = np.ones((batch_size, 1)) * (-1) # real labels
    fake = np.ones((batch_size, 1))        # fake labels
    avgd = np.ones((batch_size, 1))        # dummy labels for gradient penalty

 
    # tboard callbacks
    tboard_train = TensorBoardExt(log_dir=model_dir)
    tboard_train.set_model(gen_model)

    tboard_valid = TensorBoardVal(log_dir=valid_dir, data=valid_data,
                                  cri_model=cri_model, gen_model=gen_model,
                                  freq=valid_freq, steps=valid_steps,
                                  batch_size=batch_size, kl_dummy=kl_dummy)
    tboard_valid.set_model(gen_model)


    # fit generator
    with tf.device(gpu):

        abs_step = 0

        for epoch in range(epochs):
            
            print('epoch {}/{}'.format(epoch, epochs))

            cri_steps_ep = cri_steps

            # check if retune epoch, if so adjust critic steps
            if epoch % cri_retune_freq == 0:
                cri_steps_ep = cri_retune_steps
                print('retuning critic')

            progress_bar = Progbar(target=steps_per_epoch)

            for step in range(steps_per_epoch):
                
                # train critic
                for c_step in range(cri_steps_ep):
                    
                    imgs, lbls = next(train_data)

                    cri_in = [imgs[0], imgs[1], lbls[0], lbls[1]] # xr, yr, dr, db
                    cri_true = [real, fake, avgd]

                    cri_logs = cri_model.train_on_batch(cri_in, cri_true)

                imgs, lbls = next(train_data)

                gen_in = [imgs[0], lbls[0], lbls[1]] # xr, dr, db
                gen_true = [imgs[0], kl_dummy, kl_dummy, real]

                # train generator
                gen_logs = gen_model.train_on_batch(gen_in, gen_true)

                # update tensorboard
                tboard_train.on_epoch_end(abs_step, cri_logs, gen_logs)
                tboard_valid.on_epoch_end(abs_step)
                #tensorboard_summaries(tboard_train, abs_step, cri_logs, gen_logs)

                abs_step += 1
                progress_bar.add(1)

            if epoch % 5 == 0:
                cri_model.save(cri_model_save_path.format(epoch))
                gen_model.save(gen_model_save_path.format(epoch))
Ejemplo n.º 5
0
def predict(gpu_id, img_path, out_dir, batch_size, gen_model_file, start, stop,
            step):

    # GPU handling
    if gpu_id is not None:
        gpu = '/gpu:' + str(gpu_id)
        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        set_session(tf.Session(config=config))
    else:
        gpu = '/cpu:0'

    # check generator model file exists
    assert os.path.isfile(
        gen_model_file), "generator model file does not exist"

    gen_model_file = os.path.abspath(gen_model_file)

    # extract run directory and model checkpoint name
    model_dir, model_name = os.path.split(gen_model_file)
    model_name = os.path.splitext(model_name)[0]

    # load model config
    config_path = os.path.join(model_dir, 'config.pkl')

    assert os.path.isfile(config_path), "model_config not found"

    model_config = pickle.load(open(config_path, 'rb'))

    # load nifti
    assert os.path.isfile(img_path), "img file not found"
    img_path = os.path.abspath(img_path)

    nii = nib.load(img_path)
    vol = np.squeeze(nii.get_data().astype(np.float32))

    img_id = os.path.basename(img_path).split('.')[0]

    # create out_dir (run_dir/predict/model_name_delta)
    if out_dir is None or out_dir == '':
        out_dir = os.path.join(model_dir, 'predict', 'longterm', model_name)

    out_dir = os.path.join(out_dir, img_id)

    print('out_dir:', out_dir)

    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)

    # extract config variables
    vol_shape = model_config['vol_shape']
    max_delta = model_config['max_delta']
    vel_resize = model_config['vel_resize']
    prior_lambda = model_config['prior_lambda']
    cri_loss_weights = model_config['cri_loss_weights']
    gen_loss_weights = model_config['gen_loss_weights']
    enc_nf = model_config['enc_nf']
    dec_nf = model_config['dec_nf']
    cri_base_nf = model_config['cri_base_nf']
    int_steps = model_config['int_steps']
    ti_flow = model_config['ti_flow']

    flow_shape = tuple(int(d * vel_resize) for d in vol_shape)

    csv = pd.DataFrame()
    csv['delta_t'] = np.arange(start, stop, step) * 365
    csv['img_path'] = img_path
    csv['img_id'] = img_id

    # write meta to out_dir
    csv_out_path = os.path.join(out_dir, 'meta.csv')
    csv.to_csv(csv_out_path, index=False)

    img_keys = ['img_path']
    lbl_keys = ['delta_t', 'img_id', 'delta_t']

    # datagenerator (from meta in out_dir!)
    test_csv_data = datagenerators.csv_gen(csv_out_path,
                                           img_keys=img_keys,
                                           lbl_keys=lbl_keys,
                                           batch_size=batch_size,
                                           split=None,
                                           n_epochs=1,
                                           sample=False,
                                           shuffle=False)

    test_data = convert_delta(test_csv_data, max_delta, int_steps)

    kl_dummy = np.zeros((batch_size, *flow_shape, len(vol_shape) - 1))

    with tf.device(gpu):

        print('loading model')

        loss_class = losses.GANLosses(prior_lambda=prior_lambda,
                                      flow_shape=flow_shape)

        # create generator model
        _, gen_net = networks.gan_models(vol_shape,
                                         batch_size,
                                         loss_class,
                                         cri_loss_weights=cri_loss_weights,
                                         cri_optimizer=Adam(),
                                         gen_loss_weights=gen_loss_weights,
                                         gen_optimizer=Adam(),
                                         enc_nf=enc_nf,
                                         dec_nf=dec_nf,
                                         cri_base_nf=cri_base_nf,
                                         vel_resize=vel_resize,
                                         ti_flow=ti_flow,
                                         int_steps=int_steps)

        # load weights into model
        gen_net.load_weights(gen_model_file)

        print('starting predict')

        # predict
        for i, (imgs, lbls) in enumerate(test_data):

            if i % 10 == 0:
                print('step', i)

            # generate
            yf, flow, flow_ti, Df = gen_net.predict(
                [imgs[0], lbls[0], lbls[1]])

            img_ids = lbls[2]
            deltas = lbls[3] / 365

            for i in range(batch_size):

                img_id = img_ids[i][0]
                delta = deltas[i][0]

                img_name = str(img_id) + '_{img_type}_'
                img_name += '{:04.1f}.nii.gz'.format(delta)

                img_path = os.path.join(out_dir, img_name)

                def _save_nii(data, img_type):
                    nii = nib.Nifti1Image(data, np.eye(4))
                    path = img_path.format(img_type=img_type)
                    nib.save(nii, path)
                    csv.loc[csv.img_id == img_id,
                            'img_path_' + img_type] = path

                _save_nii(yf[i], 'yf')
                _save_nii(flow[i], 'flow')
                _save_nii(flow_ti[i], 'flow_ti')

                csv.loc[csv.img_id == img_id, 'Df'] = Df[i]

        csv.to_csv(csv_out_path, index=False)