class VUnetEncoder(nn.Module): def __init__( self, n_stages, nf_in=3, nf_start=64, nf_max=128, n_rnb=2, conv_layer=NormConv2d, dropout_prob=0.0, ): super().__init__() self.in_op = conv_layer(nf_in, nf_start, kernel_size=1) nf = nf_start self.blocks = ModuleDict() self.downs = ModuleDict() self.n_rnb = n_rnb self.n_stages = n_stages for i_s in range(self.n_stages): # prepare resnet blocks per stage if i_s > 0: self.downs.update( { f"s{i_s+1}": Downsample( nf, min(2 * nf, nf_max), conv_layer=conv_layer ) } ) nf = min(2 * nf, nf_max) for ir in range(self.n_rnb): stage = f"s{i_s+1}_{ir+1}" self.blocks.update( { stage: VUnetResnetBlock( nf, conv_layer=conv_layer, dropout_prob=dropout_prob ) } ) def forward(self, x): out = {} h = self.in_op(x) for ir in range(self.n_rnb): h = self.blocks[f"s1_{ir+1}"](h) out[f"s1_{ir+1}"] = h for i_s in range(1, self.n_stages): h = self.downs[f"s{i_s+1}"](h) for ir in range(self.n_rnb): stage = f"s{i_s+1}_{ir+1}" h = self.blocks[stage](h) out[stage] = h return out
class VUnetBottleneck(nn.Module): def __init__( self, n_stages, nf, device, n_rnb=2, n_auto_groups=4, conv_layer=NormConv2d, dropout_prob=0.0, ): super().__init__() self.device = device self.blocks = ModuleDict() self.channel_norm = ModuleDict() self.conv1x1 = conv_layer(nf, nf, 1) self.up = Upsample(in_channels=nf, out_channels=nf, conv_layer=conv_layer) self.depth_to_space = DepthToSpace(block_size=2) self.space_to_depth = SpaceToDepth(block_size=2) self.n_stages = n_stages self.n_rnb = n_rnb # number of autoregressively modeled groups self.n_auto_groups = n_auto_groups for i_s in range(self.n_stages, self.n_stages - 2, -1): self.channel_norm.update({f"s{i_s}": conv_layer(2 * nf, nf, 1)}) for ir in range(self.n_rnb): self.blocks.update( { f"s{i_s}_{ir+1}": VUnetResnetBlock( nf, use_skip=True, conv_layer=conv_layer, dropout_prob=dropout_prob, ) } ) self.auto_blocks = ModuleList() # model the autoregressively groups rnb for i_a in range(4): if i_a < 1: self.auto_blocks.append( VUnetResnetBlock( nf, conv_layer=conv_layer, dropout_prob=dropout_prob ) ) self.param_converter = conv_layer(4 * nf, nf, kernel_size=1) else: self.auto_blocks.append( VUnetResnetBlock( nf, use_skip=True, conv_layer=conv_layer, dropout_prob=dropout_prob, ) ) def forward(self, x_e, z_post, mode="train"): """ Parameters ---------- x_e : torch.Tensor The output from the encoder E_theta z_post : torch.Tensor The output from the encoder F_phi mode : str Determines the mode of the bottleneck, must be in ["train","appearance_transfer","sample_appearance"] Returns ------- h : torch.Tensor the output of the last layer of the bottleneck which is subsequently used by the decoder. posterior_params : torch.Tensor The flattened means of the posterior distributions p(z|ŷ,x) of the two bottleneck stages. prior_params : dict(str: torch.Tensor) The flattened means of the prior distributions p(z|ŷ) of the two bottleneck stages. z_prior : torch.Tensor The current samples of the two stages of the prior distributions of both two bottleneck stages, flattened. """ p_params = {} z_prior = {} use_z = mode == "train" or mode == "appearance_transfer" h = self.conv1x1(x_e[f"s{self.n_stages}_2"]) for i_s in range(self.n_stages, self.n_stages - 2, -1): stage = f"s{i_s}" spatial_size = x_e[stage + "_2"].shape[-1] h = self.blocks[stage + "_2"](h, x_e[stage + "_2"]) if spatial_size == 1: p_params[stage] = h # posterior_params[stage] = z_post[stage + "_2"] prior_samples = self._latent_sample(p_params[stage]) z_prior[stage] = torch.squeeze( torch.squeeze(prior_samples, dim=-1), dim=-1 ) # posterior_samples = self._latent_sample(posterior_params[stage]) else: if use_z: z_flat = ( self.space_to_depth(z_post[stage]) if z_post[stage].shape[2] > 1 else z_post[stage] ) sec_size = z_flat.shape[1] // 4 z_groups = torch.split( z_flat, [sec_size, sec_size, sec_size, sec_size], dim=1 ) param_groups = [] sample_groups = [] param_features = self.auto_blocks[0](h) param_features = self.space_to_depth(param_features) # convert to fitting depth param_features = self.param_converter(param_features) for i_a in range(len(self.auto_blocks)): param_groups.append(param_features) prior_samples = self._latent_sample(param_groups[-1]) sample_groups.append(prior_samples) if i_a + 1 < len(self.auto_blocks): if use_z: feedback = z_groups[i_a] else: feedback = prior_samples param_features = self.auto_blocks[i_a](param_features, feedback) p_params_stage = torch.cat(param_groups, dim=1) prior_samples = self.__merge_groups(sample_groups) p_params[stage] = p_params_stage z_prior[stage] = ( self.space_to_depth(prior_samples).squeeze(dim=-1).squeeze(dim=-1) ) if use_z: z = ( self.depth_to_space(z_post[stage]) if z_post[stage].shape[-1] != h.shape[-1] else z_post[stage] ) else: z = prior_samples h = torch.cat([h, z], dim=1) h = self.channel_norm[stage](h) h = self.blocks[stage + "_1"](h, x_e[stage + "_1"]) if i_s == self.n_stages: h = self.up(h) return h, p_params, z_prior def __split_groups(self, x): # split along channel axis sec_size = x.shape[1] // 4 return torch.split( self.space_to_depth(x), [sec_size, sec_size, sec_size, sec_size], dim=1, ) def __merge_groups(self, x): # merge groups along channel axis return self.depth_to_space(torch.cat(x, dim=1)) def _latent_sample(self, mean): sample_mean = torch.squeeze(torch.squeeze(mean, dim=-1), dim=-1) sampled = MultivariateNormal( loc=torch.zeros_like(sample_mean, device=self.device), covariance_matrix=torch.eye(sample_mean.shape[-1], device=self.device), ).sample() return (sampled + sample_mean).unsqueeze(dim=-1).unsqueeze(dim=-1)
class VUnetBottleneckOld(nn.Module): def __init__( self, n_stages, nf, device, n_rnb=2, n_auto_groups=4, conv_layer=NormConv2d, ): super().__init__() self.device = device self.blocks = ModuleDict() self.channel_norm = ModuleDict() self.conv1x1 = conv_layer(nf, nf, 1) self.up = Upsample(in_channels=nf, out_channels=nf, conv_layer=conv_layer) self.depth_to_space = DepthToSpace(block_size=2) self.space_to_depth = SpaceToDepth(block_size=2) self.n_stages = n_stages self.n_rnb = n_rnb # number of autoregressively modeled groups self.n_auto_groups = n_auto_groups for i_s in range(self.n_stages, self.n_stages - 2, -1): self.channel_norm.update({f"s{i_s}": conv_layer(2 * nf, nf, 1)}) for ir in range(self.n_rnb): self.blocks.update( { f"s{i_s}_{ir+1}": VUnetResnetBlock( nf, use_skip=True, conv_layer=conv_layer ) } ) if FLAGS.group_auto: self.auto_blocks = ModuleList() # model the autoregressively groups rnb for i_a in range(4): if i_a < 1: self.auto_blocks.append(VUnetResnetBlock(nf, conv_layer=conv_layer)) self.param_converter = conv_layer(4 * nf, nf, kernel_size=1) else: self.auto_blocks.append( VUnetResnetBlock(nf, use_skip=True, conv_layer=conv_layer) ) def forward(self, x_e, x_f, mode="train"): """ :param x_e: The output from the encoder E_theta :param x_f: The output from the encoder F_phi :param mode: Determines the mode of the bottleneck, must be in ["train","appearance_transfer","sample_appearance"] :return: h: the output of the last layer of the bottleneck which is subsequently used by the decoder posterior_params: The flattened means of the posterior distributions p(z|ŷ,x) of the two bottleneck stages prior_params: The flattened means of the prior distributions p(z|ŷ) of the two bottleneck stages z_prior: The current samples of the two stages of the prior distributions of both two bottleneck stages, flattened """ # posterior_samples = {} # prior_samples = {} prior_params = {} posterior_params = {} z_prior = {} h = self.conv1x1(x_e[f"s{self.n_stages}_2"]) for i_s in range(self.n_stages, self.n_stages - 2, -1): stage = f"s{i_s}" spatial_size = x_e[stage + "_2"].shape[-1] h = self.blocks[stage + "_2"](h, x_e[stage + "_2"]) if spatial_size == 1: prior_params[stage] = x_e[stage + "_2"] posterior_params[stage] = x_f[stage + "_2"] prior_samples = self._latent_sample(prior_params[stage]) z_prior[stage] = torch.squeeze( torch.squeeze(prior_samples, dim=-1), dim=-1 ) posterior_samples = self._latent_sample(posterior_params[stage]) else: post_params = self.space_to_depth(x_f[stage + "_2"]) posterior_params[stage] = post_params if FLAGS.group_auto: if mode == "train" or mode == "appearance_transfer": posterior_samples = self._latent_sample(post_params) sec_size = posterior_samples.shape[1] // 4 posterior_sample_groups = torch.split( posterior_samples, [sec_size, sec_size, sec_size, sec_size], dim=1, ) posterior_samples = self.depth_to_space(posterior_samples) param_groups = [] sample_groups = [] param_features = self.auto_blocks[0](h) param_features = self.space_to_depth(param_features) # convert to fitting depth param_features = self.param_converter(param_features) for i_a in range(len(self.auto_blocks)): param_groups.append(param_features) # with torch.cuda.device(self.device): prior_samples = self._latent_sample(param_groups[-1]) sample_groups.append(prior_samples) if i_a + 1 < len(self.auto_blocks): if mode == "train" or mode == "appearance_transfer": feedback = posterior_sample_groups[i_a] else: feedback = prior_samples param_features = self.auto_blocks[i_a]( param_features, feedback ) pri_params = torch.cat(param_groups, dim=1) prior_samples = self.__merge_groups(sample_groups) else: pri_params = self.space_to_depth(x_e[stage + "_2"]) prior_samples = self.depth_to_space(self._latent_sample(pri_params)) posterior_samples = self.depth_to_space( self._latent_sample(post_params) ) prior_params[stage] = pri_params z_prior[stage] = ( self.space_to_depth(prior_samples).squeeze(dim=-1).squeeze(dim=-1) ) if mode == "train" or mode == "appearance_transfer": # training and appearance transfer: sample from posterior z = posterior_samples elif mode == "sample_appearance": # appearance sampling: sample from prior z = prior_samples else: raise ValueError( 'The \'mode\' parameter in VUnetBottleneck must be in ["train","appearance_transfer","sample_appearance"]' ) h = torch.cat([h, z], dim=1) h = self.channel_norm[stage](h) h = self.blocks[stage + "_1"](h, x_e[stage + "_1"]) if i_s == self.n_stages: h = self.up(h) # return h, prior_params, posterior_params, z_prior def __split_groups(self, x): # split along channel axis sec_size = x.shape[1] // 4 return torch.split( self.space_to_depth(x), [sec_size, sec_size, sec_size, sec_size], dim=1, ) def __merge_groups(self, x): # merge groups along channel axis return self.depth_to_space(torch.cat(x, dim=1)) def _latent_sample(self, mean): sample_mean = torch.squeeze(torch.squeeze(mean, dim=-1), dim=-1) sampled = MultivariateNormal( loc=torch.zeros_like(sample_mean, device=self.device), covariance_matrix=torch.eye(sample_mean.shape[-1], device=self.device), ).sample() return (sampled + sample_mean).unsqueeze(dim=-1).unsqueeze(dim=-1)
class VUnetDecoder(nn.Module): def __init__( self, n_stages, nf=128, nf_out=3, n_rnb=2, conv_layer=NormConv2d, spatial_size=256, final_act=True, dropout_prob=0.0, ): super().__init__() assert (2 ** (n_stages - 1)) == spatial_size self.final_act = final_act self.blocks = ModuleDict() self.ups = ModuleDict() self.n_stages = n_stages self.n_rnb = n_rnb for i_s in range(self.n_stages - 2, 0, -1): # for final stage, bisect number of filters if i_s == 1: # upsampling operations self.ups.update( { f"s{i_s+1}": Upsample( in_channels=nf, out_channels=nf // 2, conv_layer=conv_layer, ) } ) nf = nf // 2 else: # upsampling operations self.ups.update( { f"s{i_s+1}": Upsample( in_channels=nf, out_channels=nf, conv_layer=conv_layer, ) } ) # resnet blocks for ir in range(self.n_rnb, 0, -1): stage = f"s{i_s}_{ir}" self.blocks.update( { stage: VUnetResnetBlock( nf, use_skip=True, conv_layer=conv_layer, dropout_prob=dropout_prob, ) } ) # final 1x1 convolution self.final_layer = conv_layer(nf, nf_out, kernel_size=1) # conditionally: set final activation if self.final_act: self.final_act = nn.Tanh() def forward(self, x, skips): """ Parameters ---------- x : torch.Tensor Latent representation to decode. skips : dict The skip connections of the VUnet Returns ------- out : torch.Tensor An image as described by :attr:`x` and :attr:`skips` """ out = x for i_s in range(self.n_stages - 2, 0, -1): out = self.ups[f"s{i_s+1}"](out) for ir in range(self.n_rnb, 0, -1): stage = f"s{i_s}_{ir}" out = self.blocks[stage](out, skips[stage]) out = self.final_layer(out) if self.final_act: out = self.final_act(out) return out
class VUnetDecoder(nn.Module): def __init__(self, n_stages, nf=128, nf_out=3, n_rnb=2, conv_layer=NormConv2d): super().__init__() assert (2 ** (n_stages - 1)) == FLAGS.spatial_size self.blocks = ModuleDict() self.ups = ModuleDict() self.n_stages = n_stages self.n_rnb = n_rnb for i_s in range(self.n_stages - 2, 0, -1): # for final stage, bisect number of filters if i_s == 1: # upsampling operations self.ups.update( { f"s{i_s+1}": Upsample( in_channels=nf, out_channels=nf // 2, conv_layer=conv_layer, ) } ) nf = nf // 2 else: # upsampling operations self.ups.update( { f"s{i_s+1}": Upsample( in_channels=nf, out_channels=nf, conv_layer=conv_layer, ) } ) # resnet blocks for ir in range(self.n_rnb, 0, -1): stage = f"s{i_s}_{ir}" self.blocks.update( {stage: VUnetResnetBlock(nf, use_skip=True, conv_layer=conv_layer)} ) # final 1x1 convolution self.final_layer = conv_layer(nf, nf_out, kernel_size=1) # conditionally: set final activation if FLAGS.final_act: self.final_act = nn.Tanh() def forward(self, x, skips): """ :param x: :param skips: The skip connections of the VUnet :return: """ out = x for i_s in range(self.n_stages - 2, 0, -1): out = self.ups[f"s{i_s+1}"](out) for ir in range(self.n_rnb, 0, -1): stage = f"s{i_s}_{ir}" out = self.blocks[stage](out, skips[stage]) out = self.final_layer(out) if FLAGS.final_act: out = self.final_act(out) return out