Ejemplo n.º 1
0
    def compute_loss_for_batch(self,
                               data,
                               model,
                               K=K,
                               testing_mode=False,
                               alpha=alpha):
        # data = (B, 1, H, W)
        B, _, H, W = data.shape
        data_k_vec = data.repeat((1, K, 1, 1)).view(-1, H * W)
        mu, logstd = model.encode(data_k_vec)
        # (B*K, #latents)
        z = model.reparameterize(mu, logstd)

        # summing over latents due to independence assumption
        # (B*K)
        log_q = compute_log_probabitility_gaussian(z, mu, logstd)

        log_p_z = compute_log_probabitility_gaussian(
            z, torch.zeros_like(z, requires_grad=False),
            torch.zeros_like(z, requires_grad=False))
        decoded = model.decode(z)
        if discrete_data:
            log_p = compute_log_probabitility_bernoulli(decoded, data_k_vec)
        else:
            # Gaussian where sigma = 0, not letting sigma be predicted atm
            log_p = compute_log_probabitility_gaussian(
                decoded, data_k_vec, torch.zeros_like(decoded))
        # hopefully this reshape operation magically works like always
        if model_type == 'iwae' or testing_mode:
            log_w_matrix = (log_p_z + log_p - log_q).view(B, K)
        elif model_type == 'vae':
            # treat each sample for a given data point as you would treat all samples in the minibatch
            # 1/K value because loss values seemed off otherwise
            log_w_matrix = (log_p_z + log_p - log_q).view(B * K, 1) * 1 / K
        elif model_type == 'general_alpha' or model_type == 'vralpha':
            log_w_matrix = (log_p_z + log_p - log_q).view(-1, K) * (1 - alpha)
        elif model_type == 'vrmax':
            log_w_matrix = (log_p_z + log_p - log_q).view(-1, K).max(
                axis=1, keepdim=True).values

        log_w_minus_max = log_w_matrix - torch.max(
            log_w_matrix, 1, keepdim=True)[0]
        ws_matrix = torch.exp(log_w_minus_max)
        ws_norm = ws_matrix / torch.sum(ws_matrix, 1, keepdim=True)

        if model_type == 'vralpha' and not testing_mode:
            sample_dist = Multinomial(1, ws_norm)
            ws_sum_per_datapoint = log_w_matrix.gather(
                1,
                sample_dist.sample().argmax(1, keepdim=True))
        else:
            ws_sum_per_datapoint = torch.sum(log_w_matrix * ws_norm, 1)

        if model_type in ["general_alpha", "vralpha"] and not testing_mode:
            ws_sum_per_datapoint /= (1 - alpha)

        loss = -torch.sum(ws_sum_per_datapoint)

        return decoded, mu, logstd, loss
def adjust_batch(data, group=1):
    # Adjust batch size to multiple of group
    # data is of shape [N, C, H, W]
    batch = data.shape[0]
    if batch % group == 0:
        return data
    batch_new = int(math.ceil(batch / group * 1.0) * group)
    repeat_dims = [int(math.ceil(batch_new / batch * 1.0))] + [1] * (data.ndim - 1)
    return data.repeat(*repeat_dims)
Ejemplo n.º 3
0
 def animate(i):
     view = block_rot.get_view()
     # Set skip index to anything > 15 to make sure nothing is skipped
     x_mu = model.generate(x_real, v_real, view, 44, 99)
     data = x_mu.squeeze(0)
     data = data.repeat(3, 1, 1)
     data = data.permute(1, 2, 0)
     data = data.cpu().detach()
     im.set_data(data)
     return im
 def __getitem__(self,idx):
     img_path = os.path.join(self.root,self.imgs[idx][0])
     assert os.path.exists(img_path)
     data = Image.open(img_path).convert("RGB")
     if self.transform is not None:
         data = self.transform(data)
     label = self.imgs[idx][1] # int
     if data.shape[0]==1: # 某些图片的channel为1
         print(data.shape)
         data = data.repeat(3, 1, 1)
     return data,label
