def get_learn(data): # create model model = nb_resnet_unet.get_unet_res18(1, True) model.load_state_dict(torch.load('./models/unet_res18_allres_init.pth')) # create learner learn = Learner(data, model) # split model learn.layer_groups = split_model(learn.model) # set multi-gpu if data.device.type == 'cuda': learn.model = torch.nn.DataParallel(learn.model, device_ids=[0, 1, 2, 3]) # set loss func # learn.loss_func = partial(nb_loss_metrics.combo_loss, balance_ratio=1) # learn.loss_func = nb_loss_metrics.dice_loss learn.loss_func = partial(nb_loss_metrics.balance_bce, balance_ratio=1) # 添加metrics learn.metrics += [nb_loss_metrics.dice_loss] learn.metrics += [partial(nb_loss_metrics.balance_bce, balance_ratio=1)] learn.metrics += [nb_loss_metrics.mask_iou] return learn
def get_learn_detectsym_17clas(data, gaf, clas_weights=weights): ''' 用的符号检测的17个类别的数据集 ''' # create model model = resnet_ssd.get_resnet18_1ssd(num_classes=17) model.load_state_dict(torch.load('./models/pretrained_res18_1ssd.pth')) # create learner learn = Learner(data, model) # split model learn.layer_groups = split_model(learn.model) # set multi-gpu if data.device.type == 'cuda': learn.model = torch.nn.DataParallel( learn.model, device_ids=device_ids) #device_ids=[0,1,2,3,4,5]) # set loss func learn.loss_func = partial(anchors_loss_metrics.yolo_L, gaf=gaf, conf_th=1, clas_weights=clas_weights, lambda_nconf=10) # 添加metrics learn.metrics += [ partial(anchors_loss_metrics.clas_L, gaf=gaf, clas_weights=clas_weights) ] learn.metrics += [ partial(anchors_loss_metrics.cent_L, gaf=gaf, clas_weights=clas_weights) ] learn.metrics += [ partial(anchors_loss_metrics.pConf_L, gaf=gaf, clas_weights=clas_weights) ] learn.metrics += [ partial(anchors_loss_metrics.nConf_L, gaf=gaf, conf_th=1) ] learn.metrics += [partial(anchors_loss_metrics.clas_acc, gaf=gaf)] learn.metrics += [partial(anchors_loss_metrics.cent_d, gaf=gaf)] return learn