예제 #1
0
    def set_model(self, figs, is_training, reuse=False):
        # return only logits

        h = figs

        # convolution
        with tf.variable_scope(self.name_scope_conv, reuse=reuse):
            for i, (in_chan, out_chan) in enumerate(
                    zip(self.layer_chanels, self.layer_chanels[1:])):
                if i == 0:
                    conved = conv_layer(inputs=h,
                                        out_num=out_chan,
                                        filter_width=5,
                                        filter_hight=5,
                                        stride=1,
                                        l_id=i)

                    h = tf.nn.relu(conved)
                    #h = lrelu(conved)
                else:
                    conved = conv_layer(inputs=h,
                                        out_num=out_chan,
                                        filter_width=5,
                                        filter_hight=5,
                                        stride=2,
                                        l_id=i)

                    bn_conved = batch_norm(conved, i, is_training)
                    h = tf.nn.relu(bn_conved)
                    #h = lrelu(bn_conved)

        feature_image = h

        # full connect
        dim = get_dim(h)
        h = tf.reshape(h, [-1, dim])

        with tf.variable_scope(self.name_scope_fc, reuse=reuse):
            weights = get_weights('fc', [dim, self.fc_dim], 0.02)
            biases = get_biases('fc', [self.fc_dim], 0.0)
            h = tf.matmul(h, weights) + biases
            h = batch_norm(h, 'fc', is_training)
            h = tf.nn.relu(h)

            weights = get_weights('fc2', [self.fc_dim, 1], 0.02)
            biases = get_biases('fc2', [1], 0.0)
            h = tf.matmul(h, weights) + biases

        return h, feature_image
예제 #2
0
def stat_weights(fname='../Data/tra_aug.csv'):
    df = pd.read_csv(fname)
    labels = df.values.transpose()

    W_p, W_n, N_p, N_n = get_weights(labels)

    showWeights(W_p, W_n, N_p, N_n)
예제 #3
0
파일: decoder.py 프로젝트: takat0m0/VAE_GAN
    def set_model(self, z, batch_size, is_training, reuse = False):

        # reshape z
        with tf.variable_scope(self.name_scope_reshape, reuse = reuse):
            w_r = get_weights('_r',
                              [self.z_dim, self.in_dim * self.in_dim * self.layer_chanels[0]],
                              0.02)
            b_r = get_biases('_r',
                             [self.in_dim * self.in_dim * self.layer_chanels[0]],
                             0.0)
            h = tf.matmul(z, w_r) + b_r
            h = batch_norm(h, 'reshape', is_training)
            #h = tf.nn.relu(h)
            h = lrelu(h)
            
        h = tf.reshape(h, [-1, self.in_dim, self.in_dim, self.layer_chanels[0]])

        # deconvolution
        layer_num = len(self.layer_chanels) - 1
        with tf.variable_scope(self.name_scope_deconv, reuse = reuse):
            for i, (in_chan, out_chan) in enumerate(zip(self.layer_chanels, self.layer_chanels[1:])):
                deconved = deconv_layer(inputs = h,
                                        out_shape = [batch_size, self.in_dim * 2 ** (i + 1), self.in_dim * 2 **(i + 1), out_chan],
                                        filter_width = 5, filter_hight = 5,
                                        stride = 2, l_id = i)
                if i == layer_num -1:
                    h = tf.nn.tanh(deconved)
                else:
                    bn_deconved = batch_norm(deconved, i, is_training)
                    #h = tf.nn.relu(bn_deconved)
                    h = lrelu(bn_deconved)

        return h