Ejemplo n.º 5
0
def run_fid(data, sample):
    assert data.max() <=1 and  data.min() >= 0
    assert sample.max() <=1 and  sample.min() >= 0
    data = 2*data - 1
    if data.shape[1] == 1:
        data = data.repeat(1,3,1,1)
    data = data.detach()      
    with torch.no_grad():
        iss, _, _, acts_real = inception_score(data, cuda=True, batch_size=32, resize=True, splits=10, return_preds=True)
    sample = 2*sample - 1
    if sample.shape[1] == 1:
        sample = sample.repeat(1,3,1,1)
    sample = sample.detach()

    with torch.no_grad():
        issf, _, _, acts_fake = inception_score(sample, cuda=True, batch_size=32, resize=True, splits=10, return_preds=True)
    # idxs_ = np.argsort(np.abs(acts_fake).sum(-1))[:1800] # filter the ones with super large values
    # acts_fake = acts_fake[idxs_]
    m1, s1 = calculate_activation_statistics(acts_real)
    m2, s2 = calculate_activation_statistics(acts_fake)
    fid_value = calculate_frechet_distance(m1, s1, m2, s2)
    return fid_value
Ejemplo n.º 6
0
half_label = Variable(torch.ones(batch_size) * 0.5).cuda()

check_points = 500
num_epoches = 100000
criterion = nn.BCELoss()
criterion_t = nn.TripletMarginLoss(p=1)
criterion_mse = nn.MSELoss()

epoch = 1

for epoch in range(10):
    # D PART ./ data_old/ d_f3 / error
    data, label = mm.batch_next(10, shuffle=False)
    data = torch.from_numpy(data.astype('float32'))
    data = data.repeat(batch_size // 10, 1, 1, 1)
    D.zero_grad()
    G.zero_grad()
    data_old = Variable(data).cuda()
    data_old.data.resize_(batch_size, 1, 28, 28)
    _, _, _, f3 = enet(data_old.detach())
    d_f3 = f3.cuda()

    g_sampler = torch.randn([batch_size, hidden_d, 1, 1])
    # g_sampler = g_sampler.repeat(10, 1, 1, 1)
    g_sampler = Variable(g_sampler).cuda()
    d_f3.data.resize_(batch_size, feature_size, 1, 1)
    d_f3 = d_f3.detach()
    d_f3_output = G(torch.cat([d_f3, g_sampler], 1)).detach()

    zeroinput = Variable(torch.zeros([batch_size, hidden_d, 1, 1])).cuda()
Ejemplo n.º 7
0
def get_noisy_data(data, targets):
    noisy_samples_np = np.random.rand(data.shape[0] * N_NOISY_SAMPLES_PER_TEST_SAMPLE * 28 * 28) * 1.0
    noisy_samples = torch.from_numpy(noisy_samples_np.reshape([-1,28,28]).astype(np.float32))
    noisy_data = data.repeat(N_NOISY_SAMPLES_PER_TEST_SAMPLE, 1, 1).float() + noisy_samples
    noisy_targets = targets.repeat(N_NOISY_SAMPLES_PER_TEST_SAMPLE)
    return noisy_data, noisy_targets
Ejemplo n.º 8
0
    def train_model(self, trainloader, epochs=100, test_every_x=4, epochs_per_level=50):

        

        #avgDiff = 0
        for epoch in range(self.start_epoch, epochs):

           self.model.train()

           #current_level = self.model.module.levels - epoch // epochs_per_level

           #urrent_level = max(1, current_level)

           #print("Current level:", current_level)
           #trainloader.shuffle()
           losses = []
           kld_fs = []
           kld_zs = []
           print("Running Epoch : {}".format(epoch+1))
           print(len(trainloader))


           lastDiff = 0
           #lastDiff = avgDiff
           avgDiff = 0

           loss_type = "levels"

           mse_loss = nn.MSELoss()

           for i, dataitem in tqdm(enumerate(trainloader, 1)):
               if i >= len(trainloader):
                break
               data, _ = dataitem
               data = data.cuda()

               if data.shape[1] == 1:
                data = data.repeat(1, 3, 1, 1)
                bs = data.shape[0]
                #data *= torch.rand([bs, 3, 1, 1]).cuda()

                #data[:, :, 10:20, 10:20] = 1.

                #plt.imshow(data[0].permute(1, 2, 0).cpu())
                #plt.show()
                #input()

               data = (data - 0.5) * 2

               data = Variable(data)

               self.optimizer.zero_grad()

               data_down = self.avgpool(data)

               data_noise = data_down + torch.randn(data_down.shape).cuda() * 0.1

               #plt.imshow(data[0].permute(1, 2, 0).cpu() / 2. + 0.5)
               #plt.show()
               #plt.imshow(data_noise[0].permute(1, 2, 0).cpu() / 2. + 0.5)
               #plt.show()
               #print(data_noise.shape)

               out = self.model.forward(data_noise)
               
               loss_mix = discretized_mix_logistic_loss(data_down, out)

               #norm = torch.randn(z.shape).cuda()

               #loss_z = mmd(z, norm) * 200000

               loss = loss_mix# + loss_z

               if i % 5 == 0:

                   bh = out.shape[0] // 8

                   out_sample = sample_from_logistic_mix(out[:bh])

                   out_final = self.model.forward(out_sample, final=True)

                   loss_mse = mse_loss(out_final, data[:bh])

                   loss = loss + loss_mse

               loss = loss.mean()

               if i % 10 == 0:
                 print(loss_mix, loss_mse)

               loss.backward()

               torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.1)

               self.optimizer.step()

               losses.append(loss.item())

           self.scheduler.step()

           meanloss = np.mean(losses)

           #avgDiff /= len(trainloader)
           #meanf = np.mean(kld_fs)
           #meanz = np.mean(kld_zs)
           #self.epoch_losses.append(meanloss)
           print("Epoch {} : Average Loss: {}".format(epoch+1, meanloss))

           #print("Disc. quality: {}".format(avgDiff))
           self.save_checkpoint(epoch)


           self.model.eval()
           _, (sample, _)  = next(enumerate(trainloader))
           #sample = torch.unsqueeze(sample,0)


           sample = sample.cuda()

           if sample.shape[1] == 1:
               sample = sample.repeat(1, 3, 1, 1)
               bs = sample.shape[0]
               sample *= torch.rand([bs, 3, 1, 1]).cuda()

           sample = (sample - 0.5) * 2

           if (epoch + 1) % test_every_x == 0:
            self.recon_frames(epoch+1, sample)
            self.sample_frames(epoch+1)
            #self.umap_codes(epoch+1, trainloader)
           self.model.train()
        print("Training is complete")
