コード例 #1
0
def train(epoch):
    print("train")
    model.train()
    train_loss = 0
    dataroot = random.sample(TrainingData, 1)[0]

    dataset = MultipieLoader.FareMultipieExpressionTripletsFrontal(
        opt, root=dataroot, resize=64)
    print('# size of the current (sub)dataset is %d' % len(dataset))
    train_amount = train_amount + len(dataset)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batchSize,
                                             shuffle=True,
                                             num_workers=int(opt.workers))
    for batch_idx, data_point in enumerate(dataloader, 0):

        gc.collect()  # collect garbage
        # sample the data points:
        # dp0_img: image of data point 0
        # dp9_img: image of data point 9, which is different in ``expression'' compare to dp0
        # dp1_img: image of data point 1, which is different in ``person'' compare to dp0
        dp0_img, dp9_img, dp1_img = data_point
        dp0_img, dp9_img, dp1_img = parseSampledDataTripletMultipie(
            dp0_img, dp9_img, dp1_img)
        if args.cuda:
            dp0_img, dp9_img, dp1_img = setCuda(dp0_img, dp9_img, dp1_img)
        dp0_img, dp9_img, dp1_img = setAsVariable(dp0_img, dp9_img, dp1_img)

        optimizer.zero_grad()
        recon_batch_per0, mu_per0, logvar_per0, recon_batch_ex0, mu_ex0, logvar_ex0 = model(
            dp0_img)
        recon_batch_per9, mu_per9, logvar_per9, recon_batch_ex9, mu_ex9, logvar_ex9 = model(
            dp9_img)
        recon_batch_per1, mu_per1, logvar_per1, recon_batch_ex1, mu_ex1, logvar_ex1 = model(
            dp0_img)

        loss = loss_function(recon_batch, data_point, mu, logvar)
        loss.backward()
        train_loss += loss.data[0]
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), (len(train_loader) * 64),
                100. * batch_idx / len(train_loader),
                loss.data[0] / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / (len(train_loader) * 64)))
    return train_loss / (len(train_loader) * 64)
コード例 #2
0
def getTestingSample():
    dataroot = random.sample(TestingData, 1)[0]
    dataset = MultipieLoader.FareMultipieLightingTripletsFrontal(opt,
                                                                 root=dataroot,
                                                                 resize=64)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batchSize,
                                             shuffle=True,
                                             num_workers=int(opt.workers))
    for batch_idx, data_point in enumerate(dataloader, 0):
        dp0_img, dp9_img, dp1_img = data_point
        dp0_img, dp9_img, dp1_img = parseSampledDataTripletMultipie(
            dp0_img, dp9_img, dp1_img)
        if opt.cuda:
            dp0_img, dp9_img, dp1_img = setCuda(dp0_img, dp9_img, dp1_img)
        dp0_img, dp9_img, dp1_img = setAsVariable(dp0_img, dp9_img, dp1_img)
        return dp0_img
