コード例 #1
0
def get_learn(data):
    # create model
    model = nb_resnet_unet.get_unet_res18(1, True)
    model.load_state_dict(torch.load('./models/unet_res18_allres_init.pth'))

    # create learner
    learn = Learner(data, model)

    # split model
    learn.layer_groups = split_model(learn.model)

    # set multi-gpu
    if data.device.type == 'cuda':
        learn.model = torch.nn.DataParallel(learn.model,
                                            device_ids=[0, 1, 2, 3])

    # set loss func
#     learn.loss_func = partial(nb_loss_metrics.combo_loss, balance_ratio=1)
#     learn.loss_func = nb_loss_metrics.dice_loss
    learn.loss_func = partial(nb_loss_metrics.balance_bce, balance_ratio=1)

    # 添加metrics
    learn.metrics += [nb_loss_metrics.dice_loss]
    learn.metrics += [partial(nb_loss_metrics.balance_bce, balance_ratio=1)]
    learn.metrics += [nb_loss_metrics.mask_iou]

    return learn
コード例 #2
0
def get_learn_detectsym_17clas(data, gaf, clas_weights=weights):
    '''
    用的符号检测的17个类别的数据集
    '''
    # create model
    model = resnet_ssd.get_resnet18_1ssd(num_classes=17)
    model.load_state_dict(torch.load('./models/pretrained_res18_1ssd.pth'))

    # create learner
    learn = Learner(data, model)

    # split model
    learn.layer_groups = split_model(learn.model)

    # set multi-gpu
    if data.device.type == 'cuda':
        learn.model = torch.nn.DataParallel(
            learn.model, device_ids=device_ids)  #device_ids=[0,1,2,3,4,5])

    # set loss func
    learn.loss_func = partial(anchors_loss_metrics.yolo_L,
                              gaf=gaf,
                              conf_th=1,
                              clas_weights=clas_weights,
                              lambda_nconf=10)

    # 添加metrics
    learn.metrics += [
        partial(anchors_loss_metrics.clas_L,
                gaf=gaf,
                clas_weights=clas_weights)
    ]
    learn.metrics += [
        partial(anchors_loss_metrics.cent_L,
                gaf=gaf,
                clas_weights=clas_weights)
    ]
    learn.metrics += [
        partial(anchors_loss_metrics.pConf_L,
                gaf=gaf,
                clas_weights=clas_weights)
    ]
    learn.metrics += [
        partial(anchors_loss_metrics.nConf_L, gaf=gaf, conf_th=1)
    ]
    learn.metrics += [partial(anchors_loss_metrics.clas_acc, gaf=gaf)]
    learn.metrics += [partial(anchors_loss_metrics.cent_d, gaf=gaf)]

    return learn