def __init__(self, in_channels, hidden_channels, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, flow_embed_dim, level): super().__init__() self.flow_coupling = flow_coupling self.actnorm = ActNorm2d(in_channels, actnorm_scale) self.actnorm_embed = ActNorm2d(flow_embed_dim * (2**level), actnorm_scale) if flow_coupling == "additive": self.conv_proj = Conv2dZeros(hidden_channels, in_channels // 2, kernel_size=(3, 3)) elif flow_coupling == "affine": self.conv_proj = Conv2dZeros(hidden_channels, in_channels, kernel_size=(3, 3)) # 2. permute if flow_permutation == "invconv": self.invconv = InvertibleConv1x1(in_channels, LU_decomposed=LU_decomposed) self.flow_permutation = \ lambda z, logdet, rev: self.invconv(z, logdet, rev) elif flow_permutation == "shuffle": self.shuffle = Permute2d(in_channels, shuffle=True) self.flow_permutation = \ lambda z, logdet, rev: (self.shuffle(z, rev), logdet) else: self.reverse = Permute2d(in_channels, shuffle=False) self.flow_permutation = \ lambda z, logdet, rev: (self.reverse(z, rev), logdet) self.multgate = nn.Parameter(torch.zeros((6, hidden_channels, 1, 1)))
def __init__(self, image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, y_classes, learn_top, y_condition): super().__init__() self.flow = FlowNet(image_shape=image_shape, hidden_channels=hidden_channels, K=K, L=L, actnorm_scale=actnorm_scale, flow_permutation=flow_permutation, flow_coupling=flow_coupling, LU_decomposed=LU_decomposed) self.y_classes = y_classes self.y_condition = y_condition self.learn_top = learn_top # learned prior if learn_top: C = self.flow.output_shapes[-1][1] self.learn_top_fn = Conv2dZeros(C * 2, C * 2) if y_condition: C = self.flow.output_shapes[-1][1] self.project_ycond = LinearZeros(y_classes, 2 * C) self.project_class = LinearZeros(C, y_classes) self.register_buffer( "prior_h", torch.zeros([ 1, self.flow.output_shapes[-1][1] * 2, self.flow.output_shapes[-1][2], self.flow.output_shapes[-1][3] ]))
def __init__(self, in_channels, out_channels, hidden_channels): super().__init__() self.conv1 = Conv2d(in_channels, hidden_channels) self.relu1 = nn.ReLU(inplace=False) self.conv2 = Conv2d(hidden_channels, hidden_channels, kernel_size=(3, 3)) self.relu2 = nn.ReLU(inplace=False) self.conv3 = Conv2d(hidden_channels, hidden_channels, kernel_size=(3, 3)) self.relu3 = nn.ReLU(inplace=False) self.conv4 = Conv2d(hidden_channels, hidden_channels, kernel_size=(3, 3)) self.relu4 = nn.ReLU(inplace=False) self.conv5 = Conv2d(hidden_channels, hidden_channels, kernel_size=(3, 3)) self.relu5 = nn.ReLU(inplace=False) self.conv6 = Conv2d(hidden_channels, hidden_channels, kernel_size=(3, 3)) self.relu6 = nn.ReLU(inplace=False) self.conv_proj = Conv2dZeros(hidden_channels, out_channels, kernel_size=(1, 1))
def get_block(in_channels, out_channels, hidden_channels): block = nn.Sequential(Conv2d(in_channels, hidden_channels), nn.ReLU(inplace=False), Conv2d(hidden_channels, hidden_channels, kernel_size=(1, 1)), nn.ReLU(inplace=False), Conv2dZeros(hidden_channels, out_channels)) return block
def __init__(self, image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, y_classes, d_classes, learn_top, y_condition, extra_condition, sp_condition, d_condition, yd_condition): super().__init__() self.flow = FlowNet(image_shape=image_shape, hidden_channels=hidden_channels, K=K, L=L, actnorm_scale=actnorm_scale, flow_permutation=flow_permutation, flow_coupling=flow_coupling, LU_decomposed=LU_decomposed, extra_condition=extra_condition, sp_condition=sp_condition, num_classes=y_classes + d_classes) self.y_classes = y_classes if y_condition or d_condition: self.y_condition = True print("conditional version", self.y_condition) else: self.y_condition = False self.yd_condition = yd_condition self.learn_top = learn_top print("extra condtion", extra_condition) print("split prior condition", sp_condition) # learned prior if learn_top: C = self.flow.output_shapes[-1][1] self.learn_top_fn = Conv2dZeros(C * 2, C * 2) if self.y_condition: C = self.flow.output_shapes[-1][1] self.project_ycond = LinearZeros(y_classes + d_classes, 2 * C) self.project_class = LinearZeros(C, y_classes) self.project_domain = LinearZeros(C, d_classes) elif self.yd_condition: C = self.flow.output_shapes[-1][1] self.project_ycond = LinearZeros(y_classes + d_classes, 2 * C) self.project_class = LinearZeros(C, y_classes) self.project_domain = LinearZeros(C, d_classes) self.register_buffer( "prior_h", torch.zeros([ 1, self.flow.output_shapes[-1][1] * 2, self.flow.output_shapes[-1][2], self.flow.output_shapes[-1][3] ]))
def __init__(self, image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, flow_embed_dim, y_classes, learn_top, y_condition): super().__init__() self.flow = FlowNet(image_shape=image_shape, hidden_channels=hidden_channels, K=K, L=L, actnorm_scale=actnorm_scale, flow_permutation=flow_permutation, flow_coupling=flow_coupling, LU_decomposed=LU_decomposed, flow_embed_dim=flow_embed_dim) self.y_classes = y_classes self.y_condition = y_condition self.learn_top = learn_top # learned prior if learn_top: C = self.flow.output_shapes[-1][1] self.learn_top_fn = Conv2dZeros(C * 2, C * 2) if y_condition: C = self.flow.output_shapes[-1][1] self.project_ycond = LinearZeros(y_classes, 2 * C) self.project_class = LinearZeros(C, y_classes) self.register_buffer( "prior_h", torch.zeros([ 1, self.flow.output_shapes[-1][1] * 2, self.flow.output_shapes[-1][2], self.flow.output_shapes[-1][3] ])) self.num_param = sum(p.numel() for p in self.parameters() if p.requires_grad) print("num_param: {}".format(self.num_param))
def get_block(in_channels, out_channels, hidden_channels, sn=False, no_conv_actnorm=False): if sn: block = nn.Sequential( SpectralNormConv2d(in_channels, hidden_channels, 3, stride=1, padding=1, coeff=1), nn.ReLU(inplace=False), SpectralNormConv2d(hidden_channels, hidden_channels, 3, stride=1, padding=1, coeff=1), nn.ReLU(inplace=False), SpectralNormConv2d(hidden_channels, out_channels, 3, stride=1, padding=1, coeff=1)) else: block = nn.Sequential( Conv2d(in_channels, hidden_channels, do_actnorm=not no_conv_actnorm), nn.ReLU(inplace=False), Conv2d(hidden_channels, hidden_channels, kernel_size=(1, 1), do_actnorm=not no_conv_actnorm), nn.ReLU(inplace=False), Conv2dZeros(hidden_channels, out_channels)) return block