Exemplo n.º 1
0
    def train_batch(self, batch, epoch) -> BatchResult:
        ###########################
        # Forming input variables
        ###########################

        src_inputs, src_labels = batch
        if self.opt.gpu >= 0:
            src_inputs, src_labels = src_inputs.cuda(), src_labels.cuda()
        src_inputsv, src_labelsv = Variable(src_inputs), Variable(src_labels)

        ###########################
        # Updates
        ###########################

        self.classifier.zero_grad()
        self.mixer.zero_grad()
        outC = self.classifier(self.mixer(src_inputsv))
        loss = self.criterion(outC, src_labelsv)
        num_of_correct = 0  # todo: change num of correct
        loss.backward()
        self.optimizer_classifier.step()
        self.optimizer_mixer.step()

        # Learning rate scheduling
        if self.opt.lrd:
            print(epoch)
            # todo: i changed curr_iter to epoch
            self.optimizer_mixer = utils.exp_lr_scheduler(
                self.optimizer_mixer, self.opt.lr, self.opt.lrd, epoch)
            self.optimizer_classifier = utils.exp_lr_scheduler(
                self.optimizer_classifier, self.opt.lr, self.opt.lrd, epoch)
        return BatchResult(loss.item(), int(num_of_correct / len(outC)))
Exemplo n.º 2
0
    def train(self):

        curr_iter = 0
        for epoch in range(self.opt.nepochs):

            self.netF.train()
            self.netC.train()

            for i, datas in enumerate(self.source_trainloader):

                ###########################
                # Forming input variables
                ###########################

                src_inputs, src_labels = datas
                if self.opt.gpu >= 0:
                    src_inputs, src_labels = src_inputs.cuda(
                    ), src_labels.cuda()
                src_inputsv, src_labelsv = Variable(src_inputs), Variable(
                    src_labels)

                ###########################
                # Updates
                ###########################

                self.netC.zero_grad()
                self.netF.zero_grad()
                outC = self.netC(self.netF(src_inputsv))
                loss = self.criterion(outC, src_labelsv)
                loss.backward()
                self.optimizerC.step()
                self.optimizerF.step()

                curr_iter += 1

                # print training information
                if ((i + 1) % 50 == 0):
                    text_format = 'epoch: {}, iteration: {}, errC: {}'
                    train_text = text_format.format(epoch + 1, i + 1,
                                                    loss.item())
                    print(train_text)

                # Learning rate scheduling
                if self.opt.lrd:
                    self.optimizerF = utils.exp_lr_scheduler(
                        self.optimizerF, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)
                    self.optimizerC = utils.exp_lr_scheduler(
                        self.optimizerC, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)

                    # Validate every epoch
            self.validate(epoch)
Exemplo n.º 3
0
    def train(self):

        curr_iter = 0
        for epoch in range(self.opt.nepochs):

            self.netF.train()
            self.netC.train()

            for i, datas in enumerate(self.source_trainloader):

                ###########################
                # Forming input variables
                ###########################

                src_inputs, src_labels = datas
                if self.opt.gpu >= 0:
                    src_inputs, src_labels = src_inputs.cuda(
                    ), src_labels.cuda()
                src_inputsv, src_labelsv = Variable(src_inputs), Variable(
                    src_labels)

                ###########################
                # Updates
                ###########################

                self.netC.zero_grad()
                self.netF.zero_grad()
                outC = self.netC(self.netF(src_inputsv))
                loss = self.criterion(outC, src_labelsv)
                loss.backward()
                self.optimizerC.step()
                self.optimizerF.step()

                curr_iter += 1

                # Learning rate scheduling
                if self.opt.lrd:
                    self.optimizerF = utils.exp_lr_scheduler(
                        self.optimizerF, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)
                    self.optimizerC = utils.exp_lr_scheduler(
                        self.optimizerC, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)

            # Validate every epoch
            self.validate(epoch)
Exemplo n.º 4
0
    def train(self):
        
        curr_iter = 0
        for epoch in range(self.opt.nepochs):
            
            self.netF.train()    
            self.netC.train()    
        
            for i, datas in enumerate(self.source_trainloader):
                ###########################
                # Forming input variables
                ###########################
                
                src_inputs, src_labels = datas
                if self.opt.gpu>=0:
                    src_inputs, src_labels = src_inputs.cuda(), src_labels.cuda()
                src_inputsv, src_labelsv = Variable(src_inputs), Variable(src_labels)
                
                ###########################
                # Updates
                ###########################
                
                self.netC.zero_grad()
                self.netF.zero_grad()
                outC = self.netC(self.netF(src_inputsv))   
                loss = self.criterion(outC, src_labelsv)
                loss.backward()    
                self.optimizerC.step()
                self.optimizerF.step()

                curr_iter += 1
                
                # Learning rate scheduling
                if self.opt.lrd:
                    self.optimizerF = utils.exp_lr_scheduler(self.optimizerF, epoch, self.opt.lr, self.opt.lrd, curr_iter)
                    self.optimizerC = utils.exp_lr_scheduler(self.optimizerC, epoch, self.opt.lr, self.opt.lrd, curr_iter)                  
            
            # Validate every epoch
            self.validate(epoch)
Exemplo n.º 5
0
def train_s2cnn(mlp,
                s2cnn,
                data,
                train_batches,
                test_batches,
                num_epochs,
                init_learning_rate_s2cnn,
                learning_rate_decay_epochs,
                device_id=0):
    """ train the s2cnn keeping the baseline frozen """
    optim = OPTIMIZER(s2cnn.parameters(), lr=init_learning_rate_s2cnn)
    criterion = nn.MSELoss()
    if torch.cuda.is_available():
        criterion = criterion.cuda(device_id)
    for epoch in range(num_epochs):
        optim = exp_lr_scheduler(optim,
                                 epoch,
                                 init_lr=init_learning_rate_s2cnn,
                                 lr_decay_epoch=learning_rate_decay_epochs)
        train_losses = []
        print("training")
        for iteration, batch_idxs in enumerate(train_batches):
            s2cnn.train()
            mlp.eval()
            optim.zero_grad()
            loss = eval_batch_s2cnn(mlp, s2cnn, data, batch_idxs, criterion)
            loss.backward()
            optim.step()
            train_losses.append(loss.data)
            print("\riteration {}/{} - batch loss: {}".format(
                iteration + 1, train_batches.num_iterations(),
                np.sqrt(train_losses[-1])),
                  end="")
        print()
        test_losses = []
        print("evaluating")
        for iteration, batch_idxs in enumerate(test_batches):
            s2cnn.eval()
            mlp.eval()
            loss = eval_batch_s2cnn(mlp, s2cnn, data, batch_idxs, criterion)
            test_losses.append(loss.data)
            print("\riteration {}/{}  - batch loss: {}".format(
                iteration + 1, test_batches.num_iterations(),
                np.sqrt(test_losses[-1])),
                  end="")
        print()
        train_loss = np.sqrt(np.mean(train_losses))
        test_loss = np.sqrt(np.mean(test_losses))
        print("epoch {}/{} - avg train loss: {}, test loss: {}".format(
            epoch + 1, num_epochs, train_loss, test_loss))
    return train_loss, test_loss
    def train(self):

        curr_iter = 0

        reallabel = torch.FloatTensor(self.opt.batchSize).fill_(
            self.real_label_val)
        fakelabel = torch.FloatTensor(self.opt.batchSize).fill_(
            self.fake_label_val)
        temp = torch.LongTensor(self.opt.batchSize).fill_(11)
        if self.opt.gpu >= 0:
            reallabel, fakelabel = reallabel.cuda(), fakelabel.cuda()
            temp = temp.cuda()
        reallabelv = Variable(reallabel)
        fakelabelv = Variable(fakelabel)
        temp = Variable(temp)

        for epoch in range(self.opt.nepochs):
            print("epoch = ", epoch)
            self.generator.train()
            self.mixer.train()
            self.classifier.train()
            self.discriminator.train()

            for i, (datas, datat) in enumerate(
                    itertools.zip_longest(self.source_train_ds,
                                          self.target_train_ds)):
                if i >= min(len(self.source_train_ds), len(
                        self.target_train_ds)):
                    # todo: check what to do with the left over data
                    break

                ###########################
                # Forming input variables
                ###########################

                src_inputs, src_labels = datas
                tgt_inputs, __ = datat
                src_inputs_unnorm = ((
                    (src_inputs * self.std[0]) + self.mean[0]) - 0.5) * 2

                # Creating one hot vector
                labels_onehot = np.zeros(
                    (self.opt.batchSize, self.num_classes + 1),
                    dtype=np.float32)
                for num in range(self.opt.batchSize):
                    labels_onehot[num, src_labels[num]] = 1
                src_labels_onehot = torch.from_numpy(labels_onehot)

                labels_onehot = np.zeros(
                    (self.opt.batchSize, self.num_classes + 1),
                    dtype=np.float32)
                for num in range(self.opt.batchSize):
                    labels_onehot[num, self.num_classes] = 1
                tgt_labels_onehot = torch.from_numpy(labels_onehot)

                if self.opt.gpu >= 0:
                    src_inputs, src_labels = src_inputs.cuda(
                    ), src_labels.cuda()
                    src_inputs_unnorm = src_inputs_unnorm.cuda()
                    tgt_inputs = tgt_inputs.cuda()
                    src_labels_onehot = src_labels_onehot.cuda()
                    tgt_labels_onehot = tgt_labels_onehot.cuda()

                # Wrapping in variable
                src_inputsv, src_labelsv = Variable(src_inputs), Variable(
                    src_labels)
                src_inputs_unnormv = Variable(src_inputs_unnorm)
                tgt_inputsv = Variable(tgt_inputs)
                src_labels_onehotv = Variable(src_labels_onehot)
                tgt_labels_onehotv = Variable(tgt_labels_onehot)

                ###########################
                # Updates
                ###########################

                # Updating D network

                self.discriminator.zero_grad()
                src_emb = self.mixer(src_inputsv)
                src_emb_cat = torch.cat((src_labels_onehotv, src_emb), 1)
                src_gen = self.generator(src_emb_cat)

                tgt_emb = self.mixer(tgt_inputsv)
                tgt_emb_cat = torch.cat((tgt_labels_onehotv, tgt_emb), 1)
                tgt_gen = self.generator(tgt_emb_cat)

                src_realoutputD = self.discriminator(src_inputs_unnormv)
                errD_src_real = self.criterion_c(src_realoutputD, src_labelsv)
                tgt_realoutputD = self.discriminator(tgt_inputsv, False)
                errD_tgt_real = self.criterion_s(tgt_realoutputD, reallabelv)

                src_fakeoutputD = self.discriminator(src_gen)
                errD_src_fake = self.criterion_c(src_fakeoutputD, temp)

                tgt_fakeoutputD = self.discriminator(tgt_gen, False)
                errD_tgt_fake = self.criterion_s(tgt_fakeoutputD, fakelabelv)

                errD = errD_src_real + errD_tgt_real + errD_src_fake + errD_tgt_fake
                errD.backward(retain_graph=True)
                self.optimizer_discriminator.step()

                # Updating G network

                self.generator.zero_grad()
                src_fakeoutputD_c = self.discriminator(src_gen)
                errG_c = self.criterion_c(src_fakeoutputD_c, src_labelsv)

                errG = errG_c
                errG.backward(retain_graph=True)
                self.optimizer_generator.step()

                # Updating C network

                self.classifier.zero_grad()
                outC = self.classifier(src_emb)
                errC = self.criterion_c(outC, src_labelsv)
                errC.backward(retain_graph=True)
                self.optimizer_classifier.step()

                # Updating F network

                self.mixer.zero_grad()
                errF_fromC = self.criterion_c(outC, src_labelsv)

                src_fakeoutputD_c = self.discriminator(src_gen)
                errF_src_fromD = self.criterion_c(
                    src_fakeoutputD_c, src_labelsv) * (self.opt.adv_weight)

                tgt_fakeoutputD_s = self.discriminator(tgt_gen, False)
                errF_tgt_fromD = self.criterion_s(
                    tgt_fakeoutputD_s,
                    reallabelv) * (self.opt.adv_weight * self.opt.alpha)

                errF = errF_fromC + errF_src_fromD + errF_tgt_fromD
                errF.backward()
                self.optimizer_mixer.step()

                curr_iter += 1

                # Visualization
                if i == 1:
                    vutils.save_image((src_gen.data / 2) + 0.5,
                                      '%s/visualization/source_gen_%d.png' %
                                      (self.opt.outf, epoch))
                    vutils.save_image((tgt_gen.data / 2) + 0.5,
                                      '%s/visualization/target_gen_%d.png' %
                                      (self.opt.outf, epoch))

                # Learning rate scheduling
                if self.opt.lrd:
                    self.optimizer_discriminator = utils.exp_lr_scheduler(
                        self.optimizer_discriminator, self.opt.lr,
                        self.opt.lrd, curr_iter)
                    self.optimizer_mixer = utils.exp_lr_scheduler(
                        self.optimizer_mixer, self.opt.lr, self.opt.lrd,
                        curr_iter)
                    self.optimizer_classifier = utils.exp_lr_scheduler(
                        self.optimizer_classifier, self.opt.lr, self.opt.lrd,
                        curr_iter)

                    # Validate every epoch
            self.validate(epoch + 1)
