Esempio n. 1
0
    def _init_from_compact_format(self, opt, model_type, dataset, modules_lib):
        """Create a unetbasedmodel from the compact options format - where the
        same convolution is given for each layer, and arguments are given
        in lists
        """
        num_convs = len(opt.down_conv.down_conv_nn)

        # Factory for creating up and down modules
        factory_module_cls = self._get_factory(model_type, modules_lib)
        down_conv_cls_name = opt.down_conv.module_name
        up_conv_cls_name = opt.up_conv.module_name
        self._factory_module = factory_module_cls(
            down_conv_cls_name, up_conv_cls_name,
            modules_lib)  # Create the factory object

        # construct unet structure
        contains_global = hasattr(opt,
                                  "innermost") and opt.innermost is not None
        if contains_global:
            assert len(opt.down_conv.down_conv_nn) + 1 == len(
                opt.up_conv.up_conv_nn)

            args_up = self._fetch_arguments_from_list(opt.up_conv, 0)
            args_up["up_conv_cls"] = self._factory_module.get_module("UP")

            unet_block = UnetSkipConnectionBlock(
                args_up=args_up,
                args_innermost=opt.innermost,
                modules_lib=modules_lib,
                submodule=None,
                innermost=True,
            )  # add the innermost layer
        else:
            unet_block = Identity()

        if num_convs > 1:
            for index in range(num_convs - 1, 0, -1):
                args_up, args_down = self._fetch_arguments_up_and_down(
                    opt, index)
                unet_block = UnetSkipConnectionBlock(args_up=args_up,
                                                     args_down=args_down,
                                                     submodule=unet_block)
                self._save_sampling_and_search(unet_block)
        else:
            index = num_convs

        index -= 1
        args_up, args_down = self._fetch_arguments_up_and_down(opt, index)
        args_down["nb_feature"] = dataset.feature_dimension
        args_up["nb_feature"] = dataset.feature_dimension
        self.model = UnetSkipConnectionBlock(
            args_up=args_up,
            args_down=args_down,
            submodule=unet_block,
            outermost=True)  # add the outermost layer
        self._save_sampling_and_search(self.model)
    def _init_from_compact_format(self, opt, model_type, dataset, modules_lib):
        """Create a unetbasedmodel from the compact options format - where the
        same convolution is given for each layer, and arguments are given
        in lists
        """

        self.down_modules = nn.ModuleList()
        self.inner_modules = nn.ModuleList()
        self.up_modules = nn.ModuleList()

        self.save_sampling_id = opt.down_conv.get('save_sampling_id')

        # Factory for creating up and down modules
        factory_module_cls = self._get_factory(model_type, modules_lib)
        down_conv_cls_name = opt.down_conv.module_name
        up_conv_cls_name = opt.up_conv.module_name if opt.get(
            'up_conv') is not None else None
        self._factory_module = factory_module_cls(
            down_conv_cls_name, up_conv_cls_name,
            modules_lib)  # Create the factory object

        # Loal module
        contains_global = hasattr(opt,
                                  "innermost") and opt.innermost is not None
        if contains_global:
            inners = self._create_inner_modules(opt.innermost, modules_lib)
            for inner in inners:
                self.inner_modules.append(inner)
        else:
            self.inner_modules.append(Identity())

        # Down modules
        for i in range(len(opt.down_conv.down_conv_nn)):
            args = self._fetch_arguments(opt.down_conv, i, "DOWN")
            conv_cls = self._get_from_kwargs(args, "conv_cls")
            down_module = conv_cls(**args)
            self._save_sampling_and_search(down_module)
            self.down_modules.append(down_module)

        # Up modules
        if up_conv_cls_name:
            for i in range(len(opt.up_conv.up_conv_nn)):
                args = self._fetch_arguments(opt.up_conv, i, "UP")
                conv_cls = self._get_from_kwargs(args, "conv_cls")
                up_module = conv_cls(**args)
                self._save_upsample(up_module)
                self.up_modules.append(up_module)

        self.metric_loss_module, self.miner_module = BaseModel.get_metric_loss_and_miner(
            getattr(opt, "metric_loss", None), getattr(opt, "miner", None))