示例#1
0
files = sorted(glob(os.path.join(data_dir, 'hist_match_ct', '*.tif')))

# adjust gamma
'''gamma_save_dir = os.path.join(data_dir, 'gamma_adjust')
ersa_utils.make_dir_if_not_exist(gamma_save_dir)
ga = gammaAdjust.GammaAdjust(gamma=gammas[sample_id - 1], path=gamma_save_dir)
ga.run(force_run=False, file_list=files)
files = sorted(glob(os.path.join(gamma_save_dir, '*.tif')))'''

# get image mean
img_mean = cm.get_channel_mean('', [[f] for f in files])

model_dir = r'/hdd6/Models/UNET_city/UnetCrop_spca_aug_xcity_0_PS(572, 572)_BS5_EP100_LR0.0001_DS60_DR0.1_SFN32'

# test model
'''nn_utils.tf_warn_level(3)
unet = unet.UNet(class_num, patch_size, suffix=suffix, batch_size=bs)
model_dir = r'/hdd6/Models/UNET_city/UnetCrop_spca_aug_xcity_0_PS(572, 572)_BS5_EP100_LR0.0001_DS60_DR0.1_SFN32'
# model_dir = r'/hdd6/Models/Inria_decay/UnetCrop_inria_decay_0_PS(572, 572)_BS5_EP100_LR0.0001_DS60.0_DR0.1_SFN32'
unet.evaluate([[f] for f in files], patch_size, tile_size, bs, img_mean, model_dir, gpu, save_result_parent_dir='aemo',
              sfn=sfn, force_run=False, score_results=False, split_char='.', best_model=False)'''

my_dir = os.path.join(data_dir, 'bh_pred_ct')

# make dirs
if not os.path.exists(my_dir):
    os.makedirs(my_dir)

# run detector
file_list_valid = [[os.path.basename(x)] for x in files]
示例#2
0
    rgb_ext='.*rgb',
    gt_ext='.*gt_d255',
    file_ext='tif',
    force_run=False,
    clc_name=suffix)
cm.print_meta_data()

file_list_train = cm.load_files(field_name='aus10,aus30',
                                field_id='',
                                field_ext='.*rgb,.*gt_d255')
file_list_valid = cm.load_files(field_name='aus50',
                                field_id='',
                                field_ext='.*rgb,.*gt_d255')
chan_mean = cm.meta_data['chan_mean']

nn_utils.tf_warn_level(3)
model_dir = r'/hdd6/Models/aemo/aemo_resize_new_loss/unet_aemo_0_PS(572, 572)_BS5_EP80_LR0.001_DS30_DR0.1'
unet.evaluate(file_list_valid,
              patch_size,
              tile_size,
              bs,
              chan_mean,
              model_dir,
              gpu,
              save_result_parent_dir='aemo',
              sfn=32,
              force_run=True,
              score_results=True,
              split_char='.',
              load_epoch_num=4)
示例#3
0
    rgb_ext='.*rgb',
    gt_ext='.*gt',
    file_ext='tif',
    force_run=False,
    clc_name=suffix)
cm.print_meta_data()

file_list_train = cm.load_files(field_name='aus10,aus30',
                                field_id='',
                                field_ext='.*rgb,.*gt')
file_list_valid = cm.load_files(field_name='aus50',
                                field_id='',
                                field_ext='.*rgb,.*gt')
chan_mean = cm.meta_data['chan_mean']

nn_utils.tf_warn_level(3)
model_dir = r'/hdd6/Models/aemo/new4/unet_aemo_1_PS(572, 572)_BS5_EP80_LR0.001_DS30_DR0.1'
unet.evaluate(file_list_train,
              patch_size,
              tile_size,
              bs,
              chan_mean,
              model_dir,
              gpu,
              save_result_parent_dir='aemo',
              sfn=32,
              force_run=True,
              score_results=True,
              split_char='.',
              ds_name='train')
示例#4
0
feature, label = reader_op

unet.create_graph(feature, sfn)
unet.compile(feature, label, n_train, n_valid, patch_size, ersaPath.PATH['model'], par_dir='test', loss_type='xent')
train_hook = hook.ValueSummaryHook(verb_step, [unet.loss, unet.lr_op], value_names=['train_loss', 'learning_rate'],
                                   print_val=[0])
model_save_hook = hook.ModelSaveHook(unet.get_epoch_step()*10, unet.ckdir)
valid_loss_hook = hook.ValueSummaryHook(unet.get_epoch_step(), [unet.loss],
                                        value_names=['valid_loss'], log_time=True, run_time=unet.n_valid)
valid_iou_hook = hook.IoUSummaryHook(unet.get_epoch_step(), unet.loss_iou, log_time=True, run_time=unet.n_valid,
                                     cust_str='\t')
image_hook = hook.ImageValidSummaryHook(unet.get_epoch_step(), unet.valid_images, feature, label, unet.pred,
                                        nn_utils.image_summary, img_mean=chan_mean)
start_time = time.time()
unet.train(train_hooks=[train_hook, model_save_hook], valid_hooks=[valid_loss_hook, valid_iou_hook, image_hook],
           train_init=train_init_op, valid_init=valid_init_op)
print('Duration: {:.3f}'.format((time.time() - start_time)/3600))'''

nn_utils.tf_warn_level(3)
model_dir = r'/hdd6/Models/test/unet_test_PS(572, 572)_BS5_EP6_LR0.0001_DS60_DR0.1'
unet.evaluate(file_list_valid,
              patch_size,
              tile_size,
              bs,
              chan_mean,
              model_dir,
              gpu,
              save_result_parent_dir='ersa',
              sfn=sfn,
              force_run=True)