Exemplo n.º 7
0
    def train(self):

        curr_iter = 0

        reallabel = torch.FloatTensor(self.opt.batchSize).fill_(
            self.real_label_val)
        fakelabel = torch.FloatTensor(self.opt.batchSize).fill_(
            self.fake_label_val)
        if self.opt.gpu >= 0:
            reallabel, fakelabel = reallabel.cuda(), fakelabel.cuda()

        source_domain = torch.FloatTensor(self.opt.batchSize).fill_(
            self.real_label_val)
        target_domain = torch.FloatTensor(self.opt.batchSize).fill_(
            self.fake_label_val)
        if self.opt.gpu >= 0:
            source_domain, target_domain = source_domain.cuda(
            ), target_domain.cuda()

        # list_errD_src_real_c = []
        list_errD_src_real_s = []
        list_errD_src_fake_s = []
        list_errD_tgt_fake_s = []
        list_errG_c = []
        list_errG_s = []
        # list_errC = []
        # list_errF_fromC = []
        list_errF_src_fromD = []
        list_errF_tgt_fromD = []

        for epoch in range(self.opt.nepochs):

            self.netG.train()
            self.netF.train()
            self.netC.train()
            self.netD.train()

            for i, (datas, datat) in enumerate(
                    zip(self.source_trainloader, self.target_trainloader)):

                ###########################
                # Forming input variables
                ###########################

                src_inputs, src_labels = datas
                tgt_inputs, __ = datat

                # Creating one hot vector
                labels_onehot = np.zeros(
                    (self.opt.batchSize, self.nclasses + 1), dtype=np.float32)
                for num in range(self.opt.batchSize):
                    labels_onehot[num, src_labels[num]] = 1
                src_labels_onehot = torch.from_numpy(labels_onehot)

                labels_onehot = np.zeros(
                    (self.opt.batchSize, self.nclasses + 1), dtype=np.float32)
                for num in range(self.opt.batchSize):
                    labels_onehot[num, self.nclasses] = 1
                tgt_labels_onehot = torch.from_numpy(labels_onehot)

                if self.opt.gpu >= 0:
                    src_inputs, src_labels = src_inputs.cuda(
                    ), src_labels.cuda()
                    tgt_inputs = tgt_inputs.cuda()
                    src_labels_onehot = src_labels_onehot.cuda()
                    tgt_labels_onehot = tgt_labels_onehot.cuda()

                ###########################
                # Updating D network
                ###########################
                self.netD.zero_grad()

                src_emb = self.netF(src_inputs)
                src_emb_cat = torch.cat((src_labels_onehot, src_emb), 1)
                # print('F src_emb {}, src_emb_cat {}'.format(src_emb.shape, src_emb_cat.shape))
                src_gen = self.netG(src_emb_cat)
                # print('src_gen {}'.format(src_gen.shape))

                tgt_emb = self.netF(tgt_inputs)
                tgt_emb_cat = torch.cat((tgt_labels_onehot, tgt_emb), 1)
                tgt_gen = self.netG(src_emb_cat)

                # 源领域 初始图片的判别损失 -->  reallabel
                src_realoutputD_s, src_realoutputD_c, src_real_feature = self.netD(
                    src_inputs)
                # print('src_realoutputD_s {}'.format(src_realoutputD_s.shape))
                # print('src_realoutputD_c {}'.format(src_realoutputD_c.shape))
                # print('src_real_feature {}'.format(src_real_feature.shape))
                errD_src_real_dloss = self.criterion_s(src_realoutputD_s,
                                                       reallabel)
                errD_src_real_closs = self.criterion_c(src_realoutputD_c,
                                                       src_labels)

                # 目的领域 初始图片的判别损失 --> reallabel
                # diff3,增加了目的领域原始样本,GTA中没有这个分支
                # tgt_realoutputD_s, tgt_realoutputD_c, tgt_real_feature = self.netD(tgt_inputs)
                # errD_tgt_real_dloss = self.criterion_s(tgt_realoutputD_s, reallabel)

                # 源领域 生成图片的判别损失 --> fakelabel
                # diff2,增加src_gen的分类损失
                src_fakeoutputD_s, src_fakeoutputD_c, src_fake_feature = self.netD(
                    src_gen)
                errD_src_fake_dloss = self.criterion_s(src_fakeoutputD_s,
                                                       fakelabel)
                errD_src_fake_closs = self.criterion_c(src_fakeoutputD_c,
                                                       src_labels)
                # print('D src_fake_feature {}'.format(src_fake_feature.shape))

                # 目的领域 生成图片的判别损失 --> fakelabel
                tgt_fakeoutputD_s, tgt_fakeoutputD_c, tgt_fake_feature = self.netD(
                    tgt_gen)
                errD_tgt_fake_dloss = self.criterion_s(tgt_fakeoutputD_s,
                                                       fakelabel)

                errD_mmd = self.mmd_loss(src_fake_feature, tgt_fake_feature)

                errD = (
                    errD_src_real_dloss + errD_src_fake_dloss +
                    errD_tgt_fake_dloss
                ) + errD_src_fake_closs + errD_src_real_closs + self.opt.mmd_weight * errD_mmd
                if i == 0:
                    print(
                        '  D: errD {:.2f}, [errD_src_real_dloss {:.2f}, errD_src_fake_dloss {:.2f}, errD_tgt_fake_dloss {:.2f}], errD_src_fake_closs {:.2f}, errD_src_real_closs {:.2f}, errD_mmd {:.2f}'
                        .format(errD.item(), errD_src_real_dloss.item(),
                                errD_src_fake_dloss.item(),
                                errD_tgt_fake_dloss.item(),
                                errD_src_fake_closs.item(),
                                errD_src_real_closs.item(), errD_mmd.item()))
                    logging.debug(
                        '  D: errD {:.2f}, [errD_src_real_dloss {:.2f}, errD_src_fake_dloss {:.2f}, errD_tgt_fake_dloss {:.2f}], errD_src_fake_closs {:.2f}, errD_src_real_closs {:.2f}, errD_mmd {:.2f}'
                        .format(errD.item(), errD_src_real_dloss.item(),
                                errD_src_fake_dloss.item(),
                                errD_tgt_fake_dloss.item(),
                                errD_src_fake_closs.item(),
                                errD_src_real_closs.item(), errD_mmd.item()))
                errD.backward()
                self.optimizerD.step()

                ###########################
                # Updating C network
                ###########################
                # self.netC.zero_grad()
                # outC = self.netC(src_emb)
                # errC = self.criterion_c(outC, src_labelsv)
                # # src_fakeoutputD_s, src_fakeoutputD_c, _ = self.netD(src_gen)
                # # errC_src_closs = self.criterion_c(src_fakeoutputD_c, src_labelsv)
                # # errG_src_dloss = self.criterion_s(src_fakeoutputD_s, reallabel)
                # # errC = errG_src_closs
                # if i == 0:
                #     print('C: errC {:.2f}'.format(errC.item()))
                # # errC.backward(retain_graph=True)
                # errC.backward(retain_graph=True)
                # self.optimizerC.step()

                ###########################
                # Updating G network
                ###########################
                self.netG.zero_grad()

                src_emb = self.netF(src_inputs)
                src_emb_cat = torch.cat((src_labels_onehot, src_emb), 1)
                src_gen = self.netG(src_emb_cat)
                tgt_emb = self.netF(tgt_inputs)
                tgt_emb_cat = torch.cat((tgt_labels_onehot, tgt_emb), 1)
                tgt_gen = self.netG(src_emb_cat)

                # # # 源领域 生成图片的判别损失 --> reallabel
                # # #                分类损失
                src_fakeoutputD_s, src_fakeoutputD_c, src_fake_feature = self.netD(
                    src_gen)
                errG_src_closs = self.criterion_c(src_fakeoutputD_c,
                                                  src_labels)
                errG_src_dloss = self.criterion_s(src_fakeoutputD_s, reallabel)

                # # 目的领域 生成图片的判别损失 --> real
                # tgt_fakeoutputD_s, _, tgt_fake_feature = self.netD(tgt_gen)
                # errG_tgt_dloss = self.criterion_s(tgt_fakeoutputD_s, reallabel)

                # # src_gen / tgt_gen 的MMD
                errG_mmd = self.mmd_loss(src_fake_feature, tgt_fake_feature)

                errG = errG_src_closs + errG_src_dloss + errG_mmd
                if i == 0:
                    print(
                        '  G: errG {:.2f}, [errG_src_closs {:.2f}, errG_src_dloss {:.2f}, errG_mmd {:.2f}]'
                        .format(errG.item(), errG_src_closs.item(),
                                errG_src_dloss.item(), errG_mmd.item()))
                    logging.debug(
                        '  G: errG {:.2f}, [errG_src_closs {:.2f}, errG_src_dloss {:.2f}, errG_mmd {:.2f}]'
                        .format(errG.item(), errG_src_closs.item(),
                                errG_src_dloss.item(), errG_mmd.item()))
                errG.backward()
                self.optimizerG.step()

                ###########################
                # Updating F network
                ###########################
                self.netF.zero_grad()

                # errF_fromC = self.criterion_c(outC, src_labelsv)
                #############################
                # 包括src_gen的分类损失、src_gen的判别损失、tgt_gen的判别损失
                # 增加:src_emd/tgt_emd的MMD
                #############################

                src_emb = self.netF(src_inputs)
                src_emb_cat = torch.cat((src_labels_onehot, src_emb), 1)
                src_gen = self.netG(src_emb_cat)

                tgt_emb = self.netF(tgt_inputs)
                tgt_emb_cat = torch.cat((tgt_labels_onehot, tgt_emb), 1)
                tgt_gen = self.netG(src_emb_cat)

                # errF_fromC = self.criterion_c(outC, src_labelsv)
                # diff1, 将源域样本的分类损失,放到源域生成样本上了

                # 源领域 生成图片的判别损失 --> reallabel
                src_fakeoutputD_s, src_fakeoutputD_c, src_fake_feature = self.netD(
                    src_gen)
                errF_srcFake_closs = self.criterion_c(
                    src_fakeoutputD_c, src_labels) * (self.opt.adv_weight)
                errF_srcFake_dloss = self.criterion_s(
                    src_fakeoutputD_s,
                    reallabel) * (self.opt.adv_weight * self.opt.alpha)

                # 目的领域 生成图片的判别损失 --> reallabel
                tgt_fakeoutputD_s, tgt_fakeoutputD_c, tgt_fake_feature = self.netD(
                    tgt_gen)
                errF_tgtFake_dloss = self.criterion_s(
                    tgt_fakeoutputD_s,
                    reallabel) * (self.opt.adv_weight * self.opt.alpha)

                # errF_mmd = self.mmd_loss(src_fake_feature, tgt_fake_feature)

                errF = errF_srcFake_dloss + errF_tgtFake_dloss + errF_srcFake_closs
                if i == 0:
                    print(
                        '  F: errF {:.2f}, [errF_srcFake_dloss {:.2f}, errF_tgtFake_dloss {:.2f}], errF_srcFake_closs {:.2f}'
                        .format(errF.item(), errF_srcFake_dloss.item(),
                                errF_tgtFake_dloss.item(),
                                errF_srcFake_closs.item()))
                    logging.debug(
                        '  F: errF {:.2f}, [errF_srcFake_dloss {:.2f}, errF_tgtFake_dloss {:.2f}], errF_srcFake_closs {:.2f}'
                        .format(errF.item(), errF_srcFake_dloss.item(),
                                errF_tgtFake_dloss.item(),
                                errF_srcFake_closs.item()))
                errF.backward()
                self.optimizerF.step()

                curr_iter += 1

                # # list_errD_src_real_c.append(errD_src_real_c.item())
                # list_errD_src_real_s.append(errD_src_real_s.item())
                # list_errD_src_fake_s.append(errD_src_fake_s.item())
                # list_errD_tgt_fake_s.append(errD_tgt_fake_s.item())
                # list_errG_c.append(errG_c.item())
                # list_errG_s.append(errG_s.item())
                # # list_errC.append(errC.item())
                # # list_errF_fromC.append(errF_fromC.item())
                # list_errF_src_fromD.append(errF_src_fromD.item())
                # list_errF_tgt_fromD.append(errF_tgt_fromD.item())

                # Visualization
                # if i == 1:
                #     vutils.save_image((src_gen.data/2)+0.5, '%s/visualization/source_gen_%d.png' %(self.opt.outf, epoch))
                #     vutils.save_image((tgt_gen.data/2)+0.5, '%s/visualization/target_gen_%d.png' %(self.opt.outf, epoch))

                # Learning rate scheduling
                if self.opt.lrd:
                    self.optimizerD = utils.exp_lr_scheduler(
                        self.optimizerD, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)
                    self.optimizerF = utils.exp_lr_scheduler(
                        self.optimizerF, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)
                    self.optimizerC = utils.exp_lr_scheduler(
                        self.optimizerC, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)
                    # optimizerG要不要梯度递减?原始实现没有
                    self.optimizerG = utils.exp_lr_scheduler(
                        self.optimizerG, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)

            # Validate every epoch
            self.validate(epoch + 1)
            self.validate2(epoch + 1)
Exemplo n.º 8
0
    # -> Loss
    def nll_loss(pred, label, reduction='mean'):
        return F.nll_loss(F.log_softmax(pred, dim=1), label, reduction=reduction)
    loss_func = torch.nn.CrossEntropyLoss()
    train_loss_func = torch.nn.CrossEntropyLoss(reduction='mean')
    test_loss_func = torch.nn.CrossEntropyLoss(reduction='sum')

    # -> Data
    data_loader = data_loader

    epoch_val, loss_val, epoch_test, loss_test, acc_val, acc_test = [], [], [], [], [], []
    epoch_train, loss_train, roc_auc_val, roc_auc_test = [], [], [], []

    # === TRAINING ===
    for epoch in range(1, settings.epochs + 1):
        optimizer = utils.exp_lr_scheduler(optimizer, epoch, init_lr=settings.lr, lr_decay_epoch=10)
        #train
        train(epoch, model, data_loader, optimizer, train_loss_func, device=device, deepcoral=settings.deepcoral)
        test(epoch, model, data_loader.train_loader, test_loss_func, mode="Train", device=device)
        #val
        accuracy = test(epoch, model, data_loader.val_loader, test_loss_func, mode="Val", device=device)
        #test
        test(epoch, model, data_loader.target_loader, test_loss_func, mode="Test", device=device)

    # === SAVE PLOTS AT THE END ===
    fig, ax = plt.subplots(1, 3, figsize=(21, 5))
    ax[0].plot(epoch_train, loss_train, 'b', label='train')
    ax[0].set_title('Loss')
    ax[0].legend()

    ax[1].plot(epoch_val, acc_val, 'g', label='val')
