예제 #1
0
    def reverse_flow(self, estimator, flow_embed, multgate, i_flow, input,
                     logdet):
        assert input.size(1) % 2 == 0

        flow_embed = flow_embed.expand(
            (input.shape[0], flow_embed.shape[1], flow_embed.shape[2],
             flow_embed.shape[3]))
        flow_embed, _ = self.actnorm_embed(
            flow_embed,
            reverse=False)  # NOTE: reverse=False is the correct usage for this

        # 1.coupling
        z1, z2 = split_feature(input, "split")
        if self.flow_coupling == "additive":
            out = estimator(flow_embed, multgate, z1)
            out = self.conv_proj(out)
            z2 = z2 - out
        elif self.flow_coupling == "affine":
            # h = self.block(z1)
            h = estimator(flow_embed, multgate, z1)
            h = self.conv_proj(h)
            shift, scale = split_feature(h, "cross")
            scale = torch.sigmoid(scale + 2.)
            z2 = z2 / scale
            z2 = z2 - shift
            logdet = -torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet
        z = torch.cat((z1, z2), dim=1)

        # 2. permute
        z, logdet = self.flow_permutation(z, logdet, True)

        # 3. actnorm
        z, logdet = self.actnorm(z, logdet=logdet, reverse=True)

        return z, logdet
예제 #2
0
    def normal_flow(self, estimator, flow_embed, multgate, i_flow, input,
                    logdet):
        assert input.size(1) % 2 == 0

        # 1. actnorm
        z, logdet = self.actnorm(input, logdet=logdet, reverse=False)
        # 1.1 also actnorm embed
        flow_embed = flow_embed.expand(
            (z.shape[0], flow_embed.shape[1], flow_embed.shape[2],
             flow_embed.shape[3]))
        flow_embed, _ = self.actnorm_embed(flow_embed)

        # 2. permute
        z, logdet = self.flow_permutation(z, logdet, False)

        # 3. coupling
        z1, z2 = split_feature(z, "split")
        if self.flow_coupling == "additive":
            z1_out = estimator(flow_embed, multgate, z1)
            z1_out = self.conv_proj(z1_out)
            z2 = z2 + z1_out
        elif self.flow_coupling == "affine":
            h = estimator(flow_embed, multgate, z1)
            h = self.conv_proj(h)
            shift, scale = split_feature(h, "cross")
            scale = torch.sigmoid(scale + 2.)
            z2 = z2 + shift
            z2 = z2 * scale
            logdet = torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet
        z = torch.cat((z1, z2), dim=1)

        return z, logdet
예제 #3
0
    def normal_flow(self, input, y_onehot, logdet):
        assert input.size(1) % 2 == 0
        # 1. actnorm
        z, logdet = self.actnorm(input, logdet=logdet, reverse=False)

        # 2. permute
        z, logdet = self.flow_permutation(z, logdet, False)

        # 3. coupling
        z1, z2 = split_feature(z, "split")
        if self.flow_coupling == "additive":
            if self.extra_condition:

                z2 = z2 + self.block((z1, y_onehot))
            else:

                z2 = z2 + self.block(z1)
        elif self.flow_coupling == "affine":
            if self.extra_condition:

                h = self.block((z1, y_onehot))
            else:

                h = self.block(z1)
            shift, scale = split_feature(h, "cross")
            scale = torch.sigmoid(scale + 2.)
            z2 = z2 + shift
            z2 = z2 * scale
            logdet = torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet
        z = torch.cat((z1, z2), dim=1)

        return z, logdet