コード例 #3
0
def test(epoch):
    print("test")
    model.eval()
    recon_test_loss = 0
    siamese_test_loss = 0
    expression_test_loss = 0
    dataroot = random.sample(TestingData, 1)[0]

    dataset = MultipieLoader.FareMultipieExpressionTripletsFrontal(
        opt, root=dataroot, resize=64)
    print('# size of the current (sub)dataset is %d' % len(dataset))
    # train_amount = train_amount + len(dataset)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batchSize,
                                             shuffle=True,
                                             num_workers=int(opt.workers))
    for batch_idx, data_point in enumerate(dataloader, 0):
        gc.collect()  # collect garbage
        # sample the data points:
        dp0_img, dp9_img, dp1_img, dp0_ide, dp9_ide, dp1_ide = data_point
        dp0_img, dp9_img, dp1_img = parseSampledDataTripletMultipie(
            dp0_img, dp9_img, dp1_img)
        if opt.cuda:
            dp0_img, dp9_img, dp1_img = setCuda(dp0_img, dp9_img, dp1_img)
        dp0_img, dp9_img, dp1_img = setAsVariable(dp0_img, dp9_img, dp1_img)

        z_dp9, z_per_dp9, z_exp_dp9 = model.get_latent_vectors(dp9_img)
        z_dp1, z_per_dp1, z_exp_dp1 = model.get_latent_vectors(dp1_img)

        recon_batch_dp0, z_dp0, z_per_dp0, z_exp_dp0 = model(dp0_img)

        # test disentangling

        z_per0_exp9 = torch.cat((z_per_dp0, z_exp_dp9),
                                dim=1)  # should be person 0 with expression 9
        recon_per0_exp9 = model.decode(z_per0_exp9)

        visualizeAsImages(recon_per0_exp9.data.clone(),
                          opt.dirImageoutput,
                          filename='epoch_' + str(epoch) + '_per0_exp9',
                          n_sample=18,
                          nrow=5,
                          normalize=False)

        z_per0_exp1 = torch.cat(
            (z_per_dp0, z_exp_dp1), dim=1
        )  # should look the same as dp0_img (exp1 and exp0 are the same)
        recon_per0_exp1 = model.decode(z_per0_exp1)

        visualizeAsImages(recon_per0_exp1.data.clone(),
                          opt.dirImageoutput,
                          filename='epoch_' + str(epoch) + '_per0_exp1',
                          n_sample=18,
                          nrow=5,
                          normalize=False)

        # calc reconstruction loss (dp0 only)

        recon_loss = recon_loss_func(recon_batch_dp0, dp0_img)
        recon_test_loss += recon_loss.data[0].item()

        # calc siamese loss

        sim_loss = siamese_loss_func(z_per_dp0, z_per_dp9,
                                     1) + siamese_loss_func(
                                         z_exp_dp0, z_exp_dp1, 1)  # similarity
        dis_loss = siamese_loss_func(
            z_exp_dp0, z_exp_dp9, -1) + siamese_loss_func(
                z_per_dp0, z_per_dp1, -1)  # dissimilarity
        siamese_loss = sim_loss + dis_loss

        siamese_test_loss = siamese_loss.data[0].item()

        # BCE expression loss

        smile_target = torch.ones(z_exp_dp0.size()).cuda()
        neutral_target = torch.zeros(z_exp_dp0.size()).cuda()

        if dp0_ide == '01':  #neutral
            expression_loss = BCE(z_exp_dp0, neutral_target)
        else:  #smile
            expression_loss = BCE(z_exp_dp0, smile_target)

        if dp9_ide == '01':  #neutral
            expression_loss = expression_loss + BCE(z_exp_dp9, neutral_target)
        else:  #smile
            expression_loss = expression_loss + BCE(z_exp_dp9, smile_target)

        if dp1_ide == '01':  #neutral
            expression_loss = expression_loss + BCE(z_exp_dp1, neutral_target)
        else:  #smile
            expression_loss = expression_loss + BCE(z_exp_dp1, smile_target)

        expression_test_loss += expression_loss[0].item()

    print(
        '====> Test set recon loss: {:.4f}\tSiamese loss:  {:.4f}\t Exp loss:'.
        format(recon_test_loss / (opt.batchSize * len(dataloader)),
               siamese_test_loss / (opt.batchSize * len(dataloader)),
               expression_test_loss / (opt.batchSize * len(dataloader))))
