示例#1
0
validation_provider = DataProvider(valid_set, [org_suffix, lab_suffix],
                                   is_pre_load=False,
                                   processor=processor)

u_net = UNet3D(n_class=5, n_layer=4, root_filters=16, use_bn=True)

model = SimpleTFModel(u_net,
                      org_suffix,
                      lab_suffix,
                      dropout=0,
                      loss_function={'cross-entropy': 1.},
                      weight_function={'balance'})
optimizer = tf.keras.optimizers.Adam(args.learning_rate)

trainer = Trainer(model)

# train test
result = trainer.train(train_provider,
                       validation_provider,
                       epochs=args.epochs,
                       batch_size=args.batch_size,
                       mini_batch_size=args.minibatch_size,
                       output_path=output_path,
                       optimizer=optimizer,
                       eval_frequency=args.eval_frequency,
                       is_save_train_imgs=False,
                       is_save_valid_imgs=True,
                       is_rebuilt_path=True)

# eval test & pre load test
示例#2
0
# loss_function should like {'name': weight}
# weight_function should like {'name': {'alpha': 1, 'beta':2}} or ['name'] / ('name') / {'name'}
model = SimpleTFModel(unet, org_suffix, lab_suffix, dropout=0, loss_function={'cross-entropy':1.0})

# set learning rate with step decay, every [decay_step] epoch, learning rate = [learning rate * decay_rate]
lr = StepDecayLearningRate(learning_rate=args.learning_rate, 
                           decay_step=50,
                           decay_rate=0.5,
                           data_size=train_provider.size,
                           batch_size=args.batch_size)

# init optimizer with learning rate [lr]
optimizer = tf.keras.optimizers.Adam(lr)

# init trainer
trainer = Trainer(model)

# start training
result = trainer.train(train_provider, validation_provider,
                         epochs=args.epochs,
                         batch_size=args.batch_size,
                         mini_batch_size=args.minibatch_size,
                         output_path=output_path,

                         optimizer=optimizer,
                         learning_rate=lr,
                         eval_frequency=args.eval_frequency,
                         is_save_train_imgs=False,
                         is_save_valid_imgs=True)

# evaluate test data
示例#3
0
org_suffix = '_brain.nii.gz'
lab_suffix = '_brain_restore.nii.gz'
train_set = glob.glob('data/distortion_data/train/*_brain.nii.gz')
valid_set = glob.glob('data/distortion_data/validation/*_brain.nii.gz')
test_set1 = glob.glob('data/distortion_data/test/*_brain.nii.gz')
test_set2 = glob.glob('data/distortion_data/new_test/*_brain.nii.gz')
test_result_set2 = glob.glob(
    'data/distortion_data/new_test/*_brain_restore.nii.gz')

img_save_path = 'data/out_image/'

u_net = UNet3D(n_class=1, n_layer=3, root_filters=16, use_bn=True)

model = RegressionModel(u_net, org_suffix, lab_suffix, dropout=0)
trainer = Trainer(model)
trainer.restore('results/test3/ckpt/final')

pre = {org_suffix: [('channelcheck', 1)], lab_suffix: [('channelcheck', 1)]}
processor = SimpleImageProcessor(pre=pre)

train_provider = DataProvider(train_set, [org_suffix, lab_suffix],
                              is_pre_load=False,
                              processor=processor)
validation_provider = DataProvider(valid_set, [org_suffix, lab_suffix],
                                   is_pre_load=False,
                                   processor=processor)
test_provider = DataProvider(test_set1, [org_suffix, lab_suffix],
                             is_pre_load=False,
                             processor=processor)
new_test_provider = DataProvider(test_set2, [org_suffix, lab_suffix],
示例#4
0
                 g_beta=g_beta,
                 g_lambda=g_lambda,
                 dropout=dropout)
gen_lr = StepDecayLearningRate(learning_rate=learning_rate,
                               decay_step=10,
                               decay_rate=0.8,
                               data_size=train_provider.size,
                               batch_size=batch_size)
disc_lr = StepDecayLearningRate(learning_rate=learning_rate,
                                decay_step=10,
                                decay_rate=0.8,
                                data_size=train_provider.size,
                                batch_size=batch_size)
gen_optimizer = tf.keras.optimizers.Adam(gen_lr)
disc_optimizer = tf.keras.optimizers.Adam(disc_lr)
trainer = Trainer(model)

# train
results = trainer.train(train_provider,
                        valid_provider,
                        epochs=epochs,
                        batch_size=batch_size,
                        mini_batch_size=mini_batch_size,
                        output_path=output_path,
                        optimizer=[gen_optimizer, disc_optimizer],
                        learning_rate=[gen_lr, disc_lr],
                        eval_frequency=eval_frequency)

# eval
test_provider = DataProvider(test_list, [org_suffix, age_suffix],
                             is_pre_load=False,
示例#5
0
valid_provider = DataProvider(valid_list, [org_suffix, lab_suffix],
                        is_pre_load=False,
                        # temp_dir=output_path,
                        processor=processor)

# build model
vgg3d = VGG3D(n_layer=5, root_filters=16, use_bn=False)
model = ClsModel(vgg3d, org_suffix, lab_suffix, dropout=0.1)
lr = StepDecayLearningRate(learning_rate=args.learning_rate, 
                           decay_step=10,
                           decay_rate=0.8,
                           data_size=train_provider.size,
                           batch_size=args.batch_size)
optimizer = tf.keras.optimizers.Adam(lr)
trainer = Trainer(model)

# train
results = trainer.train(train_provider, valid_provider,
                       epochs=args.epochs,
                       batch_size=args.batch_size,
                       mini_batch_size=args.minibatch_size,
                       output_path=output_path,
                       optimizer=optimizer,
                       learning_rate=lr,
                       eval_frequency=args.eval_batch_size)

# eval
test_provider = DataProvider(test_list, [org_suffix, lab_suffix],
                        is_pre_load=False,
                        processor=processor)