Ejemplo n.º 9
0
    def compute_loss_for_batch(self, data, model, K=K, test=False):
        # data = (B, 1, H, W)
        B, _, H, W = data.shape

        # Generate K copies of each observation. Each will get sampled once according to the generated distribution to generate a total of K observation samples
        data_k_vec = data.repeat((1, K, 1, 1)).view(-1, H * W)

        # Retrieve the estimated mean and log(standard deviation) estimates from the posterior approximator
        mu, logstd = model.encode(data_k_vec)

        # Use the reparametrization trick to generate (mean)+(epsilon)*(standard deviation) for each sample of each observation
        z = model.reparameterize(mu, logstd)

        # Calculate log q(z|x) - how likely are the importance samples given the distribution that generated them?
        log_q = compute_log_probabitility_gaussian(z, mu, logstd)

        # Calculate log p(z) - how likely are the importance samples under the prior N(0,1) assumption?
        log_p_z = compute_log_probabitility_gaussian(
            z, torch.zeros_like(z, requires_grad=False),
            torch.zeros_like(z, requires_grad=False))

        # Hand the samples to the decoder network and get a reconstruction of each sample.
        decoded = model.decode(z)

        # Calculate log p(x|z) with a bernoulli distribution - how likely are the recreations given the latents that generated them?
        log_p = compute_log_probabitility_bernoulli(decoded, data_k_vec)

        # Begin calculating L_alpha depending on the (a) model type, and (b) optimization method
        # log_p_z + log_p - log_q = log(p(z_i)p(x|z_i)/q(z_i|x)) = log(p(x,z_i)/q(z_i|x)) = L_VI
        #   (for each importance sample i out of K for each observation)
        if model_type == 'iwae' or test:
            # Re-order the entries so that each row holds the K importance samples for each observation
            log_w_matrix = (log_p_z + log_p - log_q).view(B, K)

        elif model_type == 'vae':
            # Don't reorder, and divide by K in anticipation of taking a batch sum of (1/K)*SUM(log(p(x,z)/q(z|x)))
            log_w_matrix = (log_p_z + log_p - log_q).view(B * K, 1) * 1 / K

        elif model_type == 'general_alpha' or model_type == 'vralpha':
            # Re-order the entries so that each row holds the K importance samples for each observation
            # Multiply by (1-alpha) because (1-alpha)* log(p(x,z_i)/q(z_i|x)) =  log([p(x,z_i)/q(z_i|x)]^(1-alpha))
            log_w_matrix = (log_p_z + log_p - log_q).view(-1, K) * (1 - alpha)

        elif model_type == 'vrmax':
            # Re-order the entries so that each row holds the K importance samples for each observation
            # Take the max in each row, representing the maximum-weighted sample
            log_w_matrix = (log_p_z + log_p - log_q).view(-1, K).max(
                axis=1, keepdim=True).values

            # immediately return loss = -sum(L_alpha) over each observation
            return -torch.sum(log_w_matrix)

        # Begin using the "max trick". Subtract the maximum log(*) sample value for each observation.
        # log_w_minus_max = log([p(z_i,x)/q(z_i|x)] / max([p(z_k,x)/q(z_k|x)]))
        log_w_minus_max = log_w_matrix - torch.max(
            log_w_matrix, 1, keepdim=True)[0]

        # Exponentiate so that each term is [p(z_i,x)/q(z_i|x)] / max([p(z_k,x)/q(z_k|x)]) (no log)
        ws_matrix = torch.exp(log_w_minus_max)

        # Calculate normalized weights in each row. Max denominators cancel out!
        # ws_norm = [p(z_i,x)/q(z_i|x)]/SUM([p(z_k,x)/q(z_k|x)])
        ws_norm = ws_matrix / torch.sum(ws_matrix, 1, keepdim=True)

        if model_type == 'vralpha' and not test:
            # If we're specifically using a VR-alpha model, we want to choose a sample to backprop according to the values in ws_norm above
            # So we make a distribution in each row
            sample_dist = Multinomial(1, ws_norm)

            # Then we choose a sample in each row acccording to this distribution
            ws_sum_per_datapoint = log_w_matrix.gather(
                1,
                sample_dist.sample().argmax(1, keepdim=True))
        else:
            # For any other model, we're taking the full sum at this point
            ws_sum_per_datapoint = torch.sum(log_w_matrix * ws_norm, 1)

        if model_type in ["general_alpha", "vralpha"] and not test:
            # For both VR-alpha and directly estimating L_alpha with a sum, we have to renormalize the sum with 1-alpha
            ws_sum_per_datapoint /= (1 - alpha)

        # Return a value of loss = -L_alpha as the batch sum.
        loss = -torch.sum(ws_sum_per_datapoint)

        return loss