コード例 #4
0
def train(epoch):
    print("train")
    model.train()
    recon_train_loss = 0
    siamese_train_loss = 0
    expression_train_loss = 0
    dataroot = random.sample(TrainingData, 1)[0]

    dataset = MultipieLoader.FareMultipieExpressionTripletsFrontal(
        opt, root=dataroot, resize=64)
    print('# size of the current (sub)dataset is %d' % len(dataset))
    #   train_amount = train_amount + len(dataset)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batchSize,
                                             shuffle=True,
                                             num_workers=int(opt.workers))

    for batch_idx, data_point in enumerate(dataloader, 0):

        gc.collect()  # collect garbage
        # sample the data points:
        # dp0_img: image of data point 0
        # dp9_img: image of data point 9, which is different in ``expression'' compare to dp0
        # dp1_img: image of data point 1, which is different in ``person'' compare to dp0
        dp0_img, dp9_img, dp1_img, dp0_ide, dp9_ide, dp1_ide = data_point
        dp0_img, dp9_img, dp1_img = parseSampledDataTripletMultipie(
            dp0_img, dp9_img, dp1_img)
        if opt.cuda:
            dp0_img, dp9_img, dp1_img = setCuda(dp0_img, dp9_img, dp1_img)
        dp0_img, dp9_img, dp1_img = setAsVariable(dp0_img, dp9_img, dp1_img)

        z_dp9, z_per_dp9, z_exp_dp9 = model.get_latent_vectors(dp9_img)
        z_dp1, z_per_dp1, z_exp_dp1 = model.get_latent_vectors(dp1_img)

        optimizer.zero_grad()
        model.zero_grad()

        recon_batch_dp0, z_dp0, z_per_dp0, z_exp_dp0 = model(dp0_img)

        # calc reconstruction loss (dp0 only)

        recon_loss = recon_loss_func(recon_batch_dp0, dp0_img)
        optimizer.zero_grad()
        recon_loss.backward(retain_graph=True)
        recon_train_loss += recon_loss.data[0].item()

        # calc siamese loss

        sim_loss = siamese_loss_func(z_per_dp0, z_per_dp9,
                                     1) + siamese_loss_func(
                                         z_exp_dp0, z_exp_dp1, 1)  # similarity
        dis_loss = siamese_loss_func(
            z_exp_dp0, z_exp_dp9, -1) + siamese_loss_func(
                z_per_dp0, z_per_dp1, -1)  # dissimilarity
        siamese_loss = sim_loss + dis_loss

        siamese_loss.backward(retain_graph=True)
        siamese_train_loss += siamese_loss.data[0].item()

        # BCE expression loss

        smile_target = torch.ones(z_exp_dp0.size()).cuda()
        neutral_target = torch.zeros(z_exp_dp0.size()).cuda()

        if dp0_ide == '01':  #neutral
            expression_loss = BCE(z_exp_dp0, neutral_target)
        else:  #smile
            expression_loss = BCE(z_exp_dp0, smile_target)

        if dp9_ide == '01':  #neutral
            expression_loss = expression_loss + BCE(z_exp_dp9, neutral_target)
        else:  #smile
            expression_loss = expression_loss + BCE(z_exp_dp9, smile_target)

        if dp1_ide == '01':  #neutral
            expression_loss = expression_loss + BCE(z_exp_dp1, neutral_target)
        else:  #smile
            expression_loss = expression_loss + BCE(z_exp_dp1, smile_target)

        expression_loss.backward()
        expression_train_loss += expression_loss[0].item()

        optimizer.step()
        print(
            'Train Epoch: {} [{}/{} ({:.0f}%)]\tReconLoss: {:.6f}\tSimLoss: {:.6f}\tDisLoss: {:.6f}\tExpLoss: {:.6f}'
            .format(epoch, batch_idx * opt.batchSize,
                    (len(dataloader) * opt.batchSize),
                    100. * batch_idx / len(dataloader),
                    recon_loss.data[0].item() / opt.batchSize,
                    sim_loss.data[0].item() / opt.batchSize,
                    dis_loss.data[0].item() / opt.batchSize,
                    expression_loss[0].item() / opt.batchSize))
        #loss is calculated for each img, so divide by batch size to get average loss for the batch

    lossfile.write('Epoch: {} Recon: {:.4f}\n'.format(
        epoch, recon_train_loss / (len(dataloader) * opt.batchSize)))
    lossfile.write('Epoch: {} SiameseSim: {:.4f} SiameseDis: {:.4f}\n'.format(
        epoch, sim_loss.data[0].item() / opt.batchSize,
        dis_loss.data[0].item() / opt.batchSize))
    lossfile.write('Epoch: {} Expression: {:.4f}\n'.format(
        epoch, expression_train_loss / (len(dataloader) * opt.batchSize)))

    print(
        '====> Epoch: {} Avg recon loss: {:.4f} Avg siamese loss: {:.4f} Avg exp loss: {:.4f}'
        .format(epoch, recon_train_loss / (len(dataloader) * opt.batchSize),
                siamese_train_loss / (len(dataloader) * opt.batchSize),
                expression_train_loss / (len(dataloader) * opt.batchSize)))
    #divide by (batch_size * num_batches) to get loss for the epoch

    #data
    visualizeAsImages(dp0_img.data.clone(),
                      opt.dirImageoutput,
                      filename='epoch_' + str(epoch) + '_img0',
                      n_sample=18,
                      nrow=5,
                      normalize=False)
    visualizeAsImages(dp9_img.data.clone(),
                      opt.dirImageoutput,
                      filename='epoch_' + str(epoch) + '_img9',
                      n_sample=18,
                      nrow=5,
                      normalize=False)
    visualizeAsImages(dp1_img.data.clone(),
                      opt.dirImageoutput,
                      filename='epoch_' + str(epoch) + '_img1',
                      n_sample=18,
                      nrow=5,
                      normalize=False)

    #reconstruction (dp0 only)
    visualizeAsImages(recon_batch_dp0.data.clone(),
                      opt.dirImageoutput,
                      filename='epoch_' + str(epoch) + '_recon0',
                      n_sample=18,
                      nrow=5,
                      normalize=False)

    print('Data and reconstructions saved.')

    return recon_train_loss / (len(dataloader) *
                               opt.batchSize), siamese_train_loss / (
                                   len(dataloader) * opt.batchSize)