예제 #4
0
파일: encoder.py 프로젝트: takat0m0/VAE_GAN
    def set_model(self, figs, is_training, reuse=False):
        u'''
        return only logits. not sigmoid(logits).
        '''

        h = figs

        # convolution
        with tf.variable_scope(self.name_scope_conv, reuse=reuse):
            for i, (in_chan, out_chan) in enumerate(
                    zip(self.layer_chanels, self.layer_chanels[1:])):

                conved = conv_layer(inputs=h,
                                    out_num=out_chan,
                                    filter_width=5,
                                    filter_hight=5,
                                    stride=2,
                                    l_id=i)

                if i == 0:
                    h = tf.nn.relu(conved)
                    #h = lrelu(conved)
                else:
                    bn_conved = batch_norm(conved, i, is_training)
                    h = tf.nn.relu(bn_conved)
                    #h = lrelu(bn_conved)
        # full connect
        dim = get_dim(h)
        h = tf.reshape(h, [-1, dim])

        with tf.variable_scope(self.name_scope_fc, reuse=reuse):
            weights = get_weights('fc', [dim, self.fc_dim], 0.02)
            biases = get_biases('fc', [self.fc_dim], 0.0)
            h = tf.matmul(h, weights) + biases
            h = batch_norm(h, 'en_fc_bn', is_training)
            h = tf.nn.relu(h)

            weights = get_weights('mu', [self.fc_dim, self.z_dim], 0.02)
            biases = get_biases('mu', [self.z_dim], 0.0)
            mu = tf.matmul(h, weights) + biases

            weights = get_weights('sigma', [self.fc_dim, self.z_dim], 0.02)
            biases = get_biases('sigma', [self.z_dim], 0.0)
            log_sigma = tf.matmul(h, weights) + biases

        return mu, log_sigma
