예제 #1
0
import os
import json

from models.segmentor.dynamicUNet import Unet
from dataset.transformer import TransformerVal
from dataset.dataset import OralSlide, collate
from helper.helper_unet import SlideInference, create_model_load_weights
from helper.runner import Runner
from configs.config_patch_merge_unet import Config

distributed = False
cfg = Config(mode='patch-merge', train=False)
model = Unet(classes=cfg.n_class, encoder_name=cfg.encoder, **cfg.model_cfg)
runner = Runner(cfg, model, create_model_load_weights, distributed=distributed)

###################################
print("preparing datasets......")
slideset_cfg = cfg.testset_cfg
slide_list = sorted(os.listdir(slideset_cfg["img_dir"]))
transformer = TransformerVal()
dataset = OralSlide(
    slide_list,
    slideset_cfg["img_dir"],
    slideset_cfg["meta_file"],
    slide_mask_dir=slideset_cfg["mask_dir"],
    label=slideset_cfg['label'],
    transform=transformer,
)

runner.eval_slide(dataset, SlideInference, cfg.test_output_path)
예제 #2
0
from utils.seg_loss import FocalLoss, SymmetricCrossEntropyLoss, DecoupledSegLoss_v1, DecoupledSegLoss_v2
from helper.helper_unet import Trainer, Evaluator, SlideInference, save_ckpt_model, update_log, update_writer, get_optimizer, create_model_load_weights
from configs.config_local_merge_unet import Config
from helper.runner import argParser, seed_everything, Runner

args = argParser()
distributed = False
# if torch.cuda.device_count() > 1:
#     distributed = True

# seed
SEED = 233
seed_everything(SEED)
cfg = Config(mode='local_merge', train=True)
model = Unet(classes=cfg.n_class, encoder_name=cfg.encoder, **cfg.model_cfg)
runner = Runner(cfg, model, create_model_load_weights, distributed=distributed)

###################################
print("preparing datasets......")
batch_size = cfg.batch_size
num_workers = cfg.num_workers
trainset_cfg = cfg.trainset_cfg
valset_cfg = cfg.valset_cfg
testset_cfg = cfg.testset_cfg

transformer_train = TransformerMerge()
dataset_train = OralDatasetLocal(
    trainset_cfg["img_dir"],
    trainset_cfg["mask_dir"],
    trainset_cfg["meta_file"],
    label=trainset_cfg["label"],
예제 #3
0
파일: main.py 프로젝트: chenxh06/task1
from helper.runner import Runner

if __name__ == '__main__':
    Runner().start()
예제 #4
0
from utils.seg_loss import FocalLoss, SymmetricCrossEntropyLoss, DecoupledSegLoss_v1, DecoupledSegLoss_v2
from helper.helper_unet import Trainer, Evaluator, save_ckpt_model, update_log, update_writer, get_optimizer, create_model_load_weights
from configs.config_global_unet import Config
from helper.runner import argParser, seed_everything, Runner

args = argParser()
distributed = False
# if torch.cuda.device_count() > 1:
#     distributed = True

# seed
SEED = 23
seed_everything(SEED)
cfg = Config(mode='global', train=True)
model = Unet(classes=cfg.n_class, encoder_name=cfg.encoder, **cfg.model_cfg)
runner = Runner(cfg, model, create_model_load_weights, distributed=distributed)

###################################
print("preparing datasets......")
batch_size = cfg.batch_size
num_workers = cfg.num_workers
trainset_cfg = cfg.trainset_cfg
valset_cfg = cfg.valset_cfg

transformer_train = Transformer()
dataset_train = OralDataset(
    trainset_cfg["img_dir"],
    trainset_cfg["mask_dir"],
    trainset_cfg["meta_file"],
    label=trainset_cfg["label"],
    transform=transformer_train,
예제 #5
0
import os
import json

from models.segmentor.dynamicUNet import Unet
from dataset.transformer import TransformerVal
from dataset.dataset import OralSlide, collate
from helper.helper_unet import SlideInference, create_model_load_weights
from helper.runner import Runner
from configs.config_patch_unet import Config

distributed = False
cfg = Config(mode='patch', train=False)
model = Unet(classes=cfg.n_class, encoder_name=cfg.encoder, **cfg.model_cfg)
runner = Runner(cfg, model, create_model_load_weights, distributed=distributed)

###################################
print("preparing datasets......")
with open('/media/ldy/7E1CA94545711AE6/OSCC/train_val_part.json', 'r') as f:
    slide_list = json.load(f)['val']
slideset_cfg = cfg.slideset_cfg
transformer = TransformerVal()
dataset = OralSlide(
    slide_list,
    slideset_cfg["img_dir"],
    slideset_cfg["meta_file"],
    slide_mask_dir=slideset_cfg["mask_dir"],
    label=slideset_cfg['label'],
    transform=transformer,
)

runner.eval_slide(dataset, SlideInference)