예제 #4
0
    def normal_flow(self, x, y_onehot):
        b, c, h, w = x.shape

        x, logdet = uniform_binning_correction(x)

        z, objective = self.flow(x, logdet=logdet, reverse=False)  # z_size = C

        mean, logs = self.prior(
            x, y_onehot)  # if condition, mean_size = C//2; else, mean_size = C

        if self.y_condition:
            z_y, z_n = split_feature(z, "split")
            self.mean_normal = torch.zeros(z_y.shape).cuda()
            self.logs_normal = torch.ones(z_y.shape).cuda()
            y_logits = self.project_class(z_y.mean(2).mean(2))
            objective += gaussian_likelihood(
                mean, logs, z_y) + gaussian_likelihood(self.mean_normal,
                                                       self.logs_normal, z_n)
        else:
            objective += gaussian_likelihood(mean, logs, z)
            y_logits = None

        # Full objective - converted to bits per dimension
        bpd = (-objective) / (math.log(2.) * c * h * w)

        return z, bpd, y_logits
예제 #5
0
    def prior(self, data, y_onehot=None, batch_size=32):
        if data is not None:
            h = self.prior_h.repeat(data.shape[0], 1, 1, 1)
        else:
            # Hardcoded a batch size of 32 here
            h = self.prior_h.repeat(batch_size, 1, 1, 1)

        channels = h.size(1)

        if self.learn_top:
            h = self.learn_top_fn(h)

        if self.y_condition:
            assert y_onehot is not None
            yp = self.project_ycond(y_onehot)
            if data is not None:
                h += yp.view(data.shape[0], channels, 1, 1)
            else:
                print("no data")
                h += yp.view(batch_size, channels, 1, 1)
        if self.yd_condition:
            assert y_onehot is not None
            yp = self.project_ycond(y_onehot)
            if data is not None:
                h += yp.view(data.shape[0], channels, 1, 1)
            else:
                print("no data")
                h += yp.view(batch_size, channels, 1, 1)

        return split_feature(h, "split")
예제 #6
0
    def normal_flow(self, input, logdet):
        assert input.size(1) % 2 == 0
        # 1. actnorm
        if not self.no_actnorm:
            z, logdet = self.actnorm(input, logdet=logdet, reverse=False)
        else:
            z = input

        # 2. permute
        z, logdet = self.flow_permutation(z, logdet, False)

        # 3. coupling
        z1, z2 = split_feature(z, "split")
        if self.flow_coupling == "additive":
            z2 = z2 + self.block(z1)
        elif self.flow_coupling == "affine":
            h = self.block(z1)
            shift, scale = split_feature(h, "cross")
            scale = torch.sigmoid(scale +
                                  self.affine_scale_eps) + self.affine_eps
            z2 = z2 + shift
            z2 = z2 * scale
            logdet = torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet
        elif self.flow_coupling == "gaffine":
            h = self.block(z1)
            shift, scale = split_feature(h, "cross")
            scale = torch.exp(np.log(self.max_scale) * torch.tanh(scale))
            self.last_scale = scale
            z2 = z2 + shift
            z2 = z2 * scale
            logdet = torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet
        elif self.flow_coupling == "naffine":
            h = self.block(z1)
            shift, scale = split_feature(h, "cross")
            eps = self.affine_eps
            scale = (2 * torch.sigmoid(scale) - 1) * (1 - eps) + 1
            z2 = z2 + shift
            z2 = z2 * scale
            logdet = torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet
        z = torch.cat((z1, z2), dim=1)

        return z, logdet
예제 #7
0
 def forward(self, input, logdet=0., reverse=False, temperature=None):
     if reverse:
         z1 = input
         mean, logs = self.split2d_prior(z1)
         z2 = gaussian_sample(mean, logs, temperature)
         z = torch.cat((z1, z2), dim=1)
         return z, logdet
     else:
         z1, z2 = split_feature(input, "split")
         mean, logs = self.split2d_prior(z1)
         logdet = gaussian_likelihood(mean, logs, z2) + logdet
         return z1, logdet
    def reverse_flow(self, input, logdet):
        assert input.size(1) % 2 == 0

        # 1.coupling
        z1, z2 = split_feature(input, "split")
        if self.flow_coupling == "additive":
            z2 = z2 - self.block(z1)
        elif self.flow_coupling == "affine":
            h = self.block(z1)
            shift, scale = split_feature(h, "cross")
            scale = torch.sigmoid(scale + 2.)
            z2 = z2 / scale
            z2 = z2 - shift
            logdet = -torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet
        z = torch.cat((z1, z2), dim=1)

        # 2. permute
        z, logdet = self.flow_permutation(z, logdet, True)

        # 3. actnorm
        z, logdet = self.actnorm(z, logdet=logdet, reverse=True)

        return z, logdet
