def unet_learner(
    data: DataBunch,
    arch: Callable,
    pretrained: bool = True,
    blur_final: bool = True,
    norm_type: Optional[NormType] = NormType,
    split_on: Optional[SplitFuncOrIdxList] = None,
    blur: bool = False,
    self_attention: bool = False,
    y_range: Optional[Tuple[float, float]] = None,
    last_cross: bool = True,
    bottle: bool = False,
    cut: Union[int, Callable] = None,
    hypercolumns=True,
    **learn_kwargs: Any,
) -> Learner:
    "Build Unet learner from `data` and `arch`."
    meta = cnn_config(arch)
    body = create_body(arch, pretrained, cut)
    M = DynamicUnet_Hcolumns if hypercolumns else DynamicUnet
    model = to_device(
        M(
            body,
            n_classes=data.c,
            blur=blur,
            blur_final=blur_final,
            self_attention=self_attention,
            y_range=y_range,
            norm_type=norm_type,
            last_cross=last_cross,
            bottle=bottle,
        ),
        data.device,
    )
    learn = Learner(data, model, **learn_kwargs)
    learn.split(ifnone(split_on, meta["split"]))
    if pretrained:
        learn.freeze()
    apply_init(model[2], nn.init.kaiming_normal_)
    return learn
Ejemplo n.º 2
0
def unet_learner_wide(data: DataBunch,
                      arch: Callable,
                      pretrained: bool = True,
                      blur_final: bool = True,
                      norm_type: Optional[NormType] = NormType,
                      split_on: Optional[SplitFuncOrIdxList] = None,
                      blur: bool = False,
                      self_attention: bool = False,
                      y_range: Optional[Tuple[float, float]] = None,
                      last_cross: bool = True,
                      bottle: bool = False,
                      nf_factor: int = 1,
                      **kwargs: Any) -> Learner:
    "Build Unet learner from `data` and `arch`."
    meta = cnn_config(arch)
    body = create_body(arch, pretrained)
    # can tell to go to another gpu
    model = to_device(
        DynamicUnetWide(
            body,
            n_classes=data.c,
            blur=blur,
            blur_final=blur_final,
            self_attention=self_attention,
            y_range=y_range,
            norm_type=norm_type,
            last_cross=last_cross,
            bottle=bottle,
            nf_factor=nf_factor,
        ),
        data.device,
    )
    learn = Learner(data, model, **kwargs)
    learn.split(ifnone(split_on, meta['split']))
    if pretrained:
        learn.freeze()
    apply_init(model[2], nn.init.kaiming_normal_)
    return learn