Exemplo n.º 9
0
def on_policy_training(yahoo_data_reader,
                       validation_data_reader,
                       model,
                       experiment_name=None,
                       writer=None,
                       args=None):
    position_bias_vector = 1. / np.log2(2 + np.arange(200))
    lr = args.lr
    num_epochs = args.epochs
    weight_decay = args.weight_decay
    sample_size = args.sample_size

    print("Starting training with the following config")
    print(
        "Learning rate {}, Weight decay {}, Sample size {}\n"
        "Lambda_reward: {}, lambda_ind_fairness:{}, lambda_group_fairness:{}".
        format(lr, weight_decay, sample_size, args.lambda_reward,
               args.lambda_ind_fairness, args.lambda_group_fairness))
    if writer is None and args.summary_writing:
        writer = SummaryWriter(log_dir='runs')
    from utils import get_optimizer
    optimizer = get_optimizer(model.parameters(), lr, args.optimizer,
                              weight_decay)
    train_feats, train_rel = yahoo_data_reader.data
    len_train_set = len(train_feats)
    fairness_evaluation = True if args.lambda_ind_fairness > 0.0 else False
    group_fairness_evaluation = True if args.lambda_group_fairness > 0.0 else False

    if args.early_stopping:
        time_since_best = 0
        best_metric = 0.0
    for epoch in range(num_epochs):

        # # training
        print("Training....")
        if args.lr_scheduler and epoch >= 1:
            optimizer = exp_lr_scheduler(
                optimizer, epoch, lr, decay_factor=args.lr_decay)
        args.entropy_regularizer = args.entreg_decay * args.entropy_regularizer
        epoch_rewards_list = []
        running_ndcgs_list = []
        running_dcgs_list = []
        fairness_losses = []
        variances = []
        # shuffle(file_list)
        train_feats, train_rel = shuffle_combined(train_feats, train_rel)

        iterator = progressbar(
            range(len_train_set)) if args.progressbar else range(len_train_set)
        for i in iterator:
            if i % args.evaluate_interval == 0:
                if i != 0:
                    print(
                        "\nAverages of last 1000 rewards: {}, ndcgs: {}, dcgs: {}".
                        format(
                            np.mean(epoch_rewards_list[
                                -min([len(epoch_rewards_list), 1000]):]),
                            np.mean(running_ndcgs_list[
                                -min([len(running_dcgs_list), 1000]):]),
                            np.mean(running_dcgs_list[
                                -min([len(running_dcgs_list), 1000]):])))
                    exposure_relevance_plot = False
                else:
                    exposure_relevance_plot = False
                print(
                    "Evaluating on validation set: iteration {}/{} of epoch {}".
                    format(i, len_train_set, epoch))
                curr_metric = log_and_print(
                    model,
                    validation_data_reader,
                    writer,
                    epoch,
                    i,
                    len_train_set,
                    "val",
                    experiment_name,
                    args.gpu_id,
                    fairness_evaluation=fairness_evaluation,
                    exposure_relevance_plot=exposure_relevance_plot,
                    deterministic=args.validation_deterministic,
                    group_fairness_evaluation=group_fairness_evaluation,
                    args=args)
                # """
                # Early stopping
                # """
                if args.early_stopping:
                    if curr_metric >= best_metric:
                        best_metric = curr_metric
                        time_since_best = 0
                    elif curr_metric <= best_metric * 0.99:
                        time_since_best += 1
                    if time_since_best >= 5:
                        print(
                            "Validation set metric hasn't increased in 5 steps. Exiting"
                        )
                        return model

                # print("Evaluating on training set")
                # log_and_print(model, yahoo_data_reader, writer, epoch, i,
                #               len_train_set, "train", experiment_name,
                #               args.gpu_id, True)

                # feats, rel = yahoo_data_reader.readfile(file)
            feats, rel = train_feats[i], train_rel[i]
            if len(feats) == 1:
                continue
            if args.lambda_group_fairness > 0.0:
                group_identities = np.array(
                    feats[:, args.group_feat_id], dtype=np.int)
            if feats is None:
                continue
            if args.gpu_id is not None:
                feats, rel = convert_vars_to_gpu([feats, rel], args.gpu_id)

            scores = model(torchify(feats))
            probs_ = torch.nn.Softmax(dim=0)(scores)
            probs = probs_.data.numpy().flatten()

            rankings, rewards_list, ndcg_list, dcg_list = [], [], [], []
            # propensities = []
            for j in range(sample_size):
                # ranking, propensity = sample_ranking(
                # np.array(probs, copy=True))
                # print([(param.name, param.data)
                #        for param in model.parameters()], probs)
                ranking = sample_ranking(np.array(probs, copy=True), False)
                rankings.append(ranking)
                # propensities.append(propensity)
                ndcg, dcg = compute_dcg(ranking, rel, args.eval_rank_limit)
                if args.reward_type == "ndcg":
                    rewards_list.append(ndcg)
                elif args.reward_type == "dcg":
                    rewards_list.append(dcg)
                elif args.reward_type == "avrank":
                    avrank = -np.mean(compute_average_rank(ranking, rel))
                    rewards_list.append(np.sum(avrank))
                ndcg_list.append(ndcg)
                dcg_list.append(dcg)
            if args.baseline_type == "value":
                baseline = np.mean(rewards_list)
            elif args.baseline_type == "max":
                state = (rel)
                baseline = compute_baseline(
                    state=state, type=args.baseline_type)
            else:
                print("Choose a valid baseline type! Exiting")
                sys.exit(1)

            # FAIRNESS constraints
            if args.lambda_ind_fairness > 0.0:
                num_docs = len(ranking)
                rel_labels = np.array(rel)
                # relevant_indices_to_onehot(rel, num_docs)
                # relevance_variance = np.var(rel_labels)
                if args.fairness_version == "squared_residual":
                    expected_exposures = get_expected_exposure(
                        rankings, position_bias_vector)
                    k = minimize_for_k(rel_labels, expected_exposures,
                                       args.skip_zero_relevance)
                    disparity_matrix = IndividualFairnessLoss(
                    ).compute_disparities(rankings, rel_labels,
                                          position_bias_vector, k,
                                          args.skip_zero_relevance)
                    marginal_disparity = IndividualFairnessLoss(
                    ).compute_marginal_disparity(
                        disparity_matrix)  # should be size of the ranking set
                    assert len(marginal_disparity) == num_docs, \
                        "Marginal disparity is of the wrong dimension"
                    individual_fairness_coeffs = np.zeros(sample_size)
                    for index in range(sample_size):
                        individual_fairness_coeffs[
                            index] = IndividualFairnessLoss.compute_sq_individual_fairness_loss_coeff(
                                rankings[index], disparity_matrix[index],
                                marginal_disparity, k)
                    fairness_baseline = np.mean(individual_fairness_coeffs)
                    fairness_losses.append(fairness_baseline)
                elif args.fairness_version == "scale_inv_mse":
                    individual_fairness_coeffs = IndividualFairnessLoss(
                    ).get_scale_invariant_mse_coeffs(rankings, rel_labels,
                                                     position_bias_vector,
                                                     args.skip_zero_relevance)
                    fairness_baseline = np.mean(individual_fairness_coeffs
                                                ) if args.use_baseline else 0.0
                    fairness_losses.append(fairness_baseline)
                elif args.fairness_version == "asym_disparity":
                    pdiff = IndividualFairnessLoss.compute_pairwise_disparity_matrix(
                        rankings, rel_labels, position_bias_vector)
                    H_mat = IndividualFairnessLoss.get_H_matrix(rel_labels)
                    sum_h_mat = np.sum(
                        H_mat) + 1e-7  # to prevent Nans when dividing
                    # print(rel_labels, H_mat, sum_h_mat)
                    H_mat = np.tile(H_mat, (len(rankings), 1, 1))
                    pdiff_pi = np.mean(pdiff, axis=0)
                    pdiff_indicator = pdiff_pi > 0
                    pdiff_indicator = np.tile(pdiff_indicator, (len(rankings),
                                                                1, 1))

                    individual_fairness_coeffs = pdiff_indicator * H_mat * pdiff
                    individual_fairness_coeffs = np.sum(
                        individual_fairness_coeffs, axis=(1, 2)) / sum_h_mat
                    # print(pdiff_indicator.shape, H_mat.shape, pdiff_pi.shape,
                    #       pdiff.shape)
                    fairness_baseline = np.mean(individual_fairness_coeffs
                                                ) if args.use_baseline else 0.0

                elif args.fairness_version == "pairwise_disparity":
                    pairwise_disparity_matrix, pair_counts = IndividualFairnessLoss.compute_pairwise_disparity_matrix(
                        rankings,
                        rel_labels,
                        position_bias_vector,
                        conditional=False)
                    marginal_pairwise_disparity_matrix = np.mean(
                        pairwise_disparity_matrix, axis=0)

            if args.lambda_group_fairness > 0.0:
                rel_labels = np.array(rel)
                if np.sum(rel_labels[group_identities == 0]) == 0 or np.sum(
                        rel_labels[group_identities == 1]) == 0:
                    skip_this_query = True
                else:
                    skip_this_query = False
                    group_fairness_coeffs = GroupFairnessLoss.compute_group_fairness_coeffs_generic(
                        rankings, rel_labels, group_identities,
                        position_bias_vector, args.group_fairness_version,
                        args.skip_zero_relevance)

                    fairness_baseline = np.mean(np.mean(group_fairness_coeffs))

                # log the reward/dcg variance
            variances.append(np.var(rewards_list))
            epoch_rewards_list.append(np.mean(rewards_list))
            running_ndcgs_list.append(np.mean(ndcg_list))
            running_dcgs_list.append(np.mean(dcg_list))

            if i % 1000 == 0 and i != 0:
                if experiment_name is None:
                    experiment_name = ""
                if writer is not None:
                    writer.add_scalars(experiment_name + "/var_reward",
                                       {"var_reward": np.mean(variances)},
                                       epoch * len_train_set + i)
                    if fairness_evaluation:
                        writer.add_scalars(
                            experiment_name + "/mean_fairness_loss", {
                                "mean_fairness_loss": np.mean(fairness_losses)
                            }, epoch * len_train_set + i)
                variances = []
                fairness_losses = []
            optimizer.zero_grad()
            for j in range(sample_size):
                ranking = rankings[j]
                reward = rewards_list[j]

                log_model_prob = compute_log_model_probability(
                    scores, ranking, args.gpu_id)
                if args.use_baseline:
                    reinforce_loss = float(args.lambda_reward * -(
                        reward - baseline)) * log_model_prob
                else:
                    reinforce_loss = args.lambda_reward * log_model_prob * -reward
                if args.lambda_ind_fairness != 0.0:
                    if (args.fairness_version == "squared_residual") or (
                            args.fairness_version == "scale_inv_mse"):
                        individual_fairness_cost = float(
                            args.lambda_ind_fairness *
                            (individual_fairness_coeffs[j] - fairness_baseline
                             )) * log_model_prob
                    elif args.fairness_version == "cross_entropy":
                        individual_fairness_cost = float(
                            args.lambda_ind_fairness * IndividualFairnessLoss.
                            compute_cross_entropy_fairness_loss(
                                ranking, rel_labels, expected_exposures,
                                position_bias_vector)) * log_model_prob
                    elif args.fairness_version == "asym_disparity":
                        individual_fairness_cost = float(
                            args.lambda_ind_fairness *
                            (individual_fairness_coeffs[j] - fairness_baseline
                             )) * log_model_prob
                    elif args.fairness_version == "pairwise_disparity":
                        individual_fairness_cost = float(
                            args.lambda_ind_fairness *
                            (np.sum(2 * marginal_pairwise_disparity_matrix *
                                    pairwise_disparity_matrix[j]) / pair_counts
                             )) * log_model_prob
                    else:
                        print("Use a valid version of fairness constraints")
                    reinforce_loss += individual_fairness_cost
                if args.lambda_group_fairness != 0.0 and not skip_this_query:
                    group_fairness_cost = float(
                        args.lambda_group_fairness * group_fairness_coeffs[j]
                    ) * log_model_prob
                    reinforce_loss += group_fairness_cost
                # debias the loss because the model gets updated every sampled ranking
                # i.e. log_model_prob is biased
                # if debias_training:
                #     bias_corrections.append(
                #         math.exp(log_model_prob.data) / propensities[j])
                #     reinforce_loss *= bias_corrections[-1]
                # ^ not reqd anymore
                reinforce_loss.backward(retain_graph=True)
            if args.entropy_regularizer > 0.0:
                entropy_loss = args.entropy_regularizer * (
                    -get_entropy(probs_))
                entropy_loss.backward()
            optimizer.step()
        if args.save_checkpoints:
            if epoch == 0 and not os.path.exists(
                    "models/{}".format(experiment_name)):
                os.makedirs("models/{}/".format(experiment_name))
            torch.save(model, "models/{}/epoch{}.ckpt".format(
                experiment_name, epoch))
    log_and_print(
        model,
        validation_data_reader,
        writer,
        epoch,
        i,
        len_train_set,
        "val",
        experiment_name,
        args.gpu_id,
        fairness_evaluation=fairness_evaluation,
        exposure_relevance_plot=exposure_relevance_plot,
        deterministic=args.validation_deterministic,
        group_fairness_evaluation=group_fairness_evaluation,
        args=args)
    return model
