Ejemplo n.º 1
0
def train_c(epoch,netd_c,optd_c,loader,step,opt,Q):


    netd_c.train()
    pbar = tqdm(enumerate(loader))
    for _, (image_c, label) in pbar:



        index = label[:, 1]
        index = Variable(index).cuda()
        label = label[:, 0]
        # print(image_c.size())
        # plt.figure(0)
        # if torch.sum(label==1)>1:
        #     plt.imshow(np.transpose((image_c[label == 1][0].cpu().numpy() + 1) / 2, [1, 2, 0]))
        #     plt.show()
        #     print(image_c.max(), image_c.min())

        real_label = label.cuda().long()

        real_loss_c = torch.zeros(1).cuda()
        if torch.sum(torch.ones(real_label.size())[index==1])+ torch.sum(torch.ones(real_label.size())[index==0])>0:

            real_input_c = Variable(image_c).cuda()

            _, real_cls = netd_c(real_input_c)
            if torch.sum(torch.ones(real_label.size())[index==1]) > 0:
                real_loss_c += clip_cross_entropy(real_cls[index == 1], real_label[index == 1])  #
                # print(real_loss_c)
            if torch.sum(torch.ones(real_label.size())[index==0]) > 0:
                real_loss_c += forward_loss(real_cls[index == 0], real_label[index == 0],Q)  #

            optd_c.zero_grad()
            real_loss_c.backward()
            optd_c.step()
        pbar.set_description(
            'Loss: {:.6f}'.format(real_loss_c[0]))

    torch.save(netd_c.state_dict(), os.path.join(opt.savingroot, opt.dataset,
                                                 str(opt.p1 * 100) + '%complementary/' + str(
                                                     opt.p1) + f'_chkpts/d_{epoch:03d}.pth'))

    return step
