コード例 #1
0
 def __init__(self, shape, emb_size=512, ksize=5, h_size=1000):
     """
     shape: tuple of ints (H,W)
     emb_size: int
     ksize: int
     """
     super().__init__()
     self.emb_size = emb_size
     self.ksize = ksize
     self.shape = shape
     self.h_size = h_size
     self.conv = nn.Conv2d(self.emb_size, self.emb_size, self.ksize)
     self.activ = nn.ReLU()
     self.layer = nn.Sequential(self.conv, self.activ)
     new_shape = update_shape(self.shape, kernel=ksize, stride=1, padding=0)
     flat_size = new_shape[-2] * new_shape[-1] * self.emb_size
     self.collapser = nn.Sequential(nn.Linear(flat_size, self.h_size),
                                    nn.ReLU(),
                                    nn.Linear(self.h_size, self.emb_size))
     self.x_shape = (len(x), self.emb_size, self.shape[0], self.shape[1])
コード例 #2
0
    def __init__(self,
                 emb_size,
                 img_shape,
                 deconv_start_shape=(512, 3, 3),
                 deconv_ksizes=None,
                 deconv_strides=None,
                 deconv_lnorm=True,
                 fwd_bnorm=False,
                 drop_p=0,
                 end_sigmoid=False,
                 n_resblocks=1,
                 deconv_attn=False,
                 deconv_attn_layers=3,
                 deconv_attn_size=64,
                 deconv_heads=8,
                 deconv_multi_init=False,
                 **kwargs):
        """
        deconv_start_shape - list like [channel1, height1, width1]
            the initial shape to reshape the embedding inputs
        deconv_ksizes - list like of ints
            the kernel size for each layer
        deconv_stides - list like of ints
            the strides for each layer
        img_shape - list like [channel2, height2, width2]
            the final shape of the decoded tensor
        emb_size - int
            size of belief vector h
        deconv_lnorm: bool
            determines if layer norms will be used at each layer
        fwd_bnorm: bool
            determines if batchnorm will be used
        drop_p - float
            dropout probability at each layer
        end_sigmoid: bool
            if true, the final activation is a sigmoid. Otherwise
            there is no final activation
        n_resblocks: int
            number of ending residual blocks
        deconv_attn: bool
            if true, the incoming embedding is expanded using an attn
            based module
        deconv_attn_layers: int
            the number of decoding layers to use for the attention
            module
        deconv_attn_size: int
            the size of the projections in the multi-headed attn layer
        deconv_heads: int
            the number of projection spaces in the multi-headed attn
            layer
        deconv_multi_init: bool
            if true, the init vector for the attention module will be
            trained uniquely for each position
        """
        super().__init__()
        if deconv_start_shape[0] is None:
            deconv_start_shape = [emb_size, *deconv_start_shape[1:]]
        self.start_shape = deconv_start_shape
        self.img_shape = img_shape
        self.emb_size = emb_size
        self.drop_p = drop_p
        self.bnorm = fwd_bnorm
        self.end_sigmoid = end_sigmoid
        self.strides = deconv_strides
        self.ksizes = deconv_ksizes
        self.lnorm = deconv_lnorm
        self.n_resblocks = n_resblocks
        self.deconv_attn = deconv_attn
        self.dec_layers = deconv_attn_layers
        self.attn_size = deconv_attn_size
        self.n_heads = deconv_heads
        self.multi_init = deconv_multi_init
        print("deconv using bnorm:", self.bnorm)

        if self.ksizes is None:
            self.ksizes = [7, 4, 4, 5, 5, 5, 5, 5, 4]
        if self.strides is None:
            self.strides = [1, 1, 1, 1, 1, 1, 1, 2, 1]

        modules = []
        if deconv_attn:
            if self.start_shape[-3] != self.emb_size:
                modules.append(nn.Linear(self.emb_size, self.start_shape[-3]))
            l = int(np.prod(self.start_shape[-2:]))
            modules.append(Reshape((-1, 1, self.emb_size)))

            decoder = Decoder(l,
                              self.start_shape[-3],
                              self.attn_size,
                              self.dec_layers,
                              n_heads=self.n_heads,
                              init_decs=True,
                              multi_init=self.multi_init)
            modules.append(DeconvAttn(decoder=decoder))
        else:
            flat_start = int(np.prod(deconv_start_shape))
            if self.lnorm:
                modules.append(nn.LayerNorm(emb_size))
            modules.append(nn.Linear(emb_size, flat_start))
            if self.bnorm:
                modules.append(nn.BatchNorm1d(flat_start))
        modules.append(Reshape((-1, *deconv_start_shape)))

        depth, height, width = deconv_start_shape
        first_ksize = self.ksizes[0]
        first_stride = self.strides[0]
        self.sizes = []
        deconv = deconv_block(depth,
                              depth,
                              ksize=first_ksize,
                              stride=first_stride,
                              padding=0,
                              bnorm=self.bnorm,
                              drop_p=self.drop_p)
        height, width = update_shape((height, width),
                                     kernel=first_ksize,
                                     stride=first_stride,
                                     op="deconv")
        print("Img shape:", self.img_shape)
        print("Start Shape:", deconv_start_shape)
        print("h:", height, "| w:", width)
        self.sizes.append((height, width))
        modules.append(deconv)

        padding = 0
        for i in range(1, len(self.ksizes)):
            ksize = self.ksizes[i]
            stride = self.strides[i]
            if self.lnorm:
                modules.append(nn.LayerNorm((depth, height, width)))
            height, width = update_shape((height, width),
                                         kernel=ksize,
                                         stride=stride,
                                         padding=padding,
                                         op="deconv")
            end_depth = max(depth // 2, 16)
            modules.append(
                deconv_block(depth,
                             end_depth,
                             ksize=ksize,
                             padding=padding,
                             stride=stride,
                             bnorm=self.bnorm,
                             drop_p=drop_p))
            depth = end_depth
            self.sizes.append((height, width))
            print("h:", height, "| w:", width, "| d:", depth)

        modules.append(nn.UpsamplingBilinear2d(size=self.img_shape[-2:]))
        if self.n_resblocks is not None and self.n_resblocks > 0:
            for r in range(self.n_resblocks):
                modules.append(ResBlock(depth=depth, ksize=3, bnorm=False))
            modules.append(nn.Conv2d(depth, self.img_shape[-3], 1))
        else:
            modules.append(nn.Conv2d(depth, self.img_shape[-3], 3, padding=1))
        self.sizes.append(self.img_shape[-2:])
        if self.end_sigmoid:
            modules.append(nn.Sigmoid())
        print("decoder:", self.sizes[-1][0], self.sizes[-1][1])
        self.sequential = nn.Sequential(*modules)
コード例 #3
0
    def __init__(self, emb_size, intm_attn=0, **kwargs):
        """
        emb_size: int
        intm_attn: int
            an integer indicating the number of layers for an attention
            layer in between convolutions
        """
        super().__init__(**kwargs)
        self.emb_size = emb_size
        self.intm_attn = intm_attn
        self.conv_blocks = nn.ModuleList([])
        self.intm_attns = nn.ModuleList([])
        self.shapes = []
        shape = self.img_shape[-2:]
        self.shapes.append(shape)
        chans = [16, 32, 64, 128, self.emb_size]
        stride = 2
        ksize = 7
        self.chans = chans
        padding = 0
        block = self.get_conv_block(in_chan=self.img_shape[-3],
                                    out_chan=self.chans[0],
                                    ksize=ksize,
                                    stride=stride,
                                    padding=padding,
                                    bnorm=self.bnorm,
                                    act_fxn=self.act_fxn,
                                    drop_p=0)
        self.conv_blocks.append(nn.Sequential(*block))
        shape = update_shape(shape,
                             kernel=ksize,
                             stride=stride,
                             padding=padding)
        self.shapes.append(shape)
        if self.intm_attn > 0:
            attn = ConvAttention(chans[0],
                                 shape,
                                 n_layers=self.intm_attn,
                                 attn_size=self.attn_size,
                                 act_fxn=self.act_fxn)
            self.itmd_attns.append(attn)

        ksize = 3
        for i in range(len(chans) - 1):
            if i in {1, 3}: stride = 2
            else: stride = 1
            block = self.get_conv_block(in_chan=chans[i],
                                        out_chan=chans[i + 1],
                                        ksize=ksize,
                                        stride=stride,
                                        padding=padding,
                                        bnorm=self.bnorm,
                                        act_fxn=self.act_fxn,
                                        drop_p=0)
            self.conv_blocks.append(nn.Sequential(*block))
            shape = update_shape(shape,
                                 kernel=ksize,
                                 stride=stride,
                                 padding=padding)
            self.shapes.append(shape)
            print("model shape {}: {}".format(i, shape))
            if self.intm_attn > 0:
                attn = ConvAttention(chans[0],
                                     shape,
                                     n_layers=self.intm_attn,
                                     attn_size=self.attn_size,
                                     act_fxn=self.act_fxn)
                self.itmd_attns.append(attn)
        self.seq_len = shape[0] * shape[1]