Exemplo n.º 10
0
def run(args: argparse.Namespace) -> None:
    # save args to dict
    d = vars(args)
    d['time'] = str(datetime.datetime.now())
    save_dict_to_file(d,args.workdir)

    temperature: float = 0.1
    n_class: int = args.n_class
    metric_axis: List = args.metric_axis
    lr: float = args.l_rate
    dtype = eval(args.dtype)

    # Proper params
    savedir: str = args.workdir
    n_epoch: int = args.n_epoch

    net, optimizer, device, loss_fns, loss_weights, loss_fns_source, loss_weights_source, scheduler = setup(args, n_class, dtype)
    print(f'> Loss weights cons: {loss_weights}, Loss weights source:{loss_weights_source}')
    shuffle = False
    #if args.mix:
    #    shuffle = True
    #print("args.dataset",args.dataset)
    loader, loader_val = get_loaders(args, args.dataset,args.source_folders,
                                           args.batch_size, n_class,
                                           args.debug, args.in_memory, dtype, False,fix_size=[0,0])

    target_loader, target_loader_val = get_loaders(args, args.target_dataset,args.target_folders,
                                           args.batch_size, n_class,
                                           args.debug, args.in_memory, dtype, shuffle,fix_size=[0,0])

    num_steps = n_epoch * len(loader)
    #print(num_steps)
    print("metric axis",metric_axis)
    best_dice_pos: Tensor = np.zeros(1)
    best_dice: Tensor = np.zeros(1)
    best_2d_dice: Tensor = np.zeros(1)
    best_3d_dice: Tensor = np.zeros(1)
    best_3d_dice_source: Tensor = np.zeros(1)

    print("Results saved in ", savedir)
    print(">>> Starting the training")
    for i in range(n_epoch):

        tra_losses_vec, tra_target_vec,tra_source_vec                                    = do_epoch(args, "train", net, device,
                                                                                           loader, i, loss_fns,
                                                                                           loss_weights,
                                                                                           loss_fns_source,
                                                                                           loss_weights_source,
                                                                                           args.resize,
                                                                                           num_steps, n_class, metric_axis,
                                                                                           savedir="",
                                                                                           optimizer=optimizer,
                                                                                           target_loader=target_loader)

        with torch.no_grad():
            val_losses_vec, val_target_vec,val_source_vec                                        = do_epoch(args, "val", net, device,
                                                                                               loader_val, i, loss_fns,
                                                                                               loss_weights,
                                                                                               loss_fns_source,
                                                                                               loss_weights_source,
                                                                                               args.resize,
                                                                                               num_steps, n_class,metric_axis,
                                                                                               savedir=savedir,
                                                                                               target_loader=target_loader_val)

        #if i == 0:
         #   keep_tra_baseline_target_vec = tra_baseline_target_vec
          #  keep_val_baseline_target_vec = val_baseline_target_vec
        # print(keep_val_baseline_target_vec)

        # print(val_target_vec)
        # df_t_tmp = pd.DataFrame({
        #     "val_dice_3d": [val_target_vec[0]],
        #     "val_dice_3d_sd": [val_target_vec[1]]})

        df_s_tmp = pd.DataFrame({
            "tra_dice_3d": [tra_source_vec[0]],
            "tra_dice_3d_sd": [tra_source_vec[1]],
            "val_dice_3d": [val_source_vec[0]],
            "val_dice_3d_sd": [val_source_vec[1]]})

        if i == 0:
            df_s = df_s_tmp
        else:
            df_s = df_s.append(df_s_tmp)

        df_s.to_csv(Path(savedir, "_".join((args.source_folders.split("'")[1],"source", args.csv))), float_format="%.4f", index_label="epoch")


        df_t_tmp = pd.DataFrame({
            "tra_loss_inf":[tra_losses_vec[0]],
            "tra_loss_cons":[tra_losses_vec[1]],
            "tra_loss_fs":[tra_losses_vec[2]],
            "val_loss_inf":[val_losses_vec[0]],
            "val_loss_cons":[val_losses_vec[1]],
            "val_loss_fs":[val_losses_vec[2]],
            "tra_dice_3d": [tra_target_vec[0]],
            "tra_dice_3d_sd": [tra_target_vec[1]],
            "tra_dice": [tra_target_vec[2]],
            "val_dice_3d": [val_target_vec[0]],
            "val_dice_3d_sd": [val_target_vec[1]],
            'val_dice': [val_target_vec[2]]})

        if i == 0:
            df_t = df_t_tmp
        else:
            df_t = df_t.append(df_t_tmp)

        df_t.to_csv(Path(savedir, "_".join((args.target_folders.split("'")[1],"target", args.csv))), float_format="%.4f", index_label="epoch")

        # Save model if better
        current_val_target_2d_dice = val_target_vec[2]
        '''
        if current_val_target_2d_dice > best_2d_dice:
            best_epoch = i
            best_2d_dice = current_val_target_2d_dice
            with open(Path(savedir, "best_epoch_2.txt"), 'w') as f:
                f.write(str(i))
            best_folder_2d = Path(savedir, "best_epoch_2d")
            if best_folder_2d.exists():
                rmtree(best_folder_2d)
            copytree(Path(savedir, f"iter{i:03d}"), Path(best_folder_2d))
            torch.save(net, Path(savedir, "best_2d.pkl"))
        '''
        current_val_target_3d_dice = val_target_vec[0]

        if current_val_target_3d_dice > best_3d_dice:
            best_epoch = i
            best_3d_dice = current_val_target_3d_dice
            with open(Path(savedir, "best_epoch_3d.txt"), 'w') as f:
                f.write(str(i))
            best_folder_3d = Path(savedir, "best_epoch_3d")
            if best_folder_3d.exists():
                rmtree(best_folder_3d)
            copytree(Path(savedir, f"iter{i:03d}"), Path(best_folder_3d))
            torch.save(net, Path(savedir, "best_3d.pkl"))

        #Save source model if better
        current_val_source_3d_dice = val_source_vec[0]

        if current_val_source_3d_dice > best_3d_dice_source:
            best_epoch = i
            best_3d_dice_s = current_val_source_3d_dice
            with open(Path(savedir, "best_epoch_3d_source.txt"), 'w') as f:
                f.write(str(i))
            torch.save(net, Path(savedir, "best_3d_source.pkl"))

        if i == n_epoch - 1:
            with open(Path(savedir, "last_epoch.txt"), 'w') as f:
                f.write(str(i))
            last_folder = Path(savedir, "last_epoch")
            if last_folder.exists():
                rmtree(last_folder)
            copytree(Path(savedir, f"iter{i:03d}"), Path(last_folder))
            torch.save(net, Path(savedir, "last.pkl"))

        # remove images from iteration
        rmtree(Path(savedir, f"iter{i:03d}"))

        if args.flr==False:
            #adjust_learning_rate(optimizer, i, args.l_rate, n_epoch, 0.9)
            exp_lr_scheduler(optimizer, i, args.lr_decay)
    print("Results saved in ", savedir)
Exemplo n.º 11
0
    def train(self):

        curr_iter = 0

        reallabel = torch.FloatTensor(self.opt.batchSize).fill_(
            self.real_label_val)
        fakelabel = torch.FloatTensor(self.opt.batchSize).fill_(
            self.fake_label_val)
        if self.opt.gpu >= 0:
            reallabel, fakelabel = reallabel.cuda(), fakelabel.cuda()
        reallabelv = Variable(reallabel)
        fakelabelv = Variable(fakelabel)

        # parameters
        src_hflip = False
        src_xlat_range = 2.0
        src_affine_std = 0.1
        src_intens_flip = False
        src_intens_scale_range_lower = -1.5
        src_intens_scale_range_upper = 1.5
        src_intens_offset_range_lower = -0.5
        src_intens_offset_range_upper = 0.5
        src_gaussian_noise_std = 0.1
        tgt_hflip = False
        tgt_xlat_range = 2.0
        tgt_affine_std = 0.1
        tgt_intens_flip = False
        tgt_intens_scale_range_lower = -1.5
        tgt_intens_scale_range_upper = 1.5
        tgt_intens_offset_range_lower = -0.5
        tgt_intens_offset_range_upper = 0.5
        tgt_gaussian_noise_std = 0.1

        # augmentation function
        src_aug = augmentation.ImageAugmentation(
            src_hflip,
            src_xlat_range,
            src_affine_std,
            intens_flip=src_intens_flip,
            intens_scale_range_lower=src_intens_scale_range_lower,
            intens_scale_range_upper=src_intens_scale_range_upper,
            intens_offset_range_lower=src_intens_offset_range_lower,
            intens_offset_range_upper=src_intens_offset_range_upper,
            gaussian_noise_std=src_gaussian_noise_std)
        tgt_aug = augmentation.ImageAugmentation(
            tgt_hflip,
            tgt_xlat_range,
            tgt_affine_std,
            intens_flip=tgt_intens_flip,
            intens_scale_range_lower=tgt_intens_scale_range_lower,
            intens_scale_range_upper=tgt_intens_scale_range_upper,
            intens_offset_range_lower=tgt_intens_offset_range_lower,
            intens_offset_range_upper=tgt_intens_offset_range_upper,
            gaussian_noise_std=tgt_gaussian_noise_std)

        combine_batches = False

        if combine_batches:

            def augment(X_sup, y_src, X_tgt):
                X_src_stu, X_src_tea = src_aug.augment_pair(X_sup)
                X_tgt_stu, X_tgt_tea = tgt_aug.augment_pair(X_tgt)
                return X_src_stu, X_src_tea, y_src, X_tgt_stu, X_tgt_tea
        else:

            def augment(X_src, y_src, X_tgt):
                X_src = src_aug.augment(X_src)
                X_tgt_stu, X_tgt_tea = tgt_aug.augment_pair(X_tgt)
                return X_src, y_src, X_tgt_stu, X_tgt_tea

        for epoch in range(self.opt.nepochs):

            self.netG.train()
            self.netF.train()
            self.netC.train()
            self.netD.train()

            for i, (datas, datat) in enumerate(
                    zip(self.source_trainloader, self.targetloader)):

                ###########################
                # Forming input variables
                ###########################

                src_inputs, src_labels = datas
                tgt_inputs, __ = datat
                if self.augment:
                    if combine_batches:
                        src_inputs, _, src_labels, tgt_inputs, _ = augment(
                            src_inputs, src_labels, tgt_inputs)
                    else:
                        src_inputs, src_labels, tgt_inputs, _ = augment(
                            src_inputs.numpy(), src_labels.numpy(),
                            tgt_inputs.numpy())
                    src_inputs = torch.FloatTensor(src_inputs)
                    src_labels = torch.LongTensor(src_labels)
                    tgt_inputs = torch.FloatTensor(tgt_inputs)

                src_inputs_unnorm = ((
                    (src_inputs * self.std[0]) + self.mean[0]) - 0.5) * 2

                # Creating one hot vector
                labels_onehot = np.zeros(
                    (self.opt.batchSize, self.nclasses + 1), dtype=np.float32)
                for num in range(self.opt.batchSize):
                    labels_onehot[num, src_labels[num]] = 1
                src_labels_onehot = torch.from_numpy(labels_onehot)

                labels_onehot = np.zeros(
                    (self.opt.batchSize, self.nclasses + 1), dtype=np.float32)
                for num in range(self.opt.batchSize):
                    labels_onehot[num, self.nclasses] = 1
                tgt_labels_onehot = torch.from_numpy(labels_onehot)

                if self.opt.gpu >= 0:
                    src_inputs, src_labels = src_inputs.cuda(
                    ), src_labels.cuda()
                    src_inputs_unnorm = src_inputs_unnorm.cuda()
                    tgt_inputs = tgt_inputs.cuda()
                    src_labels_onehot = src_labels_onehot.cuda()
                    tgt_labels_onehot = tgt_labels_onehot.cuda()

                # Wrapping in variable
                src_inputsv, src_labelsv = Variable(src_inputs), Variable(
                    src_labels)
                src_inputs_unnormv = Variable(src_inputs_unnorm)
                tgt_inputsv = Variable(tgt_inputs)
                src_labels_onehotv = Variable(src_labels_onehot)
                tgt_labels_onehotv = Variable(tgt_labels_onehot)

                ###########################
                # Updates
                ###########################

                # Updating D network

                self.netD.zero_grad()
                src_emb = self.netF(src_inputsv)
                src_emb_cat = torch.cat((src_labels_onehotv, src_emb), 1)
                src_gen = self.netG(src_emb_cat)

                tgt_emb = self.netF(tgt_inputsv)
                tgt_emb_cat = torch.cat((tgt_labels_onehotv, tgt_emb), 1)
                tgt_gen = self.netG(tgt_emb_cat)

                src_realoutputD_s, src_realoutputD_c = self.netD(
                    src_inputs_unnormv)
                errD_src_real_s = self.criterion_s(src_realoutputD_s,
                                                   reallabelv)
                errD_src_real_c = self.criterion_c(src_realoutputD_c,
                                                   src_labelsv)

                src_fakeoutputD_s, src_fakeoutputD_c = self.netD(src_gen)
                errD_src_fake_s = self.criterion_s(src_fakeoutputD_s,
                                                   fakelabelv)

                tgt_fakeoutputD_s, tgt_fakeoutputD_c = self.netD(tgt_gen)
                errD_tgt_fake_s = self.criterion_s(tgt_fakeoutputD_s,
                                                   fakelabelv)

                errD = errD_src_real_c + errD_src_real_s + errD_src_fake_s + errD_tgt_fake_s
                #TODO add CBL to D loss
                if self.class_balance > 0.0:
                    avg_cls_prob = torch.mean(tgt_fakeoutputD_c, 0)
                    equalise_cls_loss = self.cls_bal_fn(
                        avg_cls_prob, float(1.0 / self.nclasses))
                    equalise_cls_loss = torch.mean(
                        equalise_cls_loss) * self.nclasses
                    errD += equalise_cls_loss * self.class_balance
                errD.backward(retain_graph=True)
                self.optimizerD.step()

                # Updating G network

                self.netG.zero_grad()
                src_fakeoutputD_s, src_fakeoutputD_c = self.netD(src_gen)
                errG_c = self.criterion_c(src_fakeoutputD_c, src_labelsv)
                errG_s = self.criterion_s(src_fakeoutputD_s, reallabelv)
                errG = errG_c + errG_s
                errG.backward(retain_graph=True)
                self.optimizerG.step()

                # Updating C network

                self.netC.zero_grad()
                outC = self.netC(src_emb)
                errC = self.criterion_c(outC, src_labelsv)
                errC.backward(retain_graph=True)
                self.optimizerC.step()

                # Updating F network

                self.netF.zero_grad()
                errF_fromC = self.criterion_c(outC, src_labelsv)

                src_fakeoutputD_s, src_fakeoutputD_c = self.netD(src_gen)
                errF_src_fromD = self.criterion_c(
                    src_fakeoutputD_c, src_labelsv) * (self.opt.adv_weight)

                tgt_fakeoutputD_s, tgt_fakeoutputD_c = self.netD(tgt_gen)

                #TODO add CBL to D gradient
                errF_tgt_fromD = self.criterion_s(
                    tgt_fakeoutputD_s,
                    reallabelv) * (self.opt.adv_weight * self.opt.alpha)

                errF = errF_fromC + errF_src_fromD + errF_tgt_fromD
                if self.class_balance > 0.0:
                    avg_cls_prob = torch.mean(tgt_fakeoutputD_c, 0)
                    equalise_cls_loss = self.cls_bal_fn(
                        avg_cls_prob, float(1.0 / self.nclasses))
                    equalise_cls_loss = torch.mean(
                        equalise_cls_loss) * self.nclasses
                    errF += equalise_cls_loss * self.class_balance

                errF.backward()
                self.optimizerF.step()

                curr_iter += 1

                # Visualization
                if i == 1:
                    vutils.save_image((src_gen.data / 2) + 0.5,
                                      '%s/visualization/source_gen_%d.png' %
                                      (self.opt.outf, epoch))
                    vutils.save_image((tgt_gen.data / 2) + 0.5,
                                      '%s/visualization/target_gen_%d.png' %
                                      (self.opt.outf, epoch))

                # Learning rate scheduling
                if self.opt.lrd:
                    self.optimizerD = utils.exp_lr_scheduler(
                        self.optimizerD, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)
                    self.optimizerF = utils.exp_lr_scheduler(
                        self.optimizerF, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)
                    self.optimizerC = utils.exp_lr_scheduler(
                        self.optimizerC, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)

            # Validate every epoch
            self.validate(epoch + 1)
