Exemple #1
0
def main():
    parser = argparse.ArgumentParser(
        description=
        'py, train_data_txt, train_data_ture_txt, validation_data_txt, outdir')
    parser.add_argument('--train_data_txt',
                        '-i1',
                        default='',
                        help='train data list')

    parser.add_argument('--train_ground_truth_txt',
                        '-i2',
                        default='',
                        help='train ground truth list')

    parser.add_argument('--validation_data_txt',
                        '-i3',
                        default='',
                        help='validation data list')

    parser.add_argument('--validation_ground_truth_txt',
                        '-i4',
                        default='',
                        help='validation ground truth list')

    parser.add_argument('--outdir', '-i5', default='', help='outdir')
    args = parser.parse_args()

    # check folder
    if not (os.path.exists(args.outdir)):
        os.mkdir(args.outdir)

    # define
    batch_size = 3
    epoch = 2500

    # load train data
    train_data = io.load_matrix_data(args.train_data_txt, 'float32')
    train_data = np.expand_dims(train_data, axis=4)

    # load train ground truth
    train_truth = io.load_matrix_data(args.train_ground_truth_txt, 'float32')
    train_truth = np.expand_dims(train_truth, axis=4)

    # load validation data
    val_data = io.load_matrix_data(args.validation_data_txt, 'float32')
    val_data = np.expand_dims(val_data, axis=4)

    # load validation ground truth
    val_truth = io.load_matrix_data(args.validation_ground_truth_txt,
                                    'float32')
    val_truth = np.expand_dims(val_truth, axis=4)

    print(' number of training: {}'.format(len(train_data)))
    print('size of traning: {}'.format(train_data.shape))
    print(' number of validation: {}'.format(len(val_data)))
    print('size of validation: {}'.format(val_data.shape))

    image_size = []
    image_size.extend([
        list(train_data.shape)[1],
        list(train_data.shape)[2],
        list(train_data.shape)[3]
    ])

    # set network
    network = Autoencoder(*image_size)
    model = network.model()
    model.summary()
    model.compile(optimizer='Nadam',
                  loss=losses.mean_squared_error,
                  metrics=['mse'])

    # set data_set
    train_steps, train_data = batch_iter(train_data, train_truth, batch_size)
    valid_steps, val_data = batch_iter(val_data, val_truth, batch_size)

    # fit network
    model_checkpoint = ModelCheckpoint(os.path.join(
        args.outdir, 'weights.{epoch:02d}-{val_loss:.2f}.hdf5'),
                                       verbose=1)

    history = model.fit_generator(train_data,
                                  steps_per_epoch=train_steps,
                                  epochs=epoch,
                                  validation_data=val_data,
                                  validation_steps=valid_steps,
                                  verbose=1,
                                  callbacks=[model_checkpoint])

    plot_history(history, args.outdir)
Exemple #2
0
def predict():
    parser = argparse.ArgumentParser(
        description='py, test_data_list, name_list, outdir')
    parser.add_argument('--test_data_list',
                        '-i1',
                        default='',
                        help='test data')
    parser.add_argument('--name_list', '-i2', default='', help='name list')
    parser.add_argument('--model', '-i3', default='', help='model')
    parser.add_argument('--outdir', '-i4', default='', help='outdir')
    args = parser.parse_args()

    if not (os.path.exists(args.outdir)):
        os.mkdir(args.outdir)

    # load name_list
    name_list = []
    with open(args.name_list) as paths_file:
        for line in paths_file:
            line = line.split()
            if not line: continue
            name_list.append(line[:])

    print('number of test data : {}'.format(len(name_list)))

    test_data = io.load_matrix_data(args.test_data_list, 'float32')
    test_data = np.expand_dims(test_data, axis=4)
    print(test_data.shape)

    image_size = []
    image_size.extend([
        list(test_data.shape)[1],
        list(test_data.shape)[2],
        list(test_data.shape)[3]
    ])
    print(image_size)

    # set network
    network = Autoencoder(*image_size)
    model = network.model()
    model.load_weights(args.model)

    preds = model.predict(test_data, 1)
    preds = preds[:, :, :, :, 0]

    print(preds.shape)

    for i in range(preds.shape[0]):
        # EUDT
        eudt_image = sitk.GetImageFromArray(preds[i])
        eudt_image.SetSpacing([1, 1, 1])
        eudt_image.SetOrigin([0, 0, 0])

        # label
        label = np.where(preds[i] > 0, 0, 1)
        label_image = sitk.GetImageFromArray(label)
        label_image.SetSpacing([1, 1, 1])
        label_image.SetOrigin([0, 0, 0])

        io.write_mhd_and_raw(
            eudt_image,
            '{}.mhd'.format(os.path.join(args.outdir, 'EUDT', *name_list[i])))
        io.write_mhd_and_raw(
            label_image,
            '{}.mhd'.format(os.path.join(args.outdir, 'label', *name_list[i])))