def __init__(self, shape, emb_size=512, ksize=5, h_size=1000): """ shape: tuple of ints (H,W) emb_size: int ksize: int """ super().__init__() self.emb_size = emb_size self.ksize = ksize self.shape = shape self.h_size = h_size self.conv = nn.Conv2d(self.emb_size, self.emb_size, self.ksize) self.activ = nn.ReLU() self.layer = nn.Sequential(self.conv, self.activ) new_shape = update_shape(self.shape, kernel=ksize, stride=1, padding=0) flat_size = new_shape[-2] * new_shape[-1] * self.emb_size self.collapser = nn.Sequential(nn.Linear(flat_size, self.h_size), nn.ReLU(), nn.Linear(self.h_size, self.emb_size)) self.x_shape = (len(x), self.emb_size, self.shape[0], self.shape[1])
def __init__(self, emb_size, img_shape, deconv_start_shape=(512, 3, 3), deconv_ksizes=None, deconv_strides=None, deconv_lnorm=True, fwd_bnorm=False, drop_p=0, end_sigmoid=False, n_resblocks=1, deconv_attn=False, deconv_attn_layers=3, deconv_attn_size=64, deconv_heads=8, deconv_multi_init=False, **kwargs): """ deconv_start_shape - list like [channel1, height1, width1] the initial shape to reshape the embedding inputs deconv_ksizes - list like of ints the kernel size for each layer deconv_stides - list like of ints the strides for each layer img_shape - list like [channel2, height2, width2] the final shape of the decoded tensor emb_size - int size of belief vector h deconv_lnorm: bool determines if layer norms will be used at each layer fwd_bnorm: bool determines if batchnorm will be used drop_p - float dropout probability at each layer end_sigmoid: bool if true, the final activation is a sigmoid. Otherwise there is no final activation n_resblocks: int number of ending residual blocks deconv_attn: bool if true, the incoming embedding is expanded using an attn based module deconv_attn_layers: int the number of decoding layers to use for the attention module deconv_attn_size: int the size of the projections in the multi-headed attn layer deconv_heads: int the number of projection spaces in the multi-headed attn layer deconv_multi_init: bool if true, the init vector for the attention module will be trained uniquely for each position """ super().__init__() if deconv_start_shape[0] is None: deconv_start_shape = [emb_size, *deconv_start_shape[1:]] self.start_shape = deconv_start_shape self.img_shape = img_shape self.emb_size = emb_size self.drop_p = drop_p self.bnorm = fwd_bnorm self.end_sigmoid = end_sigmoid self.strides = deconv_strides self.ksizes = deconv_ksizes self.lnorm = deconv_lnorm self.n_resblocks = n_resblocks self.deconv_attn = deconv_attn self.dec_layers = deconv_attn_layers self.attn_size = deconv_attn_size self.n_heads = deconv_heads self.multi_init = deconv_multi_init print("deconv using bnorm:", self.bnorm) if self.ksizes is None: self.ksizes = [7, 4, 4, 5, 5, 5, 5, 5, 4] if self.strides is None: self.strides = [1, 1, 1, 1, 1, 1, 1, 2, 1] modules = [] if deconv_attn: if self.start_shape[-3] != self.emb_size: modules.append(nn.Linear(self.emb_size, self.start_shape[-3])) l = int(np.prod(self.start_shape[-2:])) modules.append(Reshape((-1, 1, self.emb_size))) decoder = Decoder(l, self.start_shape[-3], self.attn_size, self.dec_layers, n_heads=self.n_heads, init_decs=True, multi_init=self.multi_init) modules.append(DeconvAttn(decoder=decoder)) else: flat_start = int(np.prod(deconv_start_shape)) if self.lnorm: modules.append(nn.LayerNorm(emb_size)) modules.append(nn.Linear(emb_size, flat_start)) if self.bnorm: modules.append(nn.BatchNorm1d(flat_start)) modules.append(Reshape((-1, *deconv_start_shape))) depth, height, width = deconv_start_shape first_ksize = self.ksizes[0] first_stride = self.strides[0] self.sizes = [] deconv = deconv_block(depth, depth, ksize=first_ksize, stride=first_stride, padding=0, bnorm=self.bnorm, drop_p=self.drop_p) height, width = update_shape((height, width), kernel=first_ksize, stride=first_stride, op="deconv") print("Img shape:", self.img_shape) print("Start Shape:", deconv_start_shape) print("h:", height, "| w:", width) self.sizes.append((height, width)) modules.append(deconv) padding = 0 for i in range(1, len(self.ksizes)): ksize = self.ksizes[i] stride = self.strides[i] if self.lnorm: modules.append(nn.LayerNorm((depth, height, width))) height, width = update_shape((height, width), kernel=ksize, stride=stride, padding=padding, op="deconv") end_depth = max(depth // 2, 16) modules.append( deconv_block(depth, end_depth, ksize=ksize, padding=padding, stride=stride, bnorm=self.bnorm, drop_p=drop_p)) depth = end_depth self.sizes.append((height, width)) print("h:", height, "| w:", width, "| d:", depth) modules.append(nn.UpsamplingBilinear2d(size=self.img_shape[-2:])) if self.n_resblocks is not None and self.n_resblocks > 0: for r in range(self.n_resblocks): modules.append(ResBlock(depth=depth, ksize=3, bnorm=False)) modules.append(nn.Conv2d(depth, self.img_shape[-3], 1)) else: modules.append(nn.Conv2d(depth, self.img_shape[-3], 3, padding=1)) self.sizes.append(self.img_shape[-2:]) if self.end_sigmoid: modules.append(nn.Sigmoid()) print("decoder:", self.sizes[-1][0], self.sizes[-1][1]) self.sequential = nn.Sequential(*modules)
def __init__(self, emb_size, intm_attn=0, **kwargs): """ emb_size: int intm_attn: int an integer indicating the number of layers for an attention layer in between convolutions """ super().__init__(**kwargs) self.emb_size = emb_size self.intm_attn = intm_attn self.conv_blocks = nn.ModuleList([]) self.intm_attns = nn.ModuleList([]) self.shapes = [] shape = self.img_shape[-2:] self.shapes.append(shape) chans = [16, 32, 64, 128, self.emb_size] stride = 2 ksize = 7 self.chans = chans padding = 0 block = self.get_conv_block(in_chan=self.img_shape[-3], out_chan=self.chans[0], ksize=ksize, stride=stride, padding=padding, bnorm=self.bnorm, act_fxn=self.act_fxn, drop_p=0) self.conv_blocks.append(nn.Sequential(*block)) shape = update_shape(shape, kernel=ksize, stride=stride, padding=padding) self.shapes.append(shape) if self.intm_attn > 0: attn = ConvAttention(chans[0], shape, n_layers=self.intm_attn, attn_size=self.attn_size, act_fxn=self.act_fxn) self.itmd_attns.append(attn) ksize = 3 for i in range(len(chans) - 1): if i in {1, 3}: stride = 2 else: stride = 1 block = self.get_conv_block(in_chan=chans[i], out_chan=chans[i + 1], ksize=ksize, stride=stride, padding=padding, bnorm=self.bnorm, act_fxn=self.act_fxn, drop_p=0) self.conv_blocks.append(nn.Sequential(*block)) shape = update_shape(shape, kernel=ksize, stride=stride, padding=padding) self.shapes.append(shape) print("model shape {}: {}".format(i, shape)) if self.intm_attn > 0: attn = ConvAttention(chans[0], shape, n_layers=self.intm_attn, attn_size=self.attn_size, act_fxn=self.act_fxn) self.itmd_attns.append(attn) self.seq_len = shape[0] * shape[1]