Exemplo n.º 12
0
model = OCRModel(num_chars=NUM_TOKENS)
device = torch.device("cuda:2")
model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
encoder = model.encoder

train_dataset = Dataset(train_dir)
dataloader = torch.utils.data.DataLoader(train_dataset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         collate_fn=pack,
                                         num_workers=num_workers)
ctc_loss = nn.CTCLoss(blank=BLANK, zero_infinity=True)
for epoch in range(num_epochs):
    exp_lr_scheduler(optimizer, epoch, init_lr=lr, lr_decay_epoch=decay_epoch)

    pbar = tqdm(iter(dataloader), total=len(dataloader))
    for x, y_padded, y_packed, y_lengths in pbar:
        x = x.to(device=device)
        y_padded = y_padded.to(device)
        y_packed = y_packed.to(device=device)

        logits = encoder(x, device)
        logits, input_lengths = nn.utils.rnn.pad_packed_sequence(
            logits, batch_first=False)
        # check the shape
        L, N, C = logits.shape
        assert (L, N, C) == (max(input_lengths), x.batch_sizes[0], NUM_TOKENS)

        N, L = y_padded.shape
Exemplo n.º 13
0
def run(args: argparse.Namespace) -> None:
    # save args to dict
    d = vars(args)
    d['time'] = str(datetime.datetime.now())
    d['server']=platform.node()
    save_dict_to_file(d,args.workdir)

    temperature: float = 0.1
    n_class: int = args.n_class
    metric_axis: List = args.metric_axis
    lr: float = args.l_rate
    dtype = eval(args.dtype)

    # Proper params
    savedir: str = args.workdir
    n_epoch: int = args.n_epoch

    net, optimizer, device, loss_fns, loss_weights, scheduler, n_epoch = setup(args, n_class, dtype)
    shuffle = True
    print(args.target_folders)
    target_loader, target_loader_val = get_loaders(args, args.target_dataset,args.target_folders,
                                           args.batch_size, n_class,
                                           args.debug, args.in_memory, dtype, shuffle, "target", args.val_target_folders)

    print("metric axis",metric_axis)
    best_dice_pos: Tensor = np.zeros(1)
    best_dice: Tensor = np.zeros(1)
    best_hd3d_dice: Tensor = np.zeros(1)
    best_3d_dice: Tensor = 0 
    best_2d_dice: Tensor = 0 
    print("Results saved in ", savedir)
    print(">>> Starting the training")
    for i in range(n_epoch):

       if args.mode =="makeim":
            with torch.no_grad():
                 
                val_losses_vec, val_target_vec,val_source_vec                                        = do_epoch(args, "val", net, device,
                                                                                                i, loss_fns,
                                                                                               loss_weights,
                                                                                               args.resize,
                                                                                                n_class,metric_axis,
                                                                                               savedir=savedir,
                                                                                               target_loader=target_loader_val, best_dice3d_val=best_3d_dice)
                tra_losses_vec = val_losses_vec
                tra_target_vec = val_target_vec
                tra_source_vec = val_source_vec
       else:
            tra_losses_vec, tra_target_vec,tra_source_vec                                    = do_epoch(args, "train", net, device,
                                                                                           i, loss_fns,
                                                                                           loss_weights,
                                                                                           args.resize,
                                                                                           n_class, metric_axis,
                                                                                           savedir=savedir,
                                                                                           optimizer=optimizer,
                                                                                           target_loader=target_loader, best_dice3d_val=best_3d_dice)
       
            with torch.no_grad():
                val_losses_vec, val_target_vec,val_source_vec                                        = do_epoch(args, "val", net, device,
                                                                                               i, loss_fns,
                                                                                               loss_weights,
                                                                                               args.resize,
                                                                                               n_class,metric_axis,
                                                                                               savedir=savedir,
                                                                                               target_loader=target_loader_val, best_dice3d_val=best_3d_dice)

       current_val_target_3d_dice = val_target_vec[0]
       if args.dice_3d:
           if current_val_target_3d_dice > best_3d_dice:
               best_3d_dice = current_val_target_3d_dice
               with open(Path(savedir, "3dbestepoch.txt"), 'w') as f:
                   f.write(str(i)+','+str(best_3d_dice))
               best_folder_3d = Path(savedir, "best_epoch_3d")
               if best_folder_3d.exists():
                    rmtree(best_folder_3d)
               if args.saveim:
                    copytree(Path(savedir, f"iter{i:03d}"), Path(best_folder_3d))
           torch.save(net, Path(savedir, "best_3d.pkl"))

       
       if not(i % 10) :
            print("epoch",str(i),savedir,'best 3d dice',best_3d_dice)
            torch.save(net, Path(savedir, "epoch_"+str(i)+".pkl"))
       
       if i == n_epoch - 1:
            with open(Path(savedir, "last_epoch.txt"), 'w') as f:
                f.write(str(i))
            last_folder = Path(savedir, "last_epoch")
            if last_folder.exists():
                rmtree(last_folder)
            if args.saveim:
                copytree(Path(savedir, f"iter{i:03d}"), Path(last_folder))
            torch.save(net, Path(savedir, "last.pkl"))

        # remove images from iteration
       if args.saveim:
           rmtree(Path(savedir, f"iter{i:03d}"))

       if args.source_metrics:
            df_s_tmp = pd.DataFrame({
            "val_dice_3d": [val_source_vec[0]],
            "val_dice_3d_sd": [val_source_vec[1]],
            "val_dice_2d": [val_source_vec[2]]})
            if i == 0:
               df_s = df_s_tmp
            else:
                df_s = df_s.append(df_s_tmp)
            df_s.to_csv(Path(savedir, "_".join((args.source_folders.split("'")[1],"source", args.csv))), float_format="%.4f", index_label="epoch")
       df_t_tmp = pd.DataFrame({
            "epoch":i,
            "tra_loss_s":[tra_losses_vec[0]],
            "tra_loss_cons":[tra_losses_vec[1]],
            "tra_loss_tot":[tra_losses_vec[2]],
            "tra_size_mean":[tra_losses_vec[3]],
            "tra_size_mean_pos":[tra_losses_vec[4]],
            "val_loss_s":[val_losses_vec[0]],
            "val_loss_cons":[val_losses_vec[1]],
            "val_loss_tot":[val_losses_vec[2]],
            "val_size_mean":[val_losses_vec[3]],
            "val_size_mean_pos":[val_losses_vec[4]],
            "val_gt_size_mean":[val_losses_vec[5]],
            "val_gt_size_mean_pos":[val_losses_vec[6]],
            'tra_dice': [tra_target_vec[4]],
           'val_asd': [val_target_vec[2]],
           'val_asd_sd': [val_target_vec[3]],
            'val_hd': [val_target_vec[4]],
            'val_hd_sd': [val_target_vec[5]],
            'val_dice': [val_target_vec[6]],
            "val_dice_3d_sd": [val_target_vec[1]],
            "val_dice_3d": [val_target_vec[0]]})

       if i == 0:
            df_t = df_t_tmp
       else:
            df_t = df_t.append(df_t_tmp)

       df_t.to_csv(Path(savedir, "_".join((args.target_folders.split("'")[1],"target", args.csv))), float_format="%.4f", index=False)

       if args.flr==False:
            exp_lr_scheduler(optimizer, i, args.lr_decay,args.lr_decay_epoch)
    print("Results saved in ", savedir, "best 3d dice",best_3d_dice)
Exemplo n.º 14
0
    def train(self,
              trainloader,
              testloader,
              classes,
              epochs=10,
              img_interval=1,
              cuda=True,
              path=None,
              g_op=None,
              d_op=None,
              init_epoch=0):
        """
        Training function
        :param optimizers: used optimizers (2 entries)
        :param criterions: loss functions (3 entries)
        :param trainloader: dataset in batch
        :param testloader: testset in batch
        :param epochs: epochs
        :param batch_size: batch_size
        :param img_interval: interval for showing the images produced
        :param cuda: usage of gpu
        :return: Loss list, accuracy list
        """

        lr_init = 0.001

        print("Initializing optimizers")
        if self.retrain:
            g_opt = g_op
            d_opt = d_op
        else:
            g_opt = torch.optim.RMSprop(self.G.parameters(),
                                        lr=lr_init,
                                        alpha=0.9)
            d_opt = torch.optim.RMSprop(self.D.parameters(),
                                        lr=lr_init,
                                        alpha=0.9)

        print("Initializing loss dataframe")
        pd_loss = pd.DataFrame(columns=['epoch', 'd_loss', 'g_loss'])
        path_loss_folder = path + "/Wikiart_loss"
        if self.retrain:
            path_loss = path_loss_folder + "/loss_retrain.csv"
        else:
            path_loss = path_loss_folder + "/loss.csv"
        if not os.path.exists(path_loss_folder):
            os.makedirs(path_loss_folder)
        pd_loss.to_csv(path_loss, index=False)

        print("Beginning epochs . . .")
        for epoch in range(init_epoch, epochs):
            # Save loss
            g_loss_l = []
            d_loss_l = []

            # Decay in the learning rate
            d_opt = utils.exp_lr_scheduler(d_opt, epoch)
            g_opt = utils.exp_lr_scheduler(g_opt, epoch)

            # import tqdm
            for i, data in enumerate(tqdm(trainloader), 0):

                # zero grad
                d_opt.zero_grad()
                # get the inputs
                x_r, k = data
                b_s = len(k)
                # generate z_hat
                z_hat = utils.gen_z(b_s, self.z_dim)
                # print("z_hat = ", z_hat.size())
                # generate Y_k and its label
                y_k = utils.gen_yk(b_s, self.num_classes)
                # print("y_k = ", y_k.size())
                # gen fakes
                y_fake = utils.fake_v(b_s, self.num_classes)
                # print("y_fake = ", y_fake.size())

                t_zeros = torch.zeros(b_s, 1)
                k_hot = F.one_hot(k, self.num_classes + 1)
                # print("k_hot = ", k_hot.size())
                # This other cuda is so that y_k_hot is created correctly

                y_fake = y_fake.type(torch.int64)
                y_k_hot = F.one_hot(y_fake, self.num_classes + 1)

                if cuda:
                    y_k = y_k.type(torch.cuda.FloatTensor)
                    x_r = x_r.type(torch.cuda.FloatTensor)
                    k = k.type(torch.cuda.LongTensor)
                    z_hat = z_hat.type(torch.cuda.FloatTensor)
                    y_k = y_k.type(torch.cuda.FloatTensor)
                    k_hot = k_hot.type(torch.cuda.FloatTensor)
                    y_k_hot = y_k_hot.type(torch.cuda.FloatTensor)
                    t_zeros = t_zeros.type(torch.cuda.FloatTensor)

                else:
                    k_hot = k_hot.type(torch.FloatTensor)
                    y_k_hot = y_k_hot.type(torch.FloatTensor)

                # calculate X_hat
                in_G = torch.cat([z_hat, y_k], 1)
                # print("y_k_hot = ", y_k_hot.size())
                # calculate Y
                y = self.D(x_r)
                # Calculate Y_hat
                x_hat = self.G(in_G)
                y_hat = self.D(x_hat)
                # update D
                d_real_loss = F.binary_cross_entropy(y, k_hot)
                d_fake_loss = F.binary_cross_entropy(y_hat, y_k_hot)
                d_loss = d_real_loss + d_fake_loss
                d_loss_l.append(d_loss.item())
                d_loss.backward(retain_graph=True)
                d_opt.step()

                # zero grad
                g_opt.zero_grad()

                # adversarial loss
                new_y_hat = self.D(x_hat)
                # print("new_y_hat = ", new_y_hat.size())
                new_y_k_hot = torch.cat([y_k, t_zeros], 1)
                # print("new_y_k_hot = ", new_y_k_hot.size())
                g_loss_adv = F.binary_cross_entropy(new_y_hat, new_y_k_hot)

                # L2 loss
                # calculate z
                z = self.D.enc(x_r)
                # print("z = ", z.size())
                # calculate X_hat_z
                x_hat_z = self.G.dec(z)
                g_loss_l2 = torch.mean((x_hat_z - x_r)**2)

                # + g_loss_adv
                g_loss = g_loss_l2 + g_loss_adv
                g_loss.backward()
                g_loss_l.append(g_loss.item())
                g_opt.step()

            d = {'epoch': epoch, 'd_loss': d_loss_l, 'g_loss': g_loss_l}
            pd_loss = pd.read_csv(path_loss)
            pd_loss = pd_loss.append(pd.DataFrame(data=d), ignore_index=True)
            pd_loss.to_csv(path_loss, index=False)

            # print image
            if ((epoch + 1) % img_interval == 0):
                utils.save_img(self.G,
                               self.D,
                               epoch,
                               classes,
                               path=path,
                               test_num=len(classes) - 1)
                name_net_folder = path + "/Wikiart_nets"
                name_net = name_net_folder + "/nn_" + str(epoch) + ".pt"
                if not os.path.exists(name_net_folder):
                    os.makedirs(name_net_folder)
                torch.save(
                    {
                        'epoch': epoch,
                        'G': self.G.state_dict(),
                        'D': self.D.state_dict(),
                        'opt_G': g_opt.state_dict(),
                        'opt_D': d_opt.state_dict(),
                    }, name_net)

        return d_loss_l, g_loss_l