コード例 #5
0
def test(epoch):
    print("test")
    model.eval()
    recon_test_loss = 0
    cosine_test_loss = 0
    triplet_test_loss = 0
    dataroot = random.sample(Data, 1)[0]

    dataset = MultipieLoader.FareMultipieExpressionTripletsFrontalTrainTestSplit(
        opt, root=dataroot, resize=64)
    print('# size of the current (sub)dataset is %d' % len(dataset))
    # train_amount = train_amount + len(dataset)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batchSize,
                                             shuffle=True,
                                             num_workers=int(opt.workers))
    for batch_idx, data_point in enumerate(dataloader, 0):
        gc.collect()  # collect garbage

        dp0_img, dp9_img, dp1_img, dp0_ide, dp9_ide, dp1_ide = data_point
        dp0_img, dp9_img, dp1_img = parseSampledDataTripletMultipie(
            dp0_img, dp9_img, dp1_img)
        if opt.cuda:
            dp0_img, dp9_img, dp1_img = setCuda(dp0_img, dp9_img, dp1_img)
        dp0_img, dp9_img, dp1_img = setAsVariable(dp0_img, dp9_img, dp1_img)

        z_dp9, z_per_dp9, z_exp_dp9 = model.get_latent_vectors(dp9_img)
        z_dp1, z_per_dp1, z_exp_dp1 = model.get_latent_vectors(dp1_img)

        optimizer.zero_grad()
        model.zero_grad()

        recon_batch_dp0, z_dp0, z_per_dp0, z_exp_dp0 = model(dp0_img)

        # save test images

        visualizeAsImages(dp0_img.data.clone(),
                          opt.dirImageoutput,
                          filename='e_' + str(epoch) + '_test_img0',
                          n_sample=18,
                          nrow=5,
                          normalize=False)

        visualizeAsImages(dp9_img.data.clone(),
                          opt.dirImageoutput,
                          filename='e_' + str(epoch) + '_test_img9',
                          n_sample=18,
                          nrow=5,
                          normalize=False)

        visualizeAsImages(dp1_img.data.clone(),
                          opt.dirImageoutput,
                          filename='e_' + str(epoch) + '_test_img1',
                          n_sample=18,
                          nrow=5,
                          normalize=False)

        # test disentangling

        z_per0_exp9 = torch.cat((z_per_dp0, z_exp_dp9),
                                dim=1)  # should be person 0 with expression 9
        recon_per0_exp9 = model.decode(z_per0_exp9)

        visualizeAsImages(recon_per0_exp9.data.clone(),
                          opt.dirImageoutput,
                          filename='e_' + str(epoch) + '_test_per0_exp9',
                          n_sample=18,
                          nrow=5,
                          normalize=False)

        z_per0_exp1 = torch.cat(
            (z_per_dp0, z_exp_dp1), dim=1
        )  # should look the same as dp0_img (exp1 and exp0 are the same)
        recon_per0_exp1 = model.decode(z_per0_exp1)

        visualizeAsImages(recon_per0_exp1.data.clone(),
                          opt.dirImageoutput,
                          filename='e_' + str(epoch) + '_test_per0_exp1',
                          n_sample=18,
                          nrow=5,
                          normalize=False)

        z_per1_exp9 = torch.cat((z_per_dp1, z_exp_dp9),
                                dim=1)  # should be unique
        recon_per1_exp9 = model.decode(z_per1_exp9)

        visualizeAsImages(recon_per1_exp9.data.clone(),
                          opt.dirImageoutput,
                          filename='e_' + str(epoch) + '_test_per1_exp9',
                          n_sample=18,
                          nrow=5,
                          normalize=False)

        # calc reconstruction loss (dp0 only)

        recon_loss = recon_loss_func(recon_batch_dp0, dp0_img)
        optimizer.zero_grad()
        recon_test_loss += recon_loss.data[0].item()

        # calc cosine loss

        sim_loss = cosine_loss_func(z_per_dp0, z_per_dp9,
                                    1) + cosine_loss_func(
                                        z_exp_dp0, z_exp_dp1, 1)  # similarity

        cosine_test_loss = sim_loss.data[0].item()

        # calc L1 loss

        L1_loss = L1(z_per_dp9, z_per_dp0) + L1(z_exp_dp1, z_exp_dp0)

        # calc triplet loss

        triplet_loss = triplet_loss_func(z_per_dp0, z_per_dp9,
                                         z_per_dp1) + triplet_loss_func(
                                             z_exp_dp0, z_exp_dp1, z_exp_dp9)
        # triplet(anchor, positive, negative)
        triplet_test_loss = triplet_loss.data[0].item()

    print('Test images saved')
    print('====> Test set recon loss: {:.4f}\ttriplet loss:  {:.4f}'.format(
        recon_test_loss, triplet_test_loss))
