Beispiel #1
0
    def __init__(self, cf, logger):

        super(net, self).__init__()
        self.cf = cf
        self.logger = logger
        backbone = utils.import_module('bbone', cf.backbone_path)
        self.logger.info("loaded backbone from {}".format(
            self.cf.backbone_path))
        conv_gen = backbone.ConvGenerator(cf.dim)

        # set operate_stride1=True to generate a unet-like FPN.)
        self.fpn = backbone.FPN(cf,
                                conv=conv_gen,
                                relu_enc=cf.relu,
                                operate_stride1=True)
        self.conv_final = conv_gen(cf.end_filts,
                                   cf.num_seg_classes,
                                   ks=1,
                                   pad=0,
                                   norm=cf.norm,
                                   relu=None)

        #initialize parameters
        if self.cf.weight_init == "custom":
            logger.info(
                "Tried to use custom weight init which is not defined. Using pytorch default."
            )
        elif self.cf.weight_init:
            mutils.initialize_weights(self)
        else:
            logger.info("using default pytorch weight init")
Beispiel #2
0
    def __init__(self, cf, logger):

        super(net, self).__init__()
        self.cf = cf
        self.logger = logger
        self.build()

        loss_order = [
            'rpn_class', 'rpn_bbox', 'mrcnn_bbox', 'mrcnn_mask', 'mrcnn_class',
            'mrcnn_rg'
        ]
        if hasattr(cf, "mrcnn_loss_weights"):
            # bring into right order
            self.loss_weights = np.array(
                [cf.mrcnn_loss_weights[k] for k in loss_order])
        else:
            self.loss_weights = np.array([1.] * len(loss_order))

        if self.cf.weight_init == "custom":
            logger.info(
                "Tried to use custom weight init which is not defined. Using pytorch default."
            )
        elif self.cf.weight_init:
            mutils.initialize_weights(self)
        else:
            logger.info("using default pytorch weight init")
Beispiel #3
0
    def __init__(self, cf, logger):

        super(net, self).__init__()
        self.cf = cf
        self.logger = logger
        self.build()
        if self.cf.weight_init is not None:
            logger.info("using pytorch weight init of type {}".format(self.cf.weight_init))
            mutils.initialize_weights(self)
        else:
            logger.info("using default pytorch weight init")
Beispiel #4
0
    def __init__(self, cf, logger):
        super(net, self).__init__()

        self.cf = cf
        self.dim = cf.dim
        self.norm = cf.norm
        self.logger = logger
        backbone = utils.import_module('bbone', cf.backbone_path)
        self.c_gen = backbone.ConvGenerator(cf.dim)
        self.Interpolator = backbone.Interpolate

        #down = DownBlockGen(cf.dim)
        #up = UpBlockGen(cf.dim, backbone.Interpolate)
        down = self.down
        up = self.up

        pad = cf.pad
        if pad=="same":
            pad = (cf.kernel_size-1)//2

        
        self.dims = "not yet recorded"
        self.is_cuda = False
              
        self.init = horiz_conv(len(cf.channels), cf.init_filts, cf.kernel_size, self.c_gen, self.norm, pad=pad,
                               relu=cf.relu)
        
        self.down1 = down(cf.init_filts,    cf.init_filts*2,  cf.kernel_size, cf.kernel_size_m, pad=pad, relu=cf.relu)
        self.down2 = down(cf.init_filts*2,  cf.init_filts*4,  cf.kernel_size, cf.kernel_size_m, pad=pad, relu=cf.relu)
        self.down3 = down(cf.init_filts*4,  cf.init_filts*6,  cf.kernel_size, cf.kernel_size_m, pad=pad, relu=cf.relu)
        self.down4 = down(cf.init_filts*6,  cf.init_filts*8,  cf.kernel_size, cf.kernel_size_m, pad=pad, relu=cf.relu,
                          maintain_z=True)
        self.down5 = down(cf.init_filts*8,  cf.init_filts*12, cf.kernel_size, cf.kernel_size_m, pad=pad, relu=cf.relu,
                          maintain_z=True)
        #self.down6 = down(cf.init_filts*10, cf.init_filts*14, cf.kernel_size, cf.kernel_size_m, pad=pad, relu=cf.relu)
        
        #self.up1 = up(cf.init_filts*14, cf.init_filts*10, cf.kernel_size, pad=pad, relu=cf.relu)
        self.up2 = up(cf.init_filts*12, cf.init_filts*8,  cf.kernel_size, pad=pad, relu=cf.relu, maintain_z=True)
        self.up3 = up(cf.init_filts*8,  cf.init_filts*6,  cf.kernel_size, pad=pad, relu=cf.relu, maintain_z=True)
        self.up4 = up(cf.init_filts*6,  cf.init_filts*4,  cf.kernel_size, pad=pad, relu=cf.relu)
        self.up5 = up(cf.init_filts*4,  cf.init_filts*2,  cf.kernel_size, pad=pad, relu=cf.relu)
        self.up6 = up(cf.init_filts*2,  cf.init_filts,    cf.kernel_size, pad=pad, relu=cf.relu)
        
        self.seg = self.c_gen(cf.init_filts, cf.num_seg_classes, 1, norm=None, relu=None)


        # initialize parameters
        if self.cf.weight_init == "custom":
            logger.info("Tried to use custom weight init which is not defined. Using pytorch default.")
        elif self.cf.weight_init:
            mutils.initialize_weights(self)
        else:
            logger.info("using default pytorch weight init")
Beispiel #5
0
    def __init__(self, cf, logger):

        super(net, self).__init__()
        self.cf = cf
        self.logger = logger
        self.regress_flag = any(['regression' in task for task in self.cf.prediction_tasks])
        self.build()


        if self.cf.weight_init=="custom":
            logger.info("Tried to use custom weight init which is not defined. Using pytorch default.")
        elif self.cf.weight_init:
            mutils.initialize_weights(self)
        else:
            logger.info("using default pytorch weight init")
Beispiel #6
0
    def __init__(self, cf, logger):
        """
        cf: A Sub-class of the cf class
        model_dir: Directory to save training logs and trained weights
        """
        super(net, self).__init__()
        self.cf = cf
        self.logger = logger
        self.build()
        if self.cf.weight_init is not None:
            mutils.initialize_weights(self)
        else:
            logger.info("using default pytorch weight init")

        self.debug_acm = []
Beispiel #7
0
    def __init__(self, cf, logger):

        super(net, self).__init__()
        self.cf = cf
        self.logger = logger
        backbone = utils.import_module('bbone', cf.backbone_path)
        conv = mutils.NDConvGenerator(cf.dim)

        # set operate_stride1=True to generate a unet-like FPN.)
        self.fpn = backbone.FPN(cf, conv, operate_stride1=True).cuda()
        self.conv_final = conv(cf.end_filts, cf.num_seg_classes, ks=1, pad=0, norm=cf.norm, relu=None)

        if self.cf.weight_init is not None:
            logger.info("using pytorch weight init of type {}".format(self.cf.weight_init))
            mutils.initialize_weights(self)
        else:
            logger.info("using default pytorch weight init")