Exemplo n.º 15
0
    def train(self):
        task_criterion = nn.CrossEntropyLoss()
        recon_criterion = SIMSE()
        diff_criterion = DiffLoss()
        sim_criterion = nn.CrossEntropyLoss()
        fix_src_data, _ = next(iter(self.src_valset_loader))
        fix_src_data = fix_src_data.to(self.device)
        if not self.src_only:
            fix_tgt_data, _ = next(iter(self.tgt_trainset_loader))
            fix_tgt_data = fix_tgt_data.to(self.device)

        best_acc = 0
        best_loss = 1e15
        iteration = 0
        if self.resume_iters:
            print("resuming step %d ..." % self.resume_iters)
            iteration = self.resume_iters
            self.load_checkpoint(self.resume_iters)
            best_loss, best_acc = self.eval()

        while iteration < self.num_iters:
            self.model.train()
            self.optimizer.zero_grad()
            loss = 0.0

            if self.src_only:
                tgt_domain_loss = torch.zeros(1)
                tgt_recon_loss = torch.zeros(1)
                tgt_diff_loss = torch.zeros(1)

            else:
                try:
                    tgt_data, _ = next(tgt_data_iter)
                except:
                    tgt_data_iter = iter(self.tgt_trainset_loader)
                    tgt_data, _ = next(tgt_data_iter)

                tgt_data = tgt_data.to(self.device)
                tgt_batch_size = len(tgt_data)

                if iteration > self.active_domain_loss_step:
                    p = float(iteration - self.active_domain_loss_step) / (
                        self.num_iters - self.active_domain_loss_step)
                    p = 2. / (1. + np.exp(-10 * p)) - 1

                    _, tgt_domain_output, tgt_private_code, tgt_shared_code, tgt_recon = self.model(
                        tgt_data, mode='target', p=p)
                    tgt_domain_label = torch.ones((tgt_batch_size, ),
                                                  dtype=torch.long,
                                                  device=self.device)
                    tgt_domain_loss = sim_criterion(tgt_domain_output,
                                                    tgt_domain_label)
                    loss += self.gamma_weight * tgt_domain_loss

                else:
                    _, tgt_domain_output, tgt_private_code, tgt_shared_code, tgt_recon = self.model(
                        tgt_data, mode='target')
                    tgt_domain_loss = torch.zeros(1)

                tgt_recon_loss = recon_criterion(tgt_recon, tgt_data)
                tgt_diff_loss = diff_criterion(tgt_private_code,
                                               tgt_shared_code)

                loss += (self.alpha_weight * tgt_recon_loss +
                         self.beta_weight * tgt_diff_loss)

            try:
                src_data, src_class_label = next(src_data_iter)
            except:
                src_data_iter = iter(self.src_trainset_loader)
                src_data, src_class_label = next(src_data_iter)

            src_data, src_class_label = src_data.to(
                self.device), src_class_label.to(self.device)
            src_batch_size = src_data.size(0)

            if iteration > self.active_domain_loss_step:
                p = float(iteration - self.active_domain_loss_step) / (
                    self.num_iters - self.active_domain_loss_step)
                p = 2. / (1. + np.exp(-10 * p)) - 1

                src_class_output, src_domain_output, src_private_code, src_shared_code, src_recon = self.model(
                    src_data, mode='source', p=p)
                src_domain_label = torch.zeros((src_batch_size, ),
                                               dtype=torch.long,
                                               device=self.device)
                src_domain_loss = sim_criterion(src_domain_output,
                                                src_domain_label)
                loss += self.gamma_weight * src_domain_loss

            else:
                src_class_output, src_domain_output, src_private_code, src_shared_code, src_recon = self.model(
                    src_data, mode='source')
                src_domain_loss = torch.zeros(1)

            src_class_loss = task_criterion(src_class_output, src_class_label)
            src_recon_loss = recon_criterion(src_recon, src_data)
            src_diff_loss = diff_criterion(src_private_code, src_shared_code)

            loss += (src_class_loss + self.alpha_weight * src_recon_loss +
                     self.beta_weight * src_diff_loss)

            loss.backward()
            self.optimizer = exp_lr_scheduler(
                optimizer=self.optimizer,
                step=iteration,
                init_lr=self.lr,
                lr_decay_step=self.num_iters_decay,
                step_decay_weight=self.step_decay_weight)
            self.optimizer.step()

            # Output training stats
            if (iteration + 1) % self.log_interval == 0:
                print(
                    'Iteration: {:5d} / {:d} loss: {:.6f} loss_src_class: {:.6f} loss_src_domain: {:.6f} loss_src_recon: {:.6f} loss_src_diff: {:.6f} loss_tgt_domain: {:.6f} loss_tgt_recon: {:.6f} loss_tgt_diff: {:.6f}'
                    .format(iteration + 1, self.num_iters, loss.item(),
                            src_class_loss.item(), src_domain_loss.item(),
                            src_recon_loss.item(), src_diff_loss.item(),
                            tgt_domain_loss.item(), tgt_recon_loss.item(),
                            tgt_diff_loss.item()))

                if self.use_wandb:
                    import wandb
                    wandb.log(
                        {
                            "loss": loss.item(),
                            "loss_src_class": src_class_loss.item(),
                            "loss_src_domain": src_domain_loss.item(),
                            "loss_src_recon": src_recon_loss.item(),
                            "loss_src_diff": src_diff_loss.item(),
                            "loss_tgt_domain": tgt_domain_loss.item(),
                            "loss_tgt_recon": tgt_recon_loss.item(),
                            "loss_tgt_diff": tgt_diff_loss.item()
                        },
                        step=iteration + 1)

                # Save model checkpoints
            if (iteration + 1) % self.save_interval == 0 and iteration > 0:
                val_loss, val_acc = self.eval()
                if self.use_wandb:
                    import wandb
                    wandb.log({
                        "val_loss": val_loss,
                        "val_acc": val_acc
                    },
                              step=iteration + 1,
                              commit=False)

                self.save_checkpoint(iteration)

                if (val_acc > best_acc):
                    print('val acc: %.2f > %.2f' % (val_acc, best_acc))
                    best_acc = val_acc
                if (val_loss < best_loss):
                    print('val loss: %.4f < %.4f' % (val_loss, best_loss))
                    best_loss = val_loss

                _, _, _, _, rec_all = self.model(fix_src_data,
                                                 mode='source',
                                                 rec_scheme='all')
                _, _, _, _, rec_share = self.model(fix_src_data,
                                                   mode='source',
                                                   rec_scheme='share')
                _, _, _, _, rec_private = self.model(fix_src_data,
                                                     mode='source',
                                                     rec_scheme='private')
                vutils.save_image(torch.cat(
                    (fix_src_data, rec_all, rec_share, rec_private)),
                                  os.path.join(self.example_dir,
                                               '%d_src.png' % (iteration + 1)),
                                  nrow=16,
                                  normalize=True)

                if not self.src_only:
                    _, _, _, _, rec_all = self.model(fix_tgt_data,
                                                     mode='target',
                                                     rec_scheme='all')
                    _, _, _, _, rec_share = self.model(fix_tgt_data,
                                                       mode='target',
                                                       rec_scheme='share')
                    _, _, _, _, rec_private = self.model(fix_tgt_data,
                                                         mode='target',
                                                         rec_scheme='private')
                    vutils.save_image(torch.cat(
                        (fix_tgt_data, rec_all, rec_share, rec_private)),
                                      os.path.join(
                                          self.example_dir,
                                          '%d_tgt.png' % (iteration + 1)),
                                      nrow=16,
                                      normalize=True)

            iteration += 1
Exemplo n.º 16
0
    def train(self):

        curr_iter = 0
        acc_list = []

        for epoch in range(self.opt.nepochs):
            source_trainloader = DataLoader(self.source_dataset,
                                            batch_size=self.opt.batchSize,
                                            shuffle=True,
                                            num_workers=1)
            targetloader = DataLoader(self.target_dataset,
                                      batch_size=self.opt.batchSize,
                                      shuffle=True,
                                      num_workers=1)
            self.netG.train()
            self.netF.train()
            self.netC.train()
            self.netD.train()

            len_dataloader = min(len(source_trainloader), len(targetloader))

            for (data_src,
                 data_tar) in tqdm.tqdm(zip(enumerate(source_trainloader),
                                            enumerate(targetloader)),
                                        total=len_dataloader,
                                        leave=False):
                self.real_label_val = 0.9 + torch.rand(1).item() * 0.1
                self.fake_label_val = 0 + torch.rand(1).item() * 0.1
                ###########################
                # Forming input variables
                ###########################

                batch_idx, (src_inputs, src_labels) = data_src
                _, (tgt_inputs, _) = data_tar
                #src_inputs_unnorm = (((src_inputs*self.std[0]) + self.mean[0]) - 0.5)*2

                if src_inputs.shape[0] != tgt_inputs.shape[0]:
                    continue
                # Creating one hot vector
                size = src_inputs.shape[0]
                labels_onehot = np.zeros((size, self.nclasses + 1),
                                         dtype=np.float32)
                for num in range(size):
                    labels_onehot[num, src_labels[num]] = 1
                src_labels_onehot = torch.from_numpy(labels_onehot)

                labels_onehot = np.zeros((size, self.nclasses + 1),
                                         dtype=np.float32)
                for num in range(size):
                    labels_onehot[num, self.nclasses] = 1
                tgt_labels_onehot = torch.from_numpy(labels_onehot)

                reallabel = torch.FloatTensor(size).fill_(self.real_label_val)
                fakelabel = torch.FloatTensor(size).fill_(self.fake_label_val)
                if self.opt.gpu >= 0:
                    reallabel, fakelabel = reallabel.cuda(), fakelabel.cuda()
                reallabelv = Variable(reallabel)
                fakelabelv = Variable(fakelabel)

                if self.opt.gpu >= 0:
                    src_inputs, src_labels = src_inputs.cuda(
                    ), src_labels.cuda()
                    #src_inputs_unnorm = src_inputs_unnorm.cuda()
                    tgt_inputs = tgt_inputs.cuda()
                    src_labels_onehot = src_labels_onehot.cuda()
                    tgt_labels_onehot = tgt_labels_onehot.cuda()

                # Wrapping in variable
                src_inputsv, src_labelsv = Variable(src_inputs), Variable(
                    src_labels)
                #src_inputs_unnormv = Variable(src_inputs_unnorm)
                tgt_inputsv = Variable(tgt_inputs)
                src_labels_onehotv = Variable(src_labels_onehot)
                tgt_labels_onehotv = Variable(tgt_labels_onehot)

                ###########################
                # Updates
                ###########################

                # Updating D network

                self.netD.zero_grad()
                src_emb = self.netF(src_inputsv)
                src_emb_cat = torch.cat((src_labels_onehotv, src_emb), 1)
                src_gen = self.netG(src_emb_cat)

                tgt_emb = self.netF(tgt_inputsv)
                tgt_emb_cat = torch.cat((tgt_labels_onehotv, tgt_emb), 1)
                tgt_gen = self.netG(tgt_emb_cat)

                # 真实源域图片用来辨别真实并且分类
                src_realoutputD_s, src_realoutputD_c = self.netD(src_inputsv)
                errD_src_real_s = self.criterion_s(src_realoutputD_s,
                                                   reallabelv)
                errD_src_real_c = self.criterion_c(src_realoutputD_c,
                                                   src_labelsv)

                # 生成的源域图片用来辨别虚假
                src_fakeoutputD_s, src_fakeoutputD_c = self.netD(src_gen)
                errD_src_fake_s = self.criterion_s(src_fakeoutputD_s,
                                                   fakelabelv)

                # 生成的目标域图片用来辨别虚假
                tgt_fakeoutputD_s, tgt_fakeoutputD_c = self.netD(tgt_gen)
                errD_tgt_fake_s = self.criterion_s(tgt_fakeoutputD_s,
                                                   fakelabelv)

                errD = errD_src_real_c + errD_src_real_s + errD_src_fake_s + errD_tgt_fake_s
                errD.backward(retain_graph=True)
                self.optimizerD.step()

                self.netG.zero_grad()
                src_fakeoutputD_s, src_fakeoutputD_c = self.netD(src_gen)
                # G让生成的图像仍然可以被正确地分类
                errG_c = self.criterion_c(src_fakeoutputD_c, src_labelsv)
                # G的目标是要让生成的图像被认为是真实的
                errG_s = self.criterion_s(src_fakeoutputD_s, reallabelv)
                errG = errG_c + errG_s
                errG.backward(retain_graph=True)
                self.optimizerG.step()

                # Updating C network
                # C使Embedding更好地被分类
                self.netC.zero_grad()
                outC = self.netC(src_emb)
                errC = self.criterion_c(outC, src_labelsv)
                errC.backward(retain_graph=True)
                self.optimizerC.step()

                # Updating F network
                self.netF.zero_grad()
                # 这个和上面的errC是一样的,但是因为前面的时候F已经zero_grad,所以再传一次
                errF_fromC = self.criterion_c(outC, src_labelsv)
                # 让生成的source能被分类得更准确
                src_fakeoutputD_s, src_fakeoutputD_c = self.netD(src_gen)
                errF_src_fromD = self.criterion_c(
                    src_fakeoutputD_c, src_labelsv) * (self.opt.adv_weight)

                # 让生成的target被认为是真实的图像 (但是为什么不加上让生成的source被认为真实这一项呢?)
                tgt_fakeoutputD_s, tgt_fakeoutputD_c = self.netD(tgt_gen)
                errF_tgt_fromD = self.criterion_s(
                    tgt_fakeoutputD_s,
                    reallabelv) * (self.opt.adv_weight * self.opt.alpha)

                errF = errF_fromC + errF_src_fromD + errF_tgt_fromD
                errF.backward()
                self.optimizerF.step()

                curr_iter += 1

                # Learning rate scheduling
                if self.opt.lrd:
                    self.optimizerD = utils.exp_lr_scheduler(
                        self.optimizerD, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)
                    self.optimizerF = utils.exp_lr_scheduler(
                        self.optimizerF, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)
                    self.optimizerC = utils.exp_lr_scheduler(
                        self.optimizerC, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)

                print('ErrD: %.2f, ErrG: %.2f, ErrC: %.2f, ErrF: %.2f' %
                      (errD.item(), errG.item(), errC.item(), errF.item()))

            # Validate every epoch
            test_acc = self.validate(epoch + 1)
            acc_list.append(test_acc)
        return acc_list