コード例 #6
0
def train(epoch):
    print("train")
    model.train()
    recon_train_loss = 0
    cosine_train_loss = 0
    triplet_train_loss = 0
    swap_train_loss = 0
    expression_train_loss = 0
    dataroot = random.sample(Data, 1)[0]

    dataset = MultipieLoader.FareMultipieExpressionTripletsFrontalTrainTestSplit(
        opt, root=dataroot, resize=64)
    print('# size of the current (sub)dataset is %d' % len(dataset))
    #   train_amount = train_amount + len(dataset)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batchSize,
                                             shuffle=True,
                                             num_workers=int(opt.workers))

    for batch_idx, data_point in enumerate(dataloader, 0):

        gc.collect()  # collect garbage
        # sample the data points:
        # dp0_img: image of data point 0
        # dp9_img: image of data point 9, which is different in ``expression'' compare to dp0, same person as dp0
        # dp1_img: image of data point 1, which is different in ``person'' compare to dp0, same expression as dp0
        dp0_img, dp9_img, dp1_img, dp0_ide, dp9_ide, dp1_ide = data_point
        dp0_img, dp9_img, dp1_img = parseSampledDataTripletMultipie(
            dp0_img, dp9_img, dp1_img)
        if opt.cuda:
            dp0_img, dp9_img, dp1_img = setCuda(dp0_img, dp9_img, dp1_img)
        dp0_img, dp9_img, dp1_img = setAsVariable(dp0_img, dp9_img, dp1_img)

        z_dp9, z_per_dp9, z_exp_dp9 = model.get_latent_vectors(dp9_img)
        z_dp1, z_per_dp1, z_exp_dp1 = model.get_latent_vectors(dp1_img)

        recon_batch_dp0, z_dp0, z_per_dp0, z_exp_dp0 = model(dp0_img)

        # calc reconstruction loss (dp0 only)

        recon_loss = recon_loss_func(recon_batch_dp0, dp0_img)

        optimizer.zero_grad()
        model.zero_grad()

        recon_loss.backward(retain_graph=True)
        recon_train_loss += recon_loss.data[0].item()

        # calc cosine similarity loss

        sim_loss = cosine_loss_func(z_per_dp0, z_per_dp9,
                                    1) + cosine_loss_func(
                                        z_exp_dp0, z_exp_dp1, 1)  # similarity
        #dis_loss = cosine_loss_func(z_exp_dp0, z_exp_dp9, -1) + cosine_loss_func(z_per_dp0, z_per_dp1, -1) # dissimilarity

        cosine_train_loss += sim_loss.data[0].item()

        # calc L1 loss

        L1_loss = L1(z_per_dp9, z_per_dp0) + L1(z_exp_dp1, z_exp_dp0)

        # calc triplet loss

        triplet_loss = triplet_loss_func(z_per_dp0, z_per_dp9,
                                         z_per_dp1) + triplet_loss_func(
                                             z_exp_dp0, z_exp_dp1, z_exp_dp9)
        # triplet(anchor, positive, negative)

        triplet_train_loss += triplet_loss.data[0].item()

        # BCE expression loss

        smile_target = torch.ones(z_exp_dp0.size()).cuda()
        neutral_target = torch.zeros(z_exp_dp0.size()).cuda()

        if dp0_ide == '01':  #neutral
            expression_loss = BCE(z_exp_dp0, neutral_target)
        else:  #smile
            expression_loss = BCE(z_exp_dp0, smile_target)

        if dp9_ide == '01':  #neutral
            expression_loss = expression_loss + BCE(z_exp_dp9, neutral_target)
        else:  #smile
            expression_loss = expression_loss + BCE(z_exp_dp9, smile_target)

        if dp1_ide == '01':  #neutral
            expression_loss = expression_loss + BCE(z_exp_dp1, neutral_target)
        else:  #smile
            expression_loss = expression_loss + BCE(z_exp_dp1, smile_target)

        expression_train_loss += expression_loss[0].item()

        # calc gradients for all losses except swap

        losses = L1_loss + sim_loss + triplet_loss + expression_loss
        losses.backward(retain_graph=True)

        # calc swapping loss

        z_per0_exp9 = torch.cat(
            (z_per_dp0, z_exp_dp9),
            dim=1)  # should be equal to img9 (per0 and per9 are the same)
        recon_per0_exp9 = model.decode(z_per0_exp9)

        z_per0_exp1 = torch.cat(
            (z_per_dp0, z_exp_dp1),
            dim=1)  # should be equal to img0 (exp1 and exp0 are the same)
        recon_per0_exp1 = model.decode(z_per0_exp1)

        z_per9_exp0 = torch.cat((z_per_dp9, z_exp_dp0),
                                dim=1)  # should be equal to img0
        recon_per9_exp0 = model.decode(z_per9_exp0)

        z_per1_exp0 = torch.cat((z_per_dp1, z_exp_dp0),
                                dim=1)  # should be equal to img1
        recon_per1_exp0 = model.decode(z_per1_exp0)

        swap_loss1 = recon_loss_func(recon_per0_exp9, dp9_img)
        swap_loss1.backward(retain_graph=True)

        swap_loss2 = recon_loss_func(recon_per0_exp1, dp0_img)
        swap_loss2.backward(retain_graph=True)

        swap_loss3 = recon_loss_func(recon_per9_exp0, dp0_img)
        swap_loss3.backward(retain_graph=True)

        swap_loss4 = recon_loss_func(recon_per1_exp0, dp1_img)
        swap_loss4.backward()

        swap_loss = swap_loss1 + swap_loss2 + swap_loss3 + swap_loss4

        swap_train_loss += swap_loss.data[0].item()

        optimizer.step()
        print(
            'Train Epoch: {} [{}/{} ({:.0f}%)] Recon: {:.6f} Cosine: {:.6f} Triplet: {:.6f} Swap: {:.6f}'
            .format(epoch, batch_idx * opt.batchSize,
                    (len(dataloader) * opt.batchSize),
                    100. * batch_idx / len(dataloader),
                    recon_loss.data[0].item(), sim_loss.data[0].item(),
                    triplet_loss.data[0].item(), swap_loss.data[0].item()))
        #loss is calculated for each img, so divide by batch size to get loss for the batch

    lossfile.write('Epoch:{} Recon:{:.6f} Swap:{:.6f} ExpLoss:{:.6f}\n'.format(
        epoch, recon_train_loss, swap_train_loss, expression_train_loss))
    lossfile.write('Epoch:{} cosineSim:{:.6f} triplet:{:.6f}\n'.format(
        epoch, cosine_train_loss, triplet_train_loss))

    print(
        '====> Epoch: {} Average recon loss: {:.6f} Average cosine loss: {:.6f} Average triplet: {:.6f} Average swap: {:.6f}'
        .format(epoch, recon_train_loss, cosine_train_loss, triplet_train_loss,
                swap_train_loss))
    #divide by (batch_size * num_batches) to get average loss for the epoch

    #data
    visualizeAsImages(dp0_img.data.clone(),
                      opt.dirImageoutput,
                      filename='e_' + str(epoch) + '_train_img0',
                      n_sample=18,
                      nrow=5,
                      normalize=False)
    visualizeAsImages(dp9_img.data.clone(),
                      opt.dirImageoutput,
                      filename='e_' + str(epoch) + '_train_img9',
                      n_sample=18,
                      nrow=5,
                      normalize=False)
    visualizeAsImages(dp1_img.data.clone(),
                      opt.dirImageoutput,
                      filename='e_' + str(epoch) + '_train_img1',
                      n_sample=18,
                      nrow=5,
                      normalize=False)

    #reconstruction (dp0 only)
    visualizeAsImages(recon_batch_dp0.data.clone(),
                      opt.dirImageoutput,
                      filename='e_' + str(epoch) + '_train_recon0',
                      n_sample=18,
                      nrow=5,
                      normalize=False)

    print('Train data and reconstruction saved.')

    return recon_train_loss / (len(dataloader) *
                               opt.batchSize), triplet_train_loss / (
                                   len(dataloader) * opt.batchSize)