Ejemplo n.º 2
0
def train_forward(args, model, device, loader, optimizer, epoch, T):
    model.train()
    train_loss = 0
    correct = 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = forward_loss(output, target, T)
        loss.backward()
        optimizer.step()
        train_loss += data.size(0) * loss.item()
        pred = output.argmax(
            dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
    train_loss /= len(loader.dataset)
    print(
        'Epoch: {}/{}\nTraining loss: {:.4f}, Training accuracy: {}/{} ({:.2f}%)'
        .format(epoch, args.epochs, train_loss, correct, len(loader.dataset),
                100. * correct / len(loader.dataset)))
Ejemplo n.º 3
0
def train_c(epoch, netd_c, optd_c, loader, step, opt):

    netd_c.train()
    for _, (image_c, label) in enumerate(loader):

        # plt.imshow(np.transpose((image[0].cpu().numpy()+1)/2,[1,2,0]))
        # plt.show()
        # print(image.max(),image.min())

        index = label[:, 1]
        index = Variable(index).cuda()
        label = label[:, 0]

        real_label = label.cuda()

        real_loss_c = torch.zeros(1).cuda()
        if sum(index == 1) + sum(index == 0) > 0:

            real_input_c = Variable(image_c).cuda()

            _, real_cls = netd_c(real_input_c)
            if sum(index == 1) > 0:
                real_loss_c += F.cross_entropy(real_cls[index == 1],
                                               real_label[index == 1])  #
            elif sum(index == 0) > 0:
                real_loss_c += forward_loss(real_cls[index == 0],
                                            real_label[index == 0])  #

            optd_c.zero_grad()
            real_loss_c.backward()
            optd_c.step()

    torch.save(
        netd_c.state_dict(),
        os.path.join(
            opt.savingroot, opt.dataset,
            str(opt.p1 * 100) + '%complementary/' + str(opt.p1) +
            f'_chkpts/d_{epoch:03d}.pth'))

    return step
Ejemplo n.º 4
0
def train_data_g(netd, netg, optd, epoch, step, opt, loader):

    netg.eval()
    netd.train()
    for _, (image_g, image_c, label) in enumerate(loader):

        # plt.imshow(np.transpose((image[0].cpu().numpy()+1)/2,[1,2,0]))
        # plt.show()
        # print(image.max(),image.min())

        index = label[:, 1]
        index = Variable(index).cuda()
        label = label[:, 0]
        real_label = label.cuda()

        real_loss_c = torch.zeros(1).cuda()
        if sum(index == 1) + sum(index == 0) > 0:

            real_input_c = Variable(image_c).cuda()

            _, real_cls = netd(real_input_c)
            if sum(index == 1) > 0:
                real_loss_c += F.cross_entropy(real_cls[index == 1],
                                               real_label[index == 1])  #
            elif sum(index == 0) > 0:
                real_loss_c += forward_loss(real_cls[index == 0],
                                            real_label[index == 0])  #

        #######################
        # fake input and label
        #######################
        noise = Variable(torch.Tensor(opt.batch_size,
                                      opt.nz).normal_(0, 1)).cuda()
        fake_label = Variable(torch.LongTensor(
            opt.batch_size).random_(10)).cuda()
        optd.zero_grad()
        fake_input = netg(noise, fake_label)
        # print(fake_input.min(),fake_input.max())

        #
        # img = np.transpose((fake_input[0].cpu().detach().numpy()+1)/2,[1,2,0])

        # if img.shape[2] == 1:
        #     img = np.concatenate([img,img,img],axis=2)

        # plt.imshow(img)
        # plt.show()
        # print(fake_label[0])
        # time.sleep(1)

        fake_pred, fake_cls = netd(fake_input.detach())

        fake_loss_c = F.cross_entropy(fake_cls, fake_label)  #

        # if epoch >=80:
        #     fake_loss = fake_loss + cep(fake_cls, fake_label)

        c_loss = fake_loss_c + real_loss_c
        c_loss.backward()
        optd.step()

        log_value('c_f_loss', c_loss, step)

    torch.save(
        netd.state_dict(),
        os.path.join(
            opt.savingroot, opt.dataset,
            str(opt.p1 * 100) + '%complementary/' + str(opt.p1) +
            f'_chkpts_fake_data/d_{epoch:03d}.pth'))
    return step
Ejemplo n.º 5
0
def train_data_g(netd,netg,optd,epoch,step,opt,loader,Q):

    netg.eval()
    requires_grad(netg, False)
    netd.train()
    for _, (image_c, label) in enumerate(loader):

        # plt.imshow(np.transpose((image[0].cpu().numpy()+1)/2,[1,2,0]))
        # plt.show()
        # print(image.max(),image.min())

        index = label[:, 1]
        index = Variable(index).cuda()
        label = label[:, 0]
        real_label = label.cuda()

        real_loss_c = torch.zeros(1).cuda()
        if sum(index == 1) + sum(index == 0) > 0:

            real_input_c = Variable(image_c).cuda()

            _, real_cls = netd(real_input_c)
            if sum(index == 1) > 0:
                real_loss_c += F.cross_entropy(real_cls[index == 1], real_label[index == 1])  #
            elif sum(index == 0) > 0:
                real_loss_c += forward_loss(real_cls[index == 0], real_label[index == 0],Q=Q)  #



        #######################
        # fake input and label
        #######################
        noise = torch.randn(opt.batch_size, opt.nz).cuda()
        fake_label = torch.multinomial(
                torch.ones(opt.num_class), opt.batch_size, replacement=True
            ).cuda()
        optd.zero_grad()
        fake_input = netg(noise,fake_label)
        # print(fake_input.min(),fake_input.max())

        #
        # img = np.transpose((fake_input[0].cpu().detach().numpy()+1)/2,[1,2,0])


        # plt.imshow(img)
        # plt.show()
        # print(fake_label[0])
        # time.sleep(1)

        fake_pred, fake_cls = netd(fake_input.detach())


        fake_loss_c = F.cross_entropy(fake_cls,fake_label) #

        # if epoch >=80:
        #     fake_loss = fake_loss + cep(fake_cls, fake_label)

        c_loss = fake_loss_c+real_loss_c
        c_loss.backward()
        optd.step()

        log_value('c_f_loss', c_loss, step)

    torch.save(netd.state_dict(), os.path.join(opt.savingroot,opt.dataset,str(opt.p1 * 100) + '%complementary/' + str(opt.p1)+f'_chkpts_fake_data/Nd_{epoch:03d}.pth'))
    return step
Ejemplo n.º 6
0
def train_c(epoch,
            netd_c,
            optd_c,
            loader,
            step,
            opt,
            co_data=torch.Tensor([]).cuda()):

    data_iter = iter(loader)
    iters = len(loader)
    extra_iters = 0
    if len(co_data) > 0:
        if len(co_data[0]) % 128 == 0:
            extra_iters = math.floor(len(co_data[0]) / 128)
        else:
            extra_iters = math.floor(len(co_data[0]) / 128) + 1
        co_label = co_data[1]
        co_datas = co_data[0]

    netd_c.train()
    i = 0
    print('iters:', iters, 'extra_iters:', extra_iters, 'sum of both:',
          iters + extra_iters)
    while i < iters + extra_iters:
        #for _, (image_c, label) in enumerate(loader):

        # plt.imshow(np.transpose((image[0].cpu().numpy()+1)/2,[1,2,0]))
        # plt.show()
        # print(image.max(),image.min())
        real_loss_c = torch.zeros(1).cuda()
        if i < iters:
            labeled_data = data_iter.next()
            image_c, label = labeled_data

            index = label[:, 1]
            index = Variable(index).cuda()
            label = label[:, 0]

            real_label = label.cuda()

            if sum(index == 1) + sum(index == 0) > 0:

                real_input_c = Variable(image_c).cuda()

                _, real_cls = netd_c(real_input_c)

                if sum(index == 1) > 0:
                    real_loss_c += F.cross_entropy(real_cls[index == 1],
                                                   real_label[index == 1])  #
                if sum(index == 0) > 0:
                    real_loss_c += forward_loss(real_cls[index == 0],
                                                real_label[index == 0])  #

                optd_c.zero_grad()
                real_loss_c.backward()
                optd_c.step()
            i += 1
        else:
            if i != iters + extra_iters - 1:
                co_labels = co_label[128 * (i - iters):(127 + 128 *
                                                        (i - iters)), 0]
                real_input_c = Variable(
                    co_datas[128 * (i - iters):(127 + 128 *
                                                (i - iters))]).cuda()
                print(co_labels.size(), real_input_c.size())
                _, real_cls = netd_c(real_input_c)
                real_loss_c += F.cross_entropy(real_cls, co_labels.long())
            else:
                co_labels = co_label[128 * (i - iters):, 0]
                print(co_labels, co_datas[128 * (i - iters):].size())
                real_input_c = Variable(co_datas[128 * (i - iters):]).cuda()
                _, real_cls = netd_c(real_input_c)
                real_loss_c += F.cross_entropy(real_cls, co_labels.long())
            optd_c.zero_grad()
            real_loss_c.backward()
            optd_c.step()
            i += 1

    torch.save(
        netd_c.state_dict(),
        os.path.join(
            opt.savingroot, opt.dataset,
            str(opt.p1 * 100) + '%complementary/' + str(opt.p1) +
            '_chkpts/d_epoch{:03d}.pth'.format(epoch)))

    return step
Ejemplo n.º 7
0
def co_train_c(epoch, netd_c1, optd_c1, netd_c2, optd_c2, loader,
               unlabel_loader, step, opt):

    netd_c1.train()  # 将本层及子层的training设定为True
    netd_c2.train()
    min_dataloader = min(len(loader), len(unlabel_loader))
    max_dataloader = max(len(loader), len(unlabel_loader))
    data_iter = iter(loader)
    unlabel_data_iter = iter(unlabel_loader)

    i = 0
    #for _, (image_g,image_c, label) in enumerate(loader):

    # image_c:torch.size([128,1,32,32])
    # label:torch.size([128,2])

    while i < min_dataloader:
        real_loss_c = torch.zeros(1).cuda()

        labeled_data = data_iter.next()  # labeled data (include 0,1 labels)
        image_g, image_c, label = labeled_data
        index = label[:, 1]
        index = Variable(index).cuda()
        label = label[:, 0]
        real_label = label.cuda()

        if sum(index == 1) + sum(index == 0) > 0:

            real_input_c = Variable(image_c).cuda()
            _, real_cls1 = netd_c1(
                real_input_c)  #real_cls:torch.size([128,10])
            _, real_cls2 = netd_c2(real_input_c)

            if sum(index == 1) > 0:
                real_loss_c += F.cross_entropy(real_cls1[index == 1],
                                               real_label[index == 1])  #
                real_loss_c += F.cross_entropy(real_cls2[index == 1],
                                               real_label[index == 1])
                if i % 20 == 0:
                    print(real_loss_c)
            if sum(index == 0) > 0:  ### I have 'elif' changed to 'if' here
                real_loss_c += forward_loss(real_cls1[index == 0],
                                            real_label[index == 0])  #
                real_loss_c += forward_loss(real_cls2[index == 0],
                                            real_label[index == 0])
                if i % 20 == 0:
                    print(real_loss_c)

        unlabel_data = unlabel_data_iter.next()  # unlabeled data (-1 label)
        image_u, label_u = unlabel_data
        unlabel_input_u = Variable(image_u).cuda()
        _, unlabel_cls1 = netd_c1(
            unlabel_input_u)  #unlabel_cls:torch.size([128,10])
        _, unlabel_cls2 = netd_c2(unlabel_input_u)

        probt1 = F.softmax(unlabel_cls1, dim=1)
        out1 = torch.log(probt1)
        probt2 = F.softmax(unlabel_cls2, dim=1)
        out2 = torch.log(probt2)
        Qx = (out1 + out2) / 2
        T = (probt1 + probt2) / 2
        KLDiv = sum(sum((probt1.mul(out1 - Qx) + probt2.mul(out2 - Qx)) /
                        2)) / (len(probt1))
        #KLDiv = sum(sum( probt1.mul(out1-out2) ))/(len(probt1))*0.5
        #KLDiv = sum(sum(-T*Qx-( -probt1*out1 - probt2*out2 )/2))/(len(probt1))
        if i % 20 == 0:
            print(KLDiv)
        real_loss_c = real_loss_c + KLDiv
        #real_loss_c +=( F.kl_div((out1+out2)/2,probt1) + F.kl_div((out1+out2)/2,probt2 ) )/2
        #real_loss_c +=( torch.nn.KLDivLoss(out1,(probt1+probt2)/2) + torch.nn.KLDivLoss(out2,(probt1+probt2)/2) )/2

        optd_c1.zero_grad()
        optd_c2.zero_grad()
        real_loss_c.backward()
        optd_c1.step()
        optd_c2.step()
        i += 1

    torch.save(
        netd_c1.state_dict(),
        os.path.join(
            opt.savingroot, opt.dataset,
            str(opt.p1 * 100) + '%complementary/' + str(opt.p1) +
            '_chkpts/d1_epoch{:03d}.pth'.format(epoch)))
    torch.save(
        netd_c2.state_dict(),
        os.path.join(
            opt.savingroot, opt.dataset,
            str(opt.p1 * 100) + '%complementary/' + str(opt.p1) +
            '_chkpts/d2_epoch{:03d}.pth'.format(epoch)))
    return step