Exemplo n.º 17
0
    def train(self):

        curr_iter = 0

        reallabel = torch.FloatTensor(self.opt.batchSize).fill_(
            self.real_label_val)
        fakelabel = torch.FloatTensor(self.opt.batchSize).fill_(
            self.fake_label_val)
        if self.opt.gpu >= 0:
            reallabel = reallabel.cuda()
            fakelabel = fakelabel.cuda()
        reallabelv = Variable(reallabel)
        fakelabelv = Variable(fakelabel)

        for epoch in range(self.opt.nepochs):

            self.netG.train()
            self.netF.train()
            self.netC.train()
            self.netD.train()

            for i, (datas, datat) in enumerate(
                    itertools.izip(self.source_trainloader,
                                   self.target_trainloader)):

                ###########################
                # Forming input variables
                ###########################

                src_inputs, src_labels = datas
                tgt_inputs, __ = datat
                src_inputs_unnorm = ((
                    (src_inputs * self.std[0]) + self.mean[0]) - 0.5) * 2
                tgt_inputs_unnorm = ((
                    (tgt_inputs * self.std[0]) + self.mean[0]) - 0.5) * 2

                # Creating one hot vector
                labels_onehot = np.zeros(
                    (self.opt.batchSize, self.nclasses + 1), dtype=np.float32)
                for num in range(self.opt.batchSize):
                    labels_onehot[num, src_labels[num]] = 1
                src_labels_onehot = torch.from_numpy(labels_onehot)

                labels_onehot = np.zeros(
                    (self.opt.batchSize, self.nclasses + 1), dtype=np.float32)
                for num in range(self.opt.batchSize):
                    labels_onehot[num, self.nclasses] = 1
                tgt_labels_onehot = torch.from_numpy(labels_onehot)

                # feed variables to gpu
                if self.opt.gpu >= 0:
                    src_inputs, src_labels = src_inputs.cuda(
                    ), src_labels.cuda()
                    src_inputs_unnorm = src_inputs_unnorm.cuda()
                    tgt_inputs_unnorm = tgt_inputs_unnorm.cuda()
                    tgt_inputs = tgt_inputs.cuda()
                    src_labels_onehot = src_labels_onehot.cuda()
                    tgt_labels_onehot = tgt_labels_onehot.cuda()

                # Wrapping in variable
                src_inputsv, src_labelsv = Variable(src_inputs), Variable(
                    src_labels)
                src_inputs_unnormv = Variable(src_inputs_unnorm)
                tgt_inputsv = Variable(tgt_inputs)
                tgt_inputs_unnormv = Variable(tgt_inputs_unnorm)
                src_labels_onehotv = Variable(src_labels_onehot)
                tgt_labels_onehotv = Variable(tgt_labels_onehot)

                ###########################
                # Updates
                ###########################

                # Mix source and target domain images
                mix_ratio = np.random.beta(self.opt.alpha, self.opt.alpha)
                mix_ratio = round(mix_ratio, 2)
                # clip the mixup_ratio
                if (mix_ratio >= 0.5 and mix_ratio <
                    (0.5 + self.opt.clip_thr)):
                    mix_ratio = 0.5 + self.opt.clip_thr
                if (mix_ratio > (0.5 - self.opt.clip_thr) and mix_ratio < 0.5):
                    mix_ratio = 0.5 - self.opt.clip_thr

                # Define labels for mixed images
                mix_label = torch.FloatTensor(
                    self.opt.batchSize).fill_(mix_ratio)
                if self.opt.gpu >= 0:
                    mix_label = mix_label.cuda()
                mix_labelv = Variable(mix_label)

                mix_samples = mix_ratio * src_inputs_unnormv + (
                    1 - mix_ratio) * tgt_inputs_unnormv

                # Define the label for mixed input
                labels_onehot = np.zeros(
                    (self.opt.batchSize, self.nclasses + 1), dtype=np.float32)
                for num in range(self.opt.batchSize):
                    labels_onehot[num, src_labels[num]] = mix_ratio
                    labels_onehot[num, self.nclasses] = 1.0 - mix_ratio
                mix_labels_onehot = torch.from_numpy(labels_onehot)

                if self.opt.gpu >= 0:
                    mix_labels_onehot = mix_labels_onehot.cuda()
                mix_labels_onehotv = Variable(mix_labels_onehot)

                # Generating images for both domains (add mixed images)

                src_emb, src_mn, src_sd = self.netF(src_inputsv)
                tgt_emb, tgt_mn, tgt_sd = self.netF(tgt_inputsv)

                # Generate mean and std for mixed samples
                mix_mn = src_mn * mix_ratio + tgt_mn * (1.0 - mix_ratio)
                mix_sd = src_sd * mix_ratio + tgt_sd * (1.0 - mix_ratio)

                src_mn_sd = torch.cat((src_mn, src_sd), 1)
                outC_src_logit, outC_src = self.netC(src_mn_sd)

                src_emb_cat = torch.cat((src_mn, src_sd, src_labels_onehotv),
                                        1)
                src_gen = self.netG(src_emb_cat)

                tgt_emb_cat = torch.cat((tgt_mn, tgt_sd, tgt_labels_onehotv),
                                        1)
                tgt_gen = self.netG(tgt_emb_cat)

                mix_emb_cat = torch.cat((mix_mn, mix_sd, mix_labels_onehotv),
                                        1)
                mix_gen = self.netG(mix_emb_cat)

                # Updating D network

                self.netD.zero_grad()

                src_realoutputD_s, src_realoutputD_c, src_realoutputD_t = self.netD(
                    src_inputs_unnormv)
                errD_src_real_s = self.criterion_s(src_realoutputD_s,
                                                   reallabelv)
                errD_src_real_c = self.criterion_c(src_realoutputD_c,
                                                   src_labelsv)

                src_fakeoutputD_s, src_fakeoutputD_c, _ = self.netD(src_gen)
                errD_src_fake_s = self.criterion_s(src_fakeoutputD_s,
                                                   fakelabelv)

                tgt_realoutputD_s, tgt_realoutputD_c, tgt_realoutputD_t = self.netD(
                    tgt_inputs_unnormv)
                tgt_fakeoutputD_s, tgt_fakeoutputD_c, _ = self.netD(tgt_gen)
                errD_tgt_fake_s = self.criterion_s(tgt_fakeoutputD_s,
                                                   fakelabelv)

                mix_s, _, mix_t = self.netD(mix_samples)
                if (mix_ratio > 0.5):
                    tmp_margin = 2 * mix_ratio - 1.
                    errD_mix_t = F.triplet_margin_loss(mix_t,
                                                       src_realoutputD_t,
                                                       tgt_realoutputD_t,
                                                       margin=tmp_margin)
                else:
                    tmp_margin = 1. - 2 * mix_ratio
                    errD_mix_t = F.triplet_margin_loss(mix_t,
                                                       tgt_realoutputD_t,
                                                       src_realoutputD_t,
                                                       margin=tmp_margin)
                errD_mix_s = self.criterion_s(mix_s, mix_labelv)
                errD_mix = errD_mix_s + errD_mix_t

                mix_gen_s, _, _ = self.netD(mix_gen)
                errD_mix_gen = self.criterion_s(mix_gen_s, fakelabelv)

                errD = errD_src_real_c + errD_src_real_s + errD_src_fake_s + errD_tgt_fake_s + errD_mix + errD_mix_gen
                errD.backward(retain_graph=True)
                self.optimizerD.step()

                # Updating G network

                self.netG.zero_grad()

                src_fakeoutputD_s, src_fakeoutputD_c, _ = self.netD(src_gen)
                errG_src_c = self.criterion_c(src_fakeoutputD_c, src_labelsv)
                errG_src_s = self.criterion_s(src_fakeoutputD_s, reallabelv)

                mix_gen_s, _, _ = self.netD(mix_gen)
                errG_mix_gen_s = self.criterion_s(mix_gen_s, reallabelv)

                errG = errG_src_c + errG_src_s + errG_mix_gen_s
                errG.backward(retain_graph=True)
                self.optimizerG.step()

                # Updating C network

                self.netC.zero_grad()
                errC = self.criterion_c(outC_src_logit, src_labelsv)
                errC.backward(retain_graph=True)
                self.optimizerC.step()

                # Updating F network

                self.netF.zero_grad()
                err_KL_src = torch.mean(
                    0.5 *
                    torch.sum(torch.exp(src_sd) + src_mn**2 - 1. - src_sd, 1))
                err_KL_tgt = torch.mean(
                    0.5 *
                    torch.sum(torch.exp(tgt_sd) + tgt_mn**2 - 1. - tgt_sd, 1))
                err_KL = (err_KL_src + err_KL_tgt) * (self.opt.KL_weight)

                errF_fromC = self.criterion_c(outC_src_logit, src_labelsv)

                src_fakeoutputD_s, src_fakeoutputD_c, _ = self.netD(src_gen)
                errF_src_fromD = self.criterion_c(
                    src_fakeoutputD_c, src_labelsv) * (self.opt.adv_weight)

                tgt_fakeoutputD_s, tgt_fakeoutputD_c, _ = self.netD(tgt_gen)
                errF_tgt_fromD = self.criterion_s(
                    tgt_fakeoutputD_s,
                    reallabelv) * (self.opt.adv_weight * self.opt.gamma)

                mix_gen_s, _, _ = self.netD(mix_gen)
                errF_mix_fromD = self.criterion_s(mix_gen_s, reallabelv) * (
                    self.opt.adv_weight * self.opt.delta)

                errF = err_KL + errF_fromC + errF_src_fromD + errF_tgt_fromD + errF_mix_fromD
                errF.backward()
                self.optimizerF.step()

                curr_iter += 1

                # print training information
                if ((i + 1) % 50 == 0):
                    text_format = 'epoch: {}, iteration: {}, errD: {}, errG: {}, ' \
                                  + 'errC: {}, errF: {}'
                    train_text = text_format.format(epoch + 1, i + 1, \
                                                    errD.item(), errG.item(), errC.item(), errF.item())
                    print(train_text)

                # Visualization
                if i == 1:
                    vutils.save_image(
                        (src_gen.data / 2) + 0.5,
                        '%s/source_generation/source_gen_%d.png' %
                        (self.opt.outf, epoch))
                    vutils.save_image(
                        (tgt_gen.data / 2) + 0.5,
                        '%s/target_generation/target_gen_%d.png' %
                        (self.opt.outf, epoch))
                    vutils.save_image((mix_gen.data / 2) + 0.5,
                                      '%s/mix_generation/mix_gen_%d.png' %
                                      (self.opt.outf, epoch))
                    vutils.save_image((mix_samples.data / 2) + 0.5,
                                      '%s/mix_images/mix_samples_%d.png' %
                                      (self.opt.outf, epoch))

                # Learning rate scheduling
                if self.opt.lrd:
                    self.optimizerD = utils.exp_lr_scheduler(
                        self.optimizerD, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)
                    self.optimizerF = utils.exp_lr_scheduler(
                        self.optimizerF, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)
                    self.optimizerC = utils.exp_lr_scheduler(
                        self.optimizerC, epoch, self.opt.lr, self.opt.lrd,
                        curr_iter)

                    # Validate every epoch
            self.validate(epoch + 1)