コード例 #7
0
                   'real/multipie_select_batches/session01_select_test/')

# ------------ training ------------ #
doTraining = True
doTesting = False
iter_mark = 0
for epoch in range(opt.epoch_iter):
    if doTraining:
        train_loss = 0
        train_amount = 0 + 1e-6
        gc.collect()  # collect garbage
        for subprocid in range(10):
            # random sample a dataroot
            dataroot = random.sample(TrainingData, 1)[0]
            aaaa = 0
            dataset = MultipieLoader.FareMultipieExpressionTripletsFrontal(
                opt, root=dataroot, resize=64)
            print('# size of the current (sub)dataset is %d' % len(dataset))
            train_amount = train_amount + len(dataset)
            dataloader = torch.utils.data.DataLoader(dataset,
                                                     batch_size=opt.batchSize,
                                                     shuffle=True,
                                                     num_workers=int(
                                                         opt.workers))
            for batch_idx, data_point in enumerate(dataloader, 0):

                aaaa += 1
                if aaaa > 20:
                    break

                gc.collect()  # collect garbage
                # sample the data points:
コード例 #8
0
def testModel():
    def parse_image(dp0_img):
        print("before ========= ")
        print(dp0_img.size())
        dp0_img = dp0_img.float(
        ) / 255  # convert to float and rerange to [0,1]
        dp0_img = dp0_img.permute(
            0, 3, 1,
            2).contiguous()  # reshape to [batch_size, 3, img_H, img_W]
        # dp0_img_1 = np.array(dp0_img).reshape(64,64,3)
        print("after ========= ")
        print(dp0_img.size())
        return dp0_img

    def save_images(imgs, path):
        if not os.path.isdir(path):
            os.mkdir(path)
        c = 0
        for i in imgs:
            img = Image.open(i)
            img = img.convert('RGB')
            img = img.resize((64, 64), Image.ANTIALIAS)
            # img0 = np.array(img0)
            img.save(path + str(c) + ".png")
            c = c + 1

    def load_images(path, batch_size):
        ret = torch.ones(5, 64, 64, 3)
        for i in range(batch_size):
            img = Image.open(path + str(i) + ".png")
            img = img.convert('RGB')
            img = img.resize((64, 64), Image.ANTIALIAS)
            img0 = np.array(img)
            ret[i, :, :, :] = torch.from_numpy(img0)
        # data = dset.ImageFolder(path, transform=transforms.ToTensor())
        # dl = torch.utils.data.DataLoader(dataset=data, batch_size=batch_size)
        # for d in dl:
        #     i,l = d
        #     # img = Image.open(path + str(i) + ".png")
        #     images.append(parse_image(i))
        return parse_image(ret)

    def set_cuda(x):
        return x.cuda()

    # def getitem(d, ids, e, p, l):
    #     # different ids, same ide, same idl, same idt, same idp
    #     print(len(d.ids))
    #     for i in range(len(d.ids)):
    #         if(d.ids[i] == ids and d.ide[i] == e and d.idp[i] == p and d.idl[i] == l):
    #             return d.imgs[i]

    dataroot0 = TrainingData[0]
    dataset0 = MultipieLoader.FareMultipieLightingTripletsFrontal(
        opt, root=dataroot0, resize=64)
    dataset0.get_Sample('params0.csv')

    dataroot5 = TrainingData[5]
    dataset5 = MultipieLoader.FareMultipieLightingTripletsFrontal(
        opt, root=dataroot5, resize=64)
    dataset5.get_Sample('params5.csv')

    # input_img_list = [dataset0.getitem('001','01','051','00'), dataset0.getitem('002','01','051','00'), dataset0.getitem('003','01','051','00'), dataset0.getitem('004','01','051','00'), dataset0.getitem('005','01','051','00')]
    # print(input_img_list)
    # # i = []
    # save_images(input_img_list, "test_images/inputs/")
    # # for img in input_img_list:
    # #     i.append(parse_image(img))
    # # dataroot5 = TrainingData[5]
    # # dataset5 = MultipieLoader.FareMultipieLightingTripletsFrontal(opt, root=dataroot5, resize=64)
    # # dataset5.get_Sample('params5.csv')
    # swap_img_list = [dataset5.getitem('028','02','051','12'), dataset5.getitem('028','02','051','12'), dataset5.getitem('028','02','051','12'), dataset5.getitem('028','02','051','12'), dataset5.getitem('028','02','051','12')]
    # save_images(swap_img_list, "test_images/swap_inputs/")

    # target_img_list = [dataset5.getitem('001','01','051','12'), dataset5.getitem('002','01','051','12'), dataset5.getitem('003','01','051','12'), dataset5.getitem('004','01','051','12'), dataset5.getitem('005','01','051','12')]
    # # print(target_img_list)
    # save_images(target_img_list, "test_images/targets/")

    inputs = Variable(set_cuda(load_images("test_images/inputs/1/", 5)))
    input_swap = Variable(
        set_cuda(load_images("test_images/swap_inputs/1/", 5)))
    targets = Variable(set_cuda(load_images("test_images/targets/1/", 5)))

    visualizeAsImages(inputs.data.clone(),
                      opt.dirImageoutput,
                      filename='test_input_iter__img0',
                      n_sample=5,
                      nrow=1,
                      normalize=False)
    visualizeAsImages(input_swap.data.clone(),
                      opt.dirImageoutput,
                      filename='test_input_swap_orig_iter__img0',
                      n_sample=5,
                      nrow=1,
                      normalize=False)

    model = torch.load("models/783a5ffc-7a8e-4fa7-b67d-0e92d41fdc40/model.pth")
    # model.eval()
    lightCode_i, nonLightCode_i, o_i = model.forward(inputs)
    lightCode_i_s, nonLightCode_i_s, o_i_s = model.forward(input_swap)
    # lightCode_t, nonLightCode_t, o_t = model.forward(targets)

    z = torch.cat([lightCode_i_s, nonLightCode_i], 1)
    z = z.unsqueeze(2)
    z = z.unsqueeze(3)
    o_swap = model.D(z)
    # print(lightCode0.size())

    visualizeAsImages(o_i.data.clone(),
                      opt.dirImageoutput,
                      filename='test_ouput_iter__img0',
                      n_sample=5,
                      nrow=1,
                      normalize=False)
    visualizeAsImages(o_swap.data.clone(),
                      opt.dirImageoutput,
                      filename='test_ouput_swap_iter__img0',
                      n_sample=5,
                      nrow=1,
                      normalize=False)
    visualizeAsImages(targets.data.clone(),
                      opt.dirImageoutput,
                      filename='test_target_iter__img0',
                      n_sample=5,
                      nrow=1,
                      normalize=False)
コード例 #9
0
    doTraining = True
    doTesting = False
else:
    doTraining = False
    doTesting = True
iter_mark = 0
for epoch in range(opt.epoch_iter):
    if doTraining:
        loss_sum = 0
        train_amount = 0 + 1e-6
        gc.collect()  # collect garbage
        # for subprocid in range(10):
        # random sample a dataroot
        dataroot = random.sample(TrainingData, 1)[0]
        aaaa = 0
        dataset = MultipieLoader.FareMultipieLightingTripletsFrontal(
            opt, root=dataroot, resize=64)
        print('# size of the current (sub)dataset is %d' % len(dataset))
        train_amount = train_amount + len(dataset)
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=opt.batchSize,
                                                 shuffle=True,
                                                 num_workers=int(opt.workers))
        for batch_idx, data_point in enumerate(dataloader, 0):

            aaaa += 1
            if aaaa > 20:
                break

            gc.collect()  # collect garbage
            # sample the data points:
            # dp0_img: image of data point 0