예제 #9
0
    def split2d_prior(self, z, condition):

        if self.sp_condition is False:
            # print('no split prior')
            h = self.conv(z)
        else:
            print('split prior')
            cond = self.cond_fc(condition)
            cond = cond.view(z.size())
            cond = self.cond_conv(cond)
            cond = F.relu(cond, inplace=False)
            z = torch.cat([z, cond], dim=1)
            h = self.conv(z)
        return split_feature(h, "cross")
예제 #10
0
    def normal_flow(self, estimator, flow_embed, i_flow, input, logdet):
        assert input.size(1) % 2 == 0

        # 1. actnorm
        z, logdet = self.actnorm(input, logdet=logdet, reverse=False)

        # 2. permute
        z, logdet = self.flow_permutation(z, logdet, False)

        # 3. coupling
        z1, z2 = split_feature(z, "split")
        if self.flow_coupling == "additive":
            z1_out = estimator(flow_embed, z1)
            z2 = z2 + z1_out
        elif self.flow_coupling == "affine":
            h = estimator(flow_embed, z1)
            shift, scale = split_feature(h, "cross")
            scale = torch.sigmoid(scale + 2.)
            z2 = z2 + shift
            z2 = z2 * scale
            logdet = torch.sum(torch.log(scale), dim=[1, 2, 3]) + logdet
        z = torch.cat((z1, z2), dim=1)

        return z, logdet
예제 #11
0
 def forward(self, input, zn, logdet=0., reverse=False, temperature=None):
     if reverse:
         # z1 = input
         z1 = input
         z2 = zn.pop()
         mean, logs = self.split2d_prior(z1)
         # z2 = gaussian_sample(mean, logs, temperature)
         z = torch.cat((z1, z2), dim=1)
         return z, logdet, zn
     else:
         z1, z2 = split_feature(input, "split")
         zn.append(z2.clone())
         mean, logs = self.split2d_prior(z1)
         logdet = gaussian_likelihood(mean, logs, z2) + logdet
         return z1, logdet, zn
예제 #12
0
    def prior(self, data, y_onehot=None):
        if data is not None:
            h = self.prior_h.repeat(data.shape[0], 1, 1, 1)
        else:
            # Hardcoded a batch size of 32 here
            h = self.prior_h.repeat(32, 1, 1, 1)

        channels = h.size(1)

        if self.learn_top:
            #             import pdb
            #             pdb.set_trace()
            h = self.learn_top_fn(h)

        if self.y_condition:
            assert y_onehot is not None
            y_onehot = y_onehot.float()
            yp = self.project_ycond(y_onehot)
            h += yp.view(data.shape[0], channels, 1, 1)

        return split_feature(h, "split")
예제 #13
0
 def forward(self, input, logdet=0., reverse=False, temperature=None):
     if reverse:
         z1 = input
         mean, logs = self.split2d_prior(z1)
         if self.use_last:
             self._last_z2.requires_grad_()
             z2 = self._last_z2
             self.use_last = False
         else:
             z2 = gaussian_sample(mean, logs, temperature)
             self._last_z2 = z2#.clone()
         z = torch.cat((z1, z2), dim=1)
         return z, logdet
     else:
         z1, z2 = split_feature(input, "split")
         self._last_z2 = z2.clone()
         mean, logs = self.split2d_prior(z1)
         d = gaussian_likelihood(mean, logs, z2)
         logdet = d + logdet
         self._last_logdet = d
         return z1, logdet
예제 #14
0
 def split2d_prior(self, z):
     h = self.conv(z)
     return split_feature(h, "cross")