Exemplo n.º 18
0
    def train(self):
        
        curr_iter = 0
        
        reallabel = torch.FloatTensor(self.opt.batchSize).fill_(self.real_label_val)
        fakelabel = torch.FloatTensor(self.opt.batchSize).fill_(self.fake_label_val)
        if self.opt.gpu>=0:
            reallabel, fakelabel = reallabel.cuda(), fakelabel.cuda()
        reallabelv = Variable(reallabel) 
        fakelabelv = Variable(fakelabel) 
        
        for epoch in range(self.opt.nepochs):
            
            self.netG.train()    
            self.netF.train()    
            self.netC.train()    
            self.netD.train()    
        
            for i, (datas, datat) in enumerate(itertools.izip(self.source_trainloader, self.targetloader)):
                
                ###########################
                # Forming input variables
                ###########################
                
                src_inputs, src_labels = datas
                tgt_inputs, __ = datat       
                src_inputs_unnorm = (((src_inputs*self.std[0]) + self.mean[0]) - 0.5)*2

                # Creating one hot vector
                labels_onehot = np.zeros((self.opt.batchSize, self.nclasses+1), dtype=np.float32)
                for num in range(self.opt.batchSize):
                    labels_onehot[num, src_labels[num]] = 1
                src_labels_onehot = torch.from_numpy(labels_onehot)

                labels_onehot = np.zeros((self.opt.batchSize, self.nclasses+1), dtype=np.float32)
                for num in range(self.opt.batchSize):
                    labels_onehot[num, self.nclasses] = 1
                tgt_labels_onehot = torch.from_numpy(labels_onehot)
                
                if self.opt.gpu>=0:
                    src_inputs, src_labels = src_inputs.cuda(), src_labels.cuda()
                    src_inputs_unnorm = src_inputs_unnorm.cuda() 
                    tgt_inputs = tgt_inputs.cuda()
                    src_labels_onehot = src_labels_onehot.cuda()
                    tgt_labels_onehot = tgt_labels_onehot.cuda()
                
                # Wrapping in variable
                src_inputsv, src_labelsv = Variable(src_inputs), Variable(src_labels)
                src_inputs_unnormv = Variable(src_inputs_unnorm)
                tgt_inputsv = Variable(tgt_inputs)
                src_labels_onehotv = Variable(src_labels_onehot)
                tgt_labels_onehotv = Variable(tgt_labels_onehot)
                
                ###########################
                # Updates
                ###########################
                
                # Updating D network
                
                self.netD.zero_grad()
                src_emb = self.netF(src_inputsv)
                src_emb_cat = torch.cat((src_labels_onehotv, src_emb), 1)
                src_gen = self.netG(src_emb_cat)

                tgt_emb = self.netF(tgt_inputsv)
                tgt_emb_cat = torch.cat((tgt_labels_onehotv, tgt_emb),1)
                tgt_gen = self.netG(tgt_emb_cat)

                src_realoutputD_s, src_realoutputD_c, _ = self.netD(src_inputs_unnormv)   
                errD_src_real_s = self.criterion_s(src_realoutputD_s, reallabelv) 
                errD_src_real_c = self.criterion_c(src_realoutputD_c, src_labelsv) 

                src_fakeoutputD_s, src_fakeoutputD_c, _ = self.netD(src_gen)
                errD_src_fake_s = self.criterion_s(src_fakeoutputD_s, fakelabelv)

                tgt_fakeoutputD_s, tgt_fakeoutputD_c, _ = self.netD(tgt_gen)          
                errD_tgt_fake_s = self.criterion_s(tgt_fakeoutputD_s, fakelabelv)

                errD = errD_src_real_c + errD_src_real_s + errD_src_fake_s + errD_tgt_fake_s
                errD.backward(retain_graph=True)    
                self.optimizerD.step()
                

                # Updating G network
                
                self.netG.zero_grad()       
                src_fakeoutputD_s, src_fakeoutputD_c, _ = self.netD(src_gen)
                errG_c = self.criterion_c(src_fakeoutputD_c, src_labelsv)
                errG_s = self.criterion_s(src_fakeoutputD_s, reallabelv)
                errG = errG_c + errG_s
                errG.backward(retain_graph=True)
                self.optimizerG.step()
                

                # Updating C network
                
                self.netC.zero_grad()
                outC = self.netC(src_emb)   
                errC = self.criterion_c(outC, src_labelsv)
                errC.backward(retain_graph=True)    
                self.optimizerC.step()

                
                # Updating F network

                self.netF.zero_grad()
                errF_fromC = self.criterion_c(outC, src_labelsv)        

                src_fakeoutputD_s, src_fakeoutputD_c, _ = self.netD(src_gen)
                errF_src_fromD = self.criterion_c(src_fakeoutputD_c, src_labelsv)*(self.opt.adv_weight)

                tgt_fakeoutputD_s, tgt_fakeoutputD_c, _ = self.netD(tgt_gen)
                errF_tgt_fromD = self.criterion_s(tgt_fakeoutputD_s, reallabelv)*(self.opt.adv_weight*self.opt.alpha)
                
                errF = errF_fromC + errF_src_fromD + errF_tgt_fromD
                errF.backward()
                self.optimizerF.step()        
                
                curr_iter += 1
                
                # Visualization
                if i == 1:
                    vutils.save_image((src_gen.data/2)+0.5, '%s/visualization/source_gen_%d.png' %(self.opt.outf, epoch))
                    vutils.save_image((tgt_gen.data/2)+0.5, '%s/visualization/target_gen_%d.png' %(self.opt.outf, epoch))
                    
                # Learning rate scheduling
                if self.opt.lrd:
                    self.optimizerD = utils.exp_lr_scheduler(self.optimizerD, epoch, self.opt.lr, self.opt.lrd, curr_iter)    
                    self.optimizerF = utils.exp_lr_scheduler(self.optimizerF, epoch, self.opt.lr, self.opt.lrd, curr_iter)
                    self.optimizerC = utils.exp_lr_scheduler(self.optimizerC, epoch, self.opt.lr, self.opt.lrd, curr_iter)                  
            
            # Validate every epoch
            self.validate(epoch+1)
Exemplo n.º 19
0
    def train(self):
        
        curr_iter = 0
        
        reallabel = torch.FloatTensor(self.opt.batchSize).fill_(self.real_label_val)
        fakelabel = torch.FloatTensor(self.opt.batchSize).fill_(self.fake_label_val)
        if self.opt.gpu>=0:
            reallabel, fakelabel = reallabel.cuda(), fakelabel.cuda()
        reallabelv = Variable(reallabel) 
        fakelabelv = Variable(fakelabel) 
        
        for epoch in range(self.opt.nepochs):
            
            self.netF1.train()
            self.netF2.train()
            self.netC1.train()
            self.netC2.train()
            self.netC3.train()
            self.netG.train()
            self.netD.train()

            for trainset_id in range (0, len(self.source_trainloader)):

                for i, (datas, datat) in enumerate(itertools.izip(self.source_trainloader[trainset_id], self.targetloader)):

                    ###########################
                    # Forming input variables
                    ###########################
                    
                    src_inputs, src_class_labels = datas
                    tgt_inputs, __ = datat

                    src_domain_id = 1
                    src_domain_labels = torch.LongTensor(self.opt.batchSize).fill_(src_domain_id)
                    tgt_domain_labels = torch.LongTensor(self.opt.batchSize).fill_(0)

                    src_inputs_unnorm = (((src_inputs*self.std[0]) + self.mean[0]) - 0.5)*2
                    tgt_inputs_unnorm = (((tgt_inputs*self.std[0]) + self.mean[0]) - 0.5)*2

                    
                    if self.opt.gpu>=0:
                        src_inputs, src_class_labels = src_inputs.cuda(), src_class_labels.cuda()
                        tgt_inputs = tgt_inputs.cuda()
                        src_domain_labels = src_domain_labels.cuda()
                        tgt_domain_labels = tgt_domain_labels.cuda()
                        src_inputs_unnorm = src_inputs_unnorm.cuda() 
                        tgt_inputs_unnorm = tgt_inputs_unnorm.cuda()
                    
                    # Wrapping in variable
                    src_inputsv, src_class_labelsv = Variable(src_inputs), Variable(src_class_labels)
                    tgt_inputsv = Variable(tgt_inputs)
                    src_domain_labelsv = Variable(src_domain_labels)
                    tgt_domain_labelsv = Variable(tgt_domain_labels)
                    src_inputs_unnormv = Variable(src_inputs_unnorm)
                    tgt_inputs_unnormv = Variable(tgt_inputs_unnorm)
                    

                    ###########################
                    # Updates
                    ###########################
                    
                    # Updating D network
                    self.netD.zero_grad()
                    src_f1 = self.netF1(src_inputsv)
                    src_f2 = self.netF2(src_inputsv)
                    src_f1f2_cat = torch.cat((src_f1, src_f2), 1)
                    src_gen = self.netG(src_f1f2_cat)

                    tgt_f1 = self.netF1(tgt_inputsv)
                    tgt_f2 = self.netF2(tgt_inputsv)
                    tgt_f1f2_cat = torch.cat((tgt_f1, tgt_f2), 1)
                    tgt_gen = self.netG(tgt_f1f2_cat)

                    src_realoutputD_s, src_realoutputD_c, src_realoutputD_d = self.netD(src_inputs_unnormv)   
                    errD_src_real_s = self.criterion_s(src_realoutputD_s, reallabelv) 
                    errD_src_real_c = self.criterion_c(src_realoutputD_c, src_class_labelsv) 
                    errD_src_real_d = self.criterion_c(src_realoutputD_d, src_domain_labelsv) 

                    tgt_realoutputD_s, __, tgt_realoutputD_d = self.netD(tgt_inputs_unnormv)   
                    errD_tgt_real_s = self.criterion_s(tgt_realoutputD_s, reallabelv) 
                    # errD_tgt_real_c = self.criterion_c(tgt_realoutputD_c, tgt_class_labelsv)
                    errD_tgt_real_d = self.criterion_c(tgt_realoutputD_d, tgt_domain_labelsv) 

                    src_fakeoutputD_s, __, __ = self.netD(src_gen)
                    errD_src_fake_s = self.criterion_s(src_fakeoutputD_s, fakelabelv)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
                    tgt_fakeoutputD_s, __, __ = self.netD(tgt_gen)          
                    errD_tgt_fake_s = self.criterion_s(tgt_fakeoutputD_s, fakelabelv)

                    errD = (errD_src_real_s + errD_src_real_c + errD_src_real_d) + 1 * (errD_tgt_real_s + errD_tgt_real_d) + errD_src_fake_s + errD_tgt_fake_s
                    errD.backward(retain_graph=True)    
                    self.optimizerD.step()
                    

                    # Updating G network
                    self.netG.zero_grad()       
                    src_fakeoutputD_s, src_fakeoutputD_c, src_fakeoutputD_d = self.netD(src_gen)
                    errG_src_s = self.criterion_s(src_fakeoutputD_s, reallabelv)
                    errG_src_c = self.criterion_c(src_fakeoutputD_c, src_class_labelsv)
                    errG_src_d = self.criterion_c(src_fakeoutputD_d, src_domain_labelsv)

                    tgt_fakeoutputD_s, __, tgt_fakeoutputD_d = self.netD(tgt_gen)
                    errG_tgt_s = self.criterion_s(tgt_fakeoutputD_s, reallabelv)
                    # errG_tgt_c = self.criterion_c(tgt_fakeoutputD_c, tgt_class_labelsv)
                    errG_tgt_d = self.criterion_c(tgt_fakeoutputD_d, tgt_domain_labelsv)

                    errG = (errG_src_s + errG_src_c + errG_src_d) + 1 * (errG_tgt_s + errG_tgt_d)
                    errG.backward(retain_graph=True)
                    self.optimizerG.step()


                    # Updating C3 Network, hold since it may do not work 


                    # Updating C2 Network
                    self.netC2.zero_grad()
                    outC2_src = self.netC2(src_f2) 
                    outC2_tgt = self.netC2(tgt_f2)

                    errC2 = self.criterion_c(outC2_src, src_domain_labelsv) + self.criterion_c(outC2_tgt, tgt_domain_labelsv)
                    errC2.backward(retain_graph=True)    
                    self.optimizerC2.step()


                    # Updating C1 Network
                    self.netC1.zero_grad()
                    outC1_src = self.netC1(src_f1) 
                    errC1 = self.criterion_c(outC1_src, src_class_labelsv)
                    errC1.backward(retain_graph=True)    
                    self.optimizerC1.step()


                    # Updating F2 Network
                    self.netF2.zero_grad()
                    errF2_fromC = self.criterion_c(outC2_src, src_domain_labelsv) + self.criterion_c(outC2_tgt, tgt_domain_labelsv)

                    src_fakeoutputD_s, __, src_fakeoutputD_d = self.netD(src_gen)
                    errF2_src_fromD_s = self.criterion_s(src_fakeoutputD_s, reallabelv)*(self.opt.adv_weight)
                    errF2_src_fromD_d = self.criterion_c(src_fakeoutputD_d, src_domain_labelsv)*(self.opt.adv_weight)

                    tgt_fakeoutputD_s, __, tgt_fakeoutputD_d = self.netD(tgt_gen)
                    errF2_tgt_fromD_s = self.criterion_s(tgt_fakeoutputD_s, reallabelv)*(self.opt.adv_weight*self.opt.alpha)
                    errF2_tgt_fromD_d = self.criterion_c(tgt_fakeoutputD_d, tgt_domain_labelsv)*(self.opt.adv_weight*self.opt.alpha)
                    
                    errF2 = errF2_fromC + (errF2_src_fromD_s + errF2_src_fromD_d) + (errF2_tgt_fromD_s + errF2_tgt_fromD_d)
                    errF2.backward(retain_graph=True)
                    self.optimizerF2.step()        


                    # Updating F1 Network
                    self.netF1.zero_grad()
                    errF1_fromC = self.criterion_c(outC1_src, src_class_labelsv)

                    src_fakeoutputD_s, src_fakeoutputD_c, __ = self.netD(src_gen)
                    errF1_src_fromD_s = self.criterion_s(src_fakeoutputD_s, reallabelv)*(self.opt.adv_weight)
                    errF1_src_fromD_c = self.criterion_c(src_fakeoutputD_c, src_class_labelsv)*(self.opt.adv_weight)

                    tgt_fakeoutputD_s, __, __ = self.netD(tgt_gen)
                    errF1_tgt_fromD_s = self.criterion_s(tgt_fakeoutputD_s, reallabelv)*(self.opt.adv_weight*self.opt.alpha)
                    
                    errF1 = errF1_fromC + errF1_src_fromD_s + errF1_src_fromD_c + 1 * errF1_tgt_fromD_s
                    errF1.backward()
                    self.optimizerF1.step()        


                    curr_iter += 1

                    # Visualization
                    if i == 1:
                        vutils.save_image((src_inputsv.data/2)+0.5, '%s/visualization/source_input_%d_%d.png' %(self.opt.outf, epoch, trainset_id))
                        vutils.save_image((tgt_inputsv.data/2)+0.5, '%s/visualization/target_input_%d.png' %(self.opt.outf, epoch))
                        vutils.save_image((src_gen.data/2)+0.5, '%s/visualization/source_gen_%d_%d.png' %(self.opt.outf, epoch, trainset_id))
                        vutils.save_image((tgt_gen.data/2)+0.5, '%s/visualization/target_gen_%d.png' %(self.opt.outf, epoch))
                        
                    # Learning rate scheduling
                    if self.opt.lrd:
                        self.optimizerF1 = utils.exp_lr_scheduler(self.optimizerF1, epoch, self.opt.lr, self.opt.lrd, curr_iter)
                        self.optimizerF2 = utils.exp_lr_scheduler(self.optimizerF2, epoch, self.opt.lr, self.opt.lrd, curr_iter)
                        self.optimizerC1 = utils.exp_lr_scheduler(self.optimizerC1, epoch, self.opt.lr, self.opt.lrd, curr_iter)
                        self.optimizerC2 = utils.exp_lr_scheduler(self.optimizerC2, epoch, self.opt.lr, self.opt.lrd, curr_iter)
                        self.optimizerC3 = utils.exp_lr_scheduler(self.optimizerC3, epoch, self.opt.lr, self.opt.lrd, curr_iter)
                        self.optimizerG = utils.exp_lr_scheduler(self.optimizerG, epoch, self.opt.lr, self.opt.lrd, curr_iter)
                        self.optimizerD = utils.exp_lr_scheduler(self.optimizerD, epoch, self.opt.lr, self.opt.lrd, curr_iter)
            
            # Validate every epoch
            self.validate(epoch+1)