예제 #5
0
파일: train.py 프로젝트: rowancheung/fetal
def main(options):
    start = time.time()

    np.random.seed(123454321)

    organ = 'all_brains' if options.organ == 'brains' else options.organ

    if options.temporal:
        logging.info('Splitting data.')
        samples = constants.GOOD_FRAMES.keys()
        n = len(samples)
        shuffled = np.random.permutation(samples)
        train = shuffled[:(2*n)//3]
        val = shuffled[(2*n)//3:(5*n)//6]
        test = shuffled[(5*n)//6:]

        logging.info('Creating data generators.')
        label_types = LABELS[options.model]
        train_for = []
        train_rev = []
        train_label_for = []
        train_label_rev = []
        for s in train:
            frames = constants.GOOD_FRAMES[s]
            train_for.extend([f'data/raw/{s}/{s}_{str(i).zfill(4)}.nii.gz' for i in frames])
            train_rev.extend([f'data/raw/{s}/{s}_{str(i-1).zfill(4)}.nii.gz' for i in frames])
            train_label_for.extend([f'data/predict_cleaned/unet3000/{s}/{s}_{str(i).zfill(4)}.nii.gz' for i in frames])
            train_label_rev.extend([f'data/predict_cleaned/unet3000/{s}/{s}_{str(i-1).zfill(4)}.nii.gz' for i in frames])
        train_gen = AugmentGenerator(train_for + train_rev,
                                     label_files=train_label_for + train_label_rev,
                                     concat_files=[[train_rev + train_for], [train_label_rev + train_label_for]],
                                     label_types=label_types)
        weights = util.get_weights(train_gen.labels)

        if not options.skip_training:
            val_for = []
            val_rev = []
            val_label_for = []
            val_label_rev = []
            for s in val:
                frames = constants.GOOD_FRAMES[s]
                val_for.extend([f'data/raw/{s}/{s}_{str(i).zfill(4)}.nii.gz' for i in frames])
                val_rev.extend([f'data/raw/{s}/{s}_{str(i-1).zfill(4)}.nii.gz' for i in frames])
                val_label_for.extend([f'data/predict_cleaned/unet3000/{s}/{s}_{str(i).zfill(4)}.nii.gz' for i in frames])
                val_label_rev.extend([f'data/predict_cleaned/unet3000/{s}/{s}_{str(i-1).zfill(4)}.nii.gz' for i in frames])
            val_gen = VolumeGenerator(val_for + val_rev,
                                      label_files=val_label_for + val_label_rev,
                                      concat_files=[[val_rev + val_for], [val_label_rev + val_label_for]],
                                      label_types=label_types)

        if options.predict_all:
            pass
        else:
            test_for = []
            test_rev = []
            test_label_for = []
            test_label_rev = []
            for s in test:
                frames = constants.GOOD_FRAMES[s]
                test_for.extend([f'data/raw/{s}/{s}_{str(i).zfill(4)}.nii.gz' for i in frames])
                test_rev.extend([f'data/raw/{s}/{s}_{str(i-1).zfill(4)}.nii.gz' for i in frames])
                test_label_for.extend([f'data/predict_cleaned/unet3000/{s}/{s}_{str(i).zfill(4)}.nii.gz' for i in frames])
                test_label_rev.extend([f'data/predict_cleaned/unet3000/{s}/{s}_{str(i-1).zfill(4)}.nii.gz' for i in frames])
            pred_gen = VolumeGenerator(test_for + test_rev, tile_inputs=True)
            test_gen = VolumeGenerator(test_for + test_rev,
                                       label_files=test_label_for + test_label_rev,
                                       concat_files=[[test_rev + test_for], [test_label_rev + test_label_for]],
                                       label_types=label_types)

        logging.info('Creating model.')
        shape = constants.SHAPE[:-1] + (3,)
        model = MODELS[options.model](shape, name=options.name, filename=options.model_file, weights=weights)
    else:
        logging.info('Splitting data.')
        n = len(constants.SAMPLES)
        shuffled = np.random.permutation(constants.SAMPLES)
        train = shuffled[:(2*n)//3]
        val = shuffled[(2*n)//3:(5*n)//6]
        test = shuffled[(5*n)//6:]

        logging.info('Creating data generators.')
        label_types = LABELS[options.model]
        train_files = [f'data/raw/{sample}/{sample}_0000.nii.gz' for sample in train]
        train_label_files = [f'data/labels/{sample}/{sample}_0_{organ}.nii.gz' for sample in train]
        train_gen = AugmentGenerator(train_files, label_files=train_label_files, label_types=label_types)
        weights = util.get_weights(train_gen.labels)

        if not options.skip_training:
            val_files = [f'data/raw/{sample}/{sample}_0000.nii.gz' for sample in val]
            val_label_files = [f'data/labels/{sample}/{sample}_0_{organ}.nii.gz' for sample in val]
            val_gen = VolumeGenerator(val_files, label_files=val_label_files, label_types=label_types)

        if options.predict_all:
            pass
        else:
            test_files = [f'data/raw/{sample}/{sample}_0000.nii.gz' for sample in test]
            test_label_files = [f'data/labels/{sample}/{sample}_0_{organ}.nii.gz' for sample in test]
            pred_gen = VolumeGenerator(test_files, tile_inputs=True)
            test_gen = VolumeGenerator(test_files, label_files=test_label_files, label_types=label_types)

        logging.info('Creating model.')
        shape = constants.SHAPE
        model = MODELS[options.model](shape, name=options.name, filename=options.model_file, weights=weights)

    if not options.skip_training:
        logging.info('Training model.')
        model.train(train_gen, val_gen, options.epochs)

    # FIXME
    if options.predict_all:
        for folder in glob.glob('data/raw/*'):
            try:
                sample = folder.split('/')[-1]
                logging.info(f'{sample}..............................')
                if options.temporal:
                    # TODO
                else:
                    pred_files = glob.glob(f'data/raw/{sample}/{sample}_*.nii.gz')
                    pred_gen = VolumeGenerator(pred_files, tile_inputs=True)
                    model.predict(pred_gen, f'data/predict/{options.name}/{sample}/')
            except Exception as e:
                logging.error(f'ERROR during {sample}: {e}')
    else:
        logging.info('Making predictions.')
        model.predict(pred_gen, f'data/predict/{options.name}/')

        logging.info('Testing model.')
        metrics = model.test(test_gen)
        logging.info(metrics)
        dice = {}
        for i in range(len(test)):
            sample = test[i]
            dice[sample] = util.dice_coef(util.read_vol(test_label_files[i]), util.read_vol(f'data/predict/{options.name}/{sample}_0000.nii.gz'))
        logging.info(metrics)
        logging.info(np.mean(list(metrics.values())))

    end = time.time()
    logging.info(f'total time: {datetime.timedelta(seconds=(end - start))}')
예제 #6
0
def main(options):
    start = time.time()

    np.random.seed(123456789)

    organ = 'all_brains' if options.organ == 'brains' else options.organ

    logging.info('Splitting data.')
    if options.temporal:
        samples = list(constants.GOOD_FRAMES.keys())
        n = len(samples)
        shuffled = np.random.permutation(samples)
        input_file_format = [
            'data/raw/{s}/{s}_{n}.nii.gz', 'data/raw/{s}/{s}_{p}.nii.gz',
            f'data/predict_cleaned/{options.temporal}/{{s}}/{{s}}_{{p}}.nii.gz'
        ]
        label_file_format = f'data/predict_cleaned/{options.temporal}/{{s}}/{{s}}_{{n}}.nii.gz'
        random_gen = True
        shape = constants.SHAPE[:-1] + (3, )
    elif options.good_frames:
        samples = list(constants.GOOD_FRAMES.keys())
        n = len(samples)
        shuffled = np.random.permutation(samples)
        input_file_format = 'data/raw/{s}/{s}_{n}.nii.gz'
        label_file_format = f'data/predict_cleaned/{options.good_frames}/{{s}}/{{s}}_{{n}}.nii.gz'
        random_gen = True
        shape = constants.SHAPE
    else:
        n = len(constants.LABELED_SAMPLES)
        shuffled = np.random.permutation(constants.LABELED_SAMPLES)
        input_file_format = 'data/raw/{s}/{s}_{n}.nii.gz'
        label_file_format = f'data/labels/{{s}}/{{s}}_{{n}}_{organ}.nii.gz'
        random_gen = False
        shape = constants.SHAPE

    assert np.sum(options.split) <= 1, 'Split is greater than 1.'
    train_split = int(options.split[0] * n)
    val_split = int(np.sum(options.split) * n)
    train = shuffled[:train_split]
    val = shuffled[train_split:val_split]
    test = shuffled[val_split:]
    frame_reference = constants.GOOD_FRAMES if options.good_frames else constants.LABELED_FRAMES

    logging.info('Creating data generators.')
    label_types = LABELS[options.model]
    if not options.skip_training:
        d = {s: frame_reference[s] for s in train}
        d[options.sample] = constants.LABELED_FRAMES[options.sample]
        train_gen = DataGenerator(d,
                                  input_file_format,
                                  label_file_format,
                                  label_types=label_types,
                                  load_files=options.load_files,
                                  random_gen=random_gen,
                                  augment=True)
        logging.info(f'  Training generator with {len(train_gen)} samples.')

        val_gen = None
        if len(val) > 0:
            val_gen = DataGenerator({s: frame_reference[s]
                                     for s in val},
                                    input_file_format,
                                    label_file_format,
                                    label_types=label_types,
                                    load_files=options.load_files,
                                    random_gen=random_gen,
                                    resize=True)
            logging.info(
                f'  Validation generator with {len(val_gen)} samples.')

    if options.predict_all or len(test) == 0:
        pred_gen = DataGenerator(
            {
                s: np.arange(n)
                for _, (s, n) in enumerate(constants.SEQ_LENGTH.items())
            },
            input_file_format,
            load_files=False,
            tile_inputs=True)
        logging.info(
            f'  Prediction generator with {len(pred_gen)//8} samples.')
    else:
        pred_gen = DataGenerator({s: frame_reference[s]
                                  for s in test},
                                 input_file_format,
                                 load_files=options.load_files,
                                 random_gen=random_gen,
                                 tile_inputs=True)
        logging.info(
            f'  Prediction generator with {len(pred_gen)//8} samples.')

    if len(test) > 0:
        test_gen = DataGenerator({s: frame_reference[s]
                                  for s in test},
                                 input_file_format,
                                 label_file_format,
                                 label_types=label_types,
                                 load_files=options.load_files,
                                 random_gen=random_gen,
                                 resize=True)
        logging.info(f'  Testing generator with {len(test_gen)} samples.')

    logging.info('Creating model.')
    weights = util.get_weights(glob.glob(f'data/labels/*/*_{organ}.nii.gz'))
    model = MODELS[options.model](shape,
                                  name=options.name,
                                  filename=options.model_file,
                                  weights=weights)

    if not options.skip_training:
        logging.info('Training model.')
        model.train(train_gen, val_gen, options.epochs)

    logging.info('Making predictions.')
    model.predict(pred_gen)

    if len(test) > 0:
        logging.info('Testing model.')
        metrics = model.test(test_gen)
        logging.info(metrics)

    end = time.time()
    logging.info(f'total time: {datetime.timedelta(seconds=(end - start))}')
예제 #7
0
파일: train.py 프로젝트: wonsang/placenta
def main(options):
    start = time.time()

    logging.info('Creating model.')
    shape = constants.TARGET_SHAPE
    if options.seed:
        shape = tuple(list(shape[:-1]) + [shape[-1] + 1])
    if options.concat:
        shape = tuple(list(shape[:-1]) + [shape[-1] + 2])
    if options.size == 'small':
        m = UNetSmall
    elif options.size == 'big':
        m = UNetBig
    else:
        m = UNet
    model = m(shape, name=options.name, filename=options.model_file)

    gen_seed = (options.seed == 'slice' or options.seed == 'volume')

    if options.train:
        logging.info('Creating data generator.')

        input_path = options.train[0].split('*')[0]
        label_path = options.train[1].split('*')[0]

        label_files = glob.glob(options.train[1])
        input_files = [label_file.replace(label_path, input_path) for label_file in label_files]

        aug_gen = AugmentGenerator(input_files,
                                   label_files=label_files,
                                   batch_size=options.batch_size,
                                   seed_type=options.seed,
                                   concat_files=options.concat)
        #FIXME
        val_gen = VolumeGenerator(input_files,
                                  label_files=label_files,
                                  batch_size=options.batch_size,
                                  seed_type=options.seed,
                                  concat_files=options.concat,
                                  load_files=True,
                                  include_labels=True)

        logging.info('Compiling model.')
        model.compile(util.get_weights(aug_gen.labels))

        logging.info('Training model.')
        model.train(aug_gen, val_gen, options.epochs)
        model.save()

    if options.predict:
        logging.info('Making predictions.')

        input_files = glob.glob(options.predict[0])
        seed_files = None if gen_seed else glob.glob(options.predict[1])
        label_files = glob.glob(options.predict[1]) if gen_seed else None
        save_path = options.predict[2]

        pred_gen = VolumeGenerator(input_files,
                                   seed_files=seed_files,
                                   label_files=label_files,
                                   batch_size=options.batch_size,
                                   seed_type=options.seed,
                                   concat_files=options.concat,
                                   include_labels=False)
        model.predict(pred_gen, save_path)

    if options.test:
        logging.info('Testing model.')

        input_files = glob.glob(options.test[0])
        seed_files = None if gen_seed else glob.glob(options.test[1])
        label_files = glob.glob(options.test[1]) if gen_seed else glob.glob(options.test[2])

        test_gen = VolumeGenerator(input_files,
                                   seed_files=seed_files,
                                   label_files=label_files,
                                   batch_size=options.batch_size,
                                   seed_type=options.seed,
                                   concat_files=options.concat,
                                   include_labels=True)
        metrics = model.test(test_gen)
        logging.info(metrics)

    end = time.time()
    logging.info('total time: {}s'.format(end - start))
예제 #8
0
파일: train.py 프로젝트: wonsang/placenta
def run(options):
    start = time.time()

    organ = 'all_brains' if options.organ[0] == 'brains' else options.organ[0]
    for sample in ['010918L', '010918S', '012115', '013018L', '013018S',
                   '013118L', '013118S', '021015', '021218L', '021218S',
                   '022318L', '022318S', '022415', '022618', '030217',
                   '030315', '031317L', '031317T', '031516', '031615',
                   '031616', '031716', '032217', '032318a', '032318b',
                   '032318c', '032318d', '032818', '040218', '040417']:
        logging.info(sample)

        logging.info('Creating model.')
        shape = constants.TARGET_SHAPE
        if options.seed:
            shape = tuple(list(shape[:-1]) + [shape[-1] + 1])
        if options.run == 'concat':
            #TODO
            pass
        if options.size == 'small':
            m = UNetSmall
        elif options.size == 'big':
            m = UNetBig
        else:
            m = UNet
        model = m(shape, name='unet_brains_{}_{}'.format(options.run, sample), filename=options.model_file)

        logging.info('Creating data generator.')

        if options.run == 'concat':
            #TODO
            pass
        else:
            concat_files = None

        if options.run == 'one-out':
            label_files = [file for file in glob.glob('data/labels/*/*_{}.nii.gz'.format(organ))
                           if not os.path.basename(file).startswith(sample)]
        elif options.run == 'single':
            label_files = glob.glob('data/labels/{}/{}_0_{}.nii.gz'.format(sample, sample, organ))
        elif options.run == 'concat':
            #TODO
            pass
        else:
            raise ValueError('Preset program not defined.')

        input_files = [file.replace('labels', 'raw').replace('_{}'.format(organ), '') for file in label_files]
        aug_gen = AugmentGenerator(input_files,
                                   label_files=label_files,
                                   batch_size=options.batch_size,
                                   seed_type=options.seed,
                                   concat_files=concat_files)
        #FIXME
        val_gen = VolumeGenerator(input_files,
                                  label_files=label_files,
                                  batch_size=options.batch_size,
                                  seed_type=options.seed,
                                  concat_files=concat_files,
                                  load_files=True,
                                  include_labels=True)

        logging.info('Compiling model.')
        model.compile(util.get_weights(aug_gen.labels))

        logging.info('Training model.')
        model.train(aug_gen, val_gen, options.epochs)

        logging.info('Saving model.')
        model.save()

        logging.info('Making predictions.')
        if options.run == 'one-out':
            predict_files = glob.glob('data/raw/{}/{}_*_{}.nii.gz'.format(sample, sample, organ))
        elif options.run == 'single':
            label_files = [f for f in glob.glob('data/raw/{}/{}_*_{}.nii.gz'.format(sample, sample, organ))
                           if not os.path.basename(f).endswith('_0_{}.nii.gz'.format(organ))]
        elif options.run == 'concat':
            #TODO
            pass
        else:
            raise ValueError('Preset program not defined.')

        pred_gen = VolumeGenerator(predict_files,
                                   batch_size=options.batch_size,
                                   seed_type=options.seed,
                                   concat_files=concat_files,
                                   include_labels=False)
        save_path = 'data/predict/{}/{}-{}/'.format(sample, options.organ[0], options.run)
        os.makedirs(save_path, exist_ok=True)
        model.predict(pred_gen, save_path)

        logging.info('Testing model.')
        #TODO

    end = time.time()
    logging.info('total time: {}s'.format(end - start))
예제 #9
0
def run(options):
    start = time.time()

    metrics = {}

    organ = 'all_brains' if options.organ[0] == 'brains' else options.organ[0]
    all_labels = glob.glob('data/labels/*/*_{}.nii.gz'.format(organ))

    for sample in ['043015', '051215', '061715', '062515', '081315', '083115', '110214', '112614', '122115', '122215']:
    # for sample in ['043015', '061715']:
        logging.info(sample)

        logging.info('Creating model.')
        shape = constants.TARGET_SHAPE
        if options.seed:
            shape = tuple(list(shape[:-1]) + [shape[-1] + 1])
        if options.run == 'concat':
            shape = tuple(list(shape[:-1]) + [shape[-1] + 2])
        if options.size == 'small':
            m = UNetSmall
        elif options.size == 'big':
            m = UNetBig
        else:
            m = UNet
        model = m(shape, name='unet_brains_{}_{}'.format(options.run, sample), filename=options.model_file)

        logging.info('Creating data generator.')

        if options.run == 'concat':
            concat_files = ['data/raw/{}/{}_1.nii.gz'.format(sample, sample),
                            'data/labels/{}/{}_1_{}.nii.gz'.format(sample, sample, organ)]
        else:
            concat_files = None

        if options.run == 'one-out':
            label_files = [file for file in all_labels if not os.path.basename(file).startswith(sample)]
        elif options.run == 'single':
            label_files = glob.glob('data/labels/{}/{}_1_{}.nii.gz'.format(sample, sample, organ))
        elif options.run == 'concat':
            label_files = glob.glob('data/labels/{}/{}_*_{}.nii.gz'.format(sample, sample, organ))[1:4]
        else:
            raise ValueError('Preset program not defined.')

        input_files = [file.replace('labels', 'raw').replace('_{}'.format(organ), '') for file in label_files]
        aug_gen = AugmentGenerator(input_files,
                                   label_files=label_files,
                                   batch_size=options.batch_size,
                                   seed_type=options.seed,
                                   concat_files=concat_files)
        val_gen = VolumeGenerator(input_files,
                                  label_files=label_files,
                                  batch_size=options.batch_size,
                                  seed_type=options.seed,
                                  concat_files=concat_files,
                                  load_files=True,
                                  include_labels=True)

        logging.info('Compiling model.')
        model.compile(util.get_weights(aug_gen.labels))

        logging.info('Training model.')
        model.train(aug_gen, val_gen, options.epochs)

        logging.info('Saving model.')
        model.save()

        logging.info('Making predictions.')
        if options.run == 'one-out':
            label_files = glob.glob('data/labels/{}/{}_*_{}.nii.gz'.format(sample, sample, organ))
        elif options.run == 'single':
            label_files = [f for f in glob.glob('data/labels/{}/{}_*_{}.nii.gz'.format(sample, sample, organ))
                           if not os.path.basename(f).endswith('_1_{}.nii.gz'.format(organ))]
        elif options.run == 'concat':
            label_files = glob.glob('data/labels/{}/{}_*_{}.nii.gz'.format(sample, sample, organ))[4:]
        else:
            raise ValueError('Preset program not defined.')

        predict_files = [file.replace('labels', 'raw').replace('_{}'.format(organ), '') for file in label_files]
        pred_gen = VolumeGenerator(predict_files,
                                   label_files=label_files,
                                   batch_size=options.batch_size,
                                   seed_type=options.seed,
                                   concat_files=concat_files,
                                   include_labels=False)
        save_path = 'data/predict/{}/{}-{}/'.format(sample, options.organ[0], options.run)
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        model.predict(pred_gen, save_path)

        logging.info('Testing model.')
        test_gen = VolumeGenerator(predict_files,
                                   label_files=label_files,
                                   batch_size=options.batch_size,
                                   seed_type=options.seed,
                                   concat_files=concat_files,
                                   include_labels=True)
        metrics[sample] = model.test(test_gen)

    logging.info(metrics)

    end = time.time()
    logging.info('total time: {}s'.format(end - start))