Ejemplo n.º 10
0
    def compute_loss_for_batch(self, data, model, K=K, test=False):
        B, _, H, W = data.shape

        # First repeat the observations K times, representing the data as a flat (M*K, # of pixels)
        data_k_vec = data.repeat((1, K, 1, 1)).view(-1, H * W)

        # Encode the model and retrieve estimated distribution parameters mu and log(standard deviation) for each sample of each observation
        # z1 holds the latent samples generated at the first stochastic layer.
        mu, log_std, [x, z1] = self.encode(data_k_vec)

        # Sample from each observation's approximated latent distribution in each row (i.e. once for each of K importance samples, represented by rows)
        # (this uses the reparametrization trick!)
        z = model.reparameterize(mu, log_std)

        # Calculate Log p(z) (prior) - how likely are these values given the prior assumption N(0,1)?
        log_p_z = torch.sum(
            -0.5 * z**2, 1) - .5 * z.shape[1] * T.log(torch.tensor(2 * np.pi))

        # Calculate q (z | h1) - how likely are the generated output latent samples given the distributions they came from?
        log_qz_h1 = compute_log_probabitility_gaussian(z, mu, log_std)

        # Re-Generate the mu and log_std that generated the first-layer latents z1
        h1 = torch.tanh(self.fc1(x))
        h2 = torch.tanh(self.fc2(h1))
        mu, log_std = self.fc31(h2), self.fc32(h2)

        # Calculate log q(h1|x) - how likely are the first-stochastic-layer latents given the distributions they come from?
        log_qh1_x = compute_log_probabitility_gaussian(z1, mu, log_std)

        # Calculate the distribution parameters that generated the first-layer latents upon decoding
        h5 = torch.tanh(self.fc7(z))
        h6 = torch.tanh(self.fc8(h5))
        mu, log_std = self.fc81(h6), self.fc82(h6)

        # Calculate log p(h1|z) - how likely are the latents z1 under the parameters of the distribution here?
        #   (This directly encourages the decoder to learn the inverse of the map h1->z)
        log_ph1_z = compute_log_probabitility_gaussian(z1, mu, log_std)

        # Finally calculate the reconstructed image
        h7 = torch.tanh(self.fc9(z1))
        h8 = torch.tanh(self.fc10(h7))
        decoded = torch.sigmoid(self.fc11(h8))

        # calculate log p(x | h1) - how likely is the reconstruction given the latent samples that generated it?
        log_px_h1 = compute_log_probabitility_bernoulli(decoded, x)

        # Begin calculating L_alpha depending on the (a) model type, and (b) optimization method
        # log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 - log_qh1_x =
        #           log([p(z0_i)p(x|z1_i)p(z1_i|z0_i)]/[q(z0_i|z1_i)q(z1_i|x)]) = log(p(x,z0_i,z1_i)/q(z0_i,z1_i|x)) = L_VI
        #   (for each importance sample i out of K for each observation)
        # Note that if test==True then we're always using the IWAE objective!
        if model_type == 'iwae' or test == True:
            # Re-order the entries so that each row holds the K importance samples for each observation
            log_w_matrix = (log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 -
                            log_qh1_x).view(-1, K)

        elif model_type == 'vae':
            # Don't reorder, and divide by K in anticipation of taking a batch sum of (1/K)*SUM(log(p(x,z)/q(z|x)))
            log_w_matrix = (log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 -
                            log_qh1_x).view(-1, 1) * 1 / K
            return -torch.sum(log_w_matrix)

        elif model_type == 'general_alpha' or model_type == 'vralpha':
            # Re-order the entries so that each row holds the K importance samples for each observation
            # Multiply by (1-alpha) because (1-alpha)* log(p(x,z_i)/q(z_i|x)) =  log([p(x,z_i)/q(z_i|x)]^(1-alpha))
            log_w_matrix = (log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 -
                            log_qh1_x).view(-1, K) * (1 - self.alpha)

        elif model_type == 'vrmax':
            # Re-order the entries so that each row holds the K importance samples for each observation
            # Take the max in each row, representing the maximum-weighted sample, then immediately return batch sum loss -L_alpha
            log_w_matrix = (log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 -
                            log_qh1_x).view(-1, K).max(axis=1,
                                                       keepdim=True).values
            return -torch.sum(log_w_matrix)

        # Begin using the "max trick". Subtract the maximum log(*) sample value for each observation.
        # log_w_minus_max = log([p(z_i,x)/q(z_i|x)] / max([p(z_k,x)/q(z_k|x)]))
        log_w_minus_max = log_w_matrix - torch.max(
            log_w_matrix, 1, keepdim=True)[0]

        # Exponentiate so that each term is [p(z_i,x)/q(z_i|x)] / max([p(z_k,x)/q(z_k|x)]) (no log)
        ws_matrix = torch.exp(log_w_minus_max)

        # Calculate normalized weights in each row. Max denominators cancel out!
        # ws_norm = [p(z_i,x)/q(z_i|x)]/SUM([p(z_k,x)/q(z_k|x)])
        ws_norm = ws_matrix / torch.sum(ws_matrix, 1, keepdim=True)

        if model_type == 'vralpha' and not test:
            # If we're specifically using a VR-alpha model, we want to choose a sample to backprop according to the values in ws_norm above
            # So we make a distribution in each row
            sample_dist = Multinomial(1, ws_norm)

            # Then we choose a sample in each row acccording to this distribution
            ws_sum_per_datapoint = log_w_matrix.gather(
                1,
                sample_dist.sample().argmax(1, keepdim=True))
        else:
            # For any other model, we're taking the full sum at this point
            ws_sum_per_datapoint = torch.sum(log_w_matrix * ws_norm, 1)

        if model_type in ["general_alpha", "vralpha"] and not test:
            # For both VR-alpha and directly estimating L_alpha with a sum, we have to renormalize the sum with 1-alpha
            ws_sum_per_datapoint /= (1 - alpha)

        loss = -torch.sum(ws_sum_per_datapoint)

        return loss