예제 #1
0
    def __init__(self,
                 args,
                 batch_size=64,
                 source='mnist',
                 target='usps',
                 learning_rate=0.0002,
                 interval=100,
                 optimizer='adam',
                 num_k=4,
                 all_use=False,
                 checkpoint_dir=None,
                 save_epoch=10):
        self.batch_size = batch_size
        self.source = source
        self.target = target
        self.num_k = num_k
        self.checkpoint_dir = checkpoint_dir
        self.save_epoch = save_epoch
        self.use_abs_diff = args.use_abs_diff
        self.all_use = all_use
        self.lambda_1 = args.lambda_1
        self.lambda_2 = args.lambda_2
        if self.source == 'svhn':
            self.scale = True
        else:
            self.scale = False
        print('dataset loading')
        self.datasets, self.dataset_test = dataset_read(source,
                                                        target,
                                                        self.batch_size,
                                                        scale=self.scale,
                                                        all_use=self.all_use)
        print('load finished!')
        self.G = Generator(source=source, target=target)
        self.C1 = Classifier(source=source, target=target)
        self.C2 = Classifier(source=source, target=target)
        if args.eval_only:
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               self.checkpoint_dir, args.resume_epoch))
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))

        self.G.cuda()
        self.C1.cuda()
        self.C2.cuda()
        self.interval = interval

        self.set_optimizer(which_opt=optimizer, lr=learning_rate)
        self.lr = learning_rate
예제 #2
0
    def __init__(self, args):

        super().__init__()

        self.opt_losses = args.opt_losses
        self.check_list = [
            'disc_ds_di_C_D', 'disc_ds_di_G', 'ring', 'confusion_G'
        ]
        assert all([ol in self.check_list
                    for ol in self.opt_losses]), 'Check loss entries'
        opt_losses_str = ",".join(map(str, self.opt_losses))
        timestring = strftime("%Y-%m-%d_%H-%M-%S",
                              gmtime()) + "_{}_optloss={}_src={}".format(
                                  args.exp_name, opt_losses_str, args.source)
        self.logdir = os.path.join('./logs', timestring)
        self.logger = SummaryWriter(log_dir=self.logdir)
        self.device = torch.device("cuda" if args.use_cuda else "cpu")

        self.src_domain_code = np.repeat(np.array([[*([1]), *([0])]]),
                                         args.batch_size,
                                         axis=0)
        self.trg_domain_code = np.repeat(np.array([[*([0]), *([1])]]),
                                         args.batch_size,
                                         axis=0)
        self.src_domain_code = torch.FloatTensor(self.src_domain_code).to(
            self.device)
        self.trg_domain_code = torch.FloatTensor(self.trg_domain_code).to(
            self.device)

        self.source = args.source
        self.target = args.target
        self.num_k = args.num_k
        self.checkpoint_dir = args.checkpoint_dir
        self.save_epoch = args.save_epoch
        self.use_abs_diff = args.use_abs_diff

        self.mi_k = 1
        self.delta = 0.01
        self.mi_coeff = 0.0001
        self.interval = 10  # write on tb every
        self.batch_size = args.batch_size
        self.which_opt = 'adam'
        self.lr = args.lr
        self.scale = 32
        self.global_step = 0

        print('Loading datasets')
        self.dataset_train, self.dataset_test = dataset_read(
            args.data_dir, self.source, self.target, self.batch_size,
            self.scale)
        print('Done!')

        self.total_batches = {
            'train': self.get_dataset_size('train'),
            'test': self.get_dataset_size('test')
        }

        self.G = Generator(source=self.source, target=self.target)
        self.FD = Feature_Discriminator()
        self.R = Reconstructor()
        self.MI = Mine()

        self.C = nn.ModuleDict({
            'ds':
            Classifier(source=self.source, target=self.target),
            'di':
            Classifier(source=self.source, target=self.target),
            'ci':
            Classifier(source=self.source, target=self.target)
        })

        self.D = nn.ModuleDict({
            'ds': Disentangler(),
            'di': Disentangler(),
            'ci': Disentangler()
        })

        # All modules in the same dict
        self.components = nn.ModuleDict({
            'G': self.G,
            'FD': self.FD,
            'R': self.R,
            'MI': self.MI
        })

        self.xent_loss = nn.CrossEntropyLoss().to(self.device)
        self.adv_loss = nn.BCEWithLogitsLoss().to(self.device)
        self.ring_loss = RingLoss(type='auto', loss_weight=1.0).to(self.device)
        self.set_optimizer(lr=self.lr)
        self.to_device()
class Solver(object):
    def __init__(self, args, batch_size=64, source='mnist',
                 target='usps', learning_rate=0.0002, interval=100, optimizer='adam'
                 , num_k=4, all_use=False, checkpoint_dir=None, save_epoch=10):
        self.batch_size = batch_size
        self.source = source
        self.target = target
        self.num_k = num_k
        self.checkpoint_dir = checkpoint_dir
        self.save_epoch = save_epoch
        self.use_abs_diff = args.use_abs_diff
        self.all_use = all_use
        self.lambda_1 = args.lambda_1
        self.lambda_2 = args.lambda_2
        if self.source == 'svhn':
            self.scale = True
        else:
            self.scale = False
        print('dataset loading')
        self.datasets, self.dataset_test = dataset_read(source, target, self.batch_size, scale=self.scale,
                                                        all_use=self.all_use)
        print('load finished!')
        self.G = Generator(source=source, target=target)
        self.C1 = Classifier(source=source, target=target)
        self.C2 = Classifier(source=source, target=target)
        if args.eval_only:
            self.G.torch.load(
                '%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, args.resume_epoch))
            self.G.torch.load(
                '%s/%s_to_%s_model_epoch%s_G.pt' % (
                    self.checkpoint_dir, self.source, self.target, self.checkpoint_dir, args.resume_epoch))
            self.G.torch.load(
                '%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, args.resume_epoch))

        self.G.cuda()
        self.C1.cuda()
        self.C2.cuda()
        self.interval = interval

        self.set_optimizer(which_opt=optimizer, lr=learning_rate)
        self.lr = learning_rate

    def set_optimizer(self, which_opt='momentum', lr=0.001, momentum=0.9):
        if which_opt == 'momentum':
            self.opt_g = optim.SGD(self.G.parameters(),
                                   lr=lr, weight_decay=0.0005,
                                   momentum=momentum)

            self.opt_c1 = optim.SGD(self.C1.parameters(),
                                    lr=lr, weight_decay=0.0005,
                                    momentum=momentum)
            self.opt_c2 = optim.SGD(self.C2.parameters(),
                                    lr=lr, weight_decay=0.0005,
                                    momentum=momentum)

        if which_opt == 'adam':
            self.opt_g = optim.Adam(self.G.parameters(),
                                    lr=lr, weight_decay=0.0005)

            self.opt_c1 = optim.Adam(self.C1.parameters(),
                                     lr=lr, weight_decay=0.0005)
            self.opt_c2 = optim.Adam(self.C2.parameters(),
                                     lr=lr, weight_decay=0.0005)

    def reset_grad(self):
        self.opt_g.zero_grad()
        self.opt_c1.zero_grad()
        self.opt_c2.zero_grad()

    def ent(self, output):
        return - torch.mean(output * torch.log(output + 1e-6))

    def discrepancy(self, out1, out2):
        return torch.mean(torch.abs(F.softmax(out1) - F.softmax(out2)))

    def train(self, epoch, record_file=None):
        criterion = nn.CrossEntropyLoss().cuda()
        
        # initialze a L1 loss for distribution alignment
        criterionConsistency = nn.L1Loss().cuda()
        
        self.C1.train()
        self.C2.train()
        torch.cuda.manual_seed(1)

        Tensor = torch.cuda.FloatTensor

        for batch_idx, data in enumerate(self.datasets):
            img_t = data['T']
            img_s = data['S']
            label_s = data['S_label']
            if img_s.size()[0] < self.batch_size or img_t.size()[0] < self.batch_size:
                break
            img_s = img_s.cuda()
            img_t = img_t.cuda()
            label_s = Variable(label_s.long().cuda())

            # for usps and mnist (source)
            z = Variable(Tensor(np.random.normal(0,1, (2048, 48))))
            # for svhn (source)
            #z = Variable(Tensor(np.random.normal(0,1, (8192, 128))))

            img_s = Variable(img_s)
            img_t = Variable(img_t)
            self.reset_grad()
            feat_s = self.G(img_s)
            output_s1 = self.C1(feat_s)
            output_s2 = self.C2(feat_s)

            # for usps and mnist (source)
            feat_s_kl = feat_s.view(-1,48)
            # for svhn (source)
            #feat_s_kl = feat_s.view(-1,128)
            loss_kld = F.kl_div(F.log_softmax(feat_s_kl), F.softmax(z))

            loss_s1 = criterion(output_s1, label_s)
            loss_s2 = criterion(output_s2, label_s)
            loss_s = loss_s1 + loss_s2 + self.lambda_1 * loss_kld
            loss_s.backward()
            self.opt_g.step()
            self.opt_c1.step()
            self.opt_c2.step()
            self.reset_grad()

            feat_s = self.G(img_s)
            output_s1 = self.C1(feat_s)
            output_s2 = self.C2(feat_s)
            feat_t = self.G(img_t)
            output_t1 = self.C1(feat_t)
            output_t2 = self.C2(feat_t)

            # for usps and mnist (source)
            feat_s_kl = feat_s.view(-1,48)
            # for svhn (source)
            #feat_s_kl = feat_s.view(-1,128)
            loss_kld = F.kl_div(F.log_softmax(feat_s_kl), F.softmax(z))

            loss_s1 = criterion(output_s1, label_s)
            loss_s2 = criterion(output_s2, label_s)
            loss_s = loss_s1 + loss_s2 + self.lambda_1 *loss_kld
            loss_dis = self.discrepancy(output_t1, output_t2)
            loss = loss_s - loss_dis
            loss.backward()
            self.opt_c1.step()
            self.opt_c2.step()
            self.reset_grad()

            for i in range(self.num_k):
                feat_t = self.G(img_t)
                output_t1 = self.C1(feat_t)
                output_t2 = self.C2(feat_t)
                
                # get x_rt
                feat_t_recon = self.G(img_t, is_deconv=True)

                feat_z_recon = self.G.decode(z)
                
                # distribution alignment loss 
                loss_dal = criterionConsistency(feat_t_recon, feat_z_recon) 
                    
                #updated loss function
                loss_dis = self.discrepancy(output_t1, output_t2) + self.lambda_2 *loss_dal
                
                loss_dis.backward()
                self.opt_g.step()
                self.reset_grad()
            if batch_idx > 500:
                return batch_idx

            if batch_idx % self.interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss1: {:.6f}\t Loss2: {:.6f}\t  Discrepancy: {:.6f}'.format(
                    epoch, batch_idx, 100,
                    100. * batch_idx / 70000, loss_s1.item(), loss_s2.item(), loss_dis.item()))
                if record_file:
                    record = open(record_file, 'a')
                    record.write('%s %s %s\n' % (loss_dis.item(), loss_s1.item(), loss_s2.item()))
                    record.close()
            torch.save(self.G,
                        '%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, epoch))
        return batch_idx

    def test(self, epoch, record_file=None, save_model=False):
        self.G.eval()
        self.C1.eval()
        self.C2.eval()
        test_loss = 0
        correct1 = 0
        correct2 = 0
        correct3 = 0
        size = 0
        for batch_idx, data in enumerate(self.dataset_test):
            img = data['T']
            label = data['T_label']
            img, label = img.cuda(), label.long().cuda()
            img, label = Variable(img, volatile=True), Variable(label)
            feat = self.G(img)
            output1 = self.C1(feat)
            output2 = self.C2(feat)
            test_loss += F.nll_loss(output1, label).item()
            output_ensemble = output1 + output2
            pred1 = output1.data.max(1)[1]
            pred2 = output2.data.max(1)[1]
            pred_ensemble = output_ensemble.data.max(1)[1]
            k = label.data.size()[0]
            correct1 += pred1.eq(label.data).cpu().sum()
            correct2 += pred2.eq(label.data).cpu().sum()
            correct3 += pred_ensemble.eq(label.data).cpu().sum()
            size += k
        test_loss = test_loss / size
        print(
            '\nTest set: Average loss: {:.4f}, Accuracy C1: {}/{} ({:.0f}%) Accuracy C2: {}/{} ({:.0f}%) Accuracy Ensemble: {}/{} ({:.0f}%) \n'.format(
                test_loss, correct1, size,
                100. * correct1 / size, correct2, size, 100. * correct2 / size, correct3, size, 100. * correct3 / size))
        if save_model and epoch % self.save_epoch == 0:
            torch.save(self.G,
                       '%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, epoch))
            torch.save(self.C1,
                       '%s/%s_to_%s_model_epoch%s_C1.pt' % (self.checkpoint_dir, self.source, self.target, epoch))
            torch.save(self.C2,
                       '%s/%s_to_%s_model_epoch%s_C2.pt' % (self.checkpoint_dir, self.source, self.target, epoch))
        if record_file:
            record = open(record_file, 'a')
            print('recording %s', record_file)
            record.write('%s %s %s\n' % (float(correct1) / size, float(correct2) / size, float(correct3) / size))
            record.close()
예제 #4
0
    def __init__(self,
                 args,
                 batch_size=64,
                 source='svhn',
                 target='mnist',
                 learning_rate=0.0002,
                 interval=1,
                 optimizer='adam',
                 num_k=4,
                 all_use=False,
                 checkpoint_dir=None,
                 save_epoch=10):

        timestring = strftime("%Y-%m-%d_%H-%M-%S",
                              gmtime()) + "_%s" % args.exp_name
        self.logdir = os.path.join('./logs', timestring)
        self.logger = SummaryWriter(log_dir=self.logdir)
        self.device = torch.device("cuda" if args.use_cuda else "cpu")

        self.src_domain_code = np.repeat(np.array([[*([1]), *([0])]]),
                                         batch_size,
                                         axis=0)
        self.trg_domain_code = np.repeat(np.array([[*([0]), *([1])]]),
                                         batch_size,
                                         axis=0)
        self.src_domain_code = torch.FloatTensor(self.src_domain_code).to(
            self.device)
        self.trg_domain_code = torch.FloatTensor(self.trg_domain_code).to(
            self.device)

        self.source = source
        self.target = target
        self.num_k = num_k
        self.mi_k = 1
        self.checkpoint_dir = checkpoint_dir
        self.save_epoch = save_epoch
        self.use_abs_diff = args.use_abs_diff
        self.all_use = all_use
        self.delta = 0.01
        self.mi_coeff = 0.0001
        self.interval = interval
        self.batch_size = batch_size
        self.lr = learning_rate
        self.scale = False

        print('dataset loading')
        self.datasets, self.dataset_test = dataset_read(source,
                                                        target,
                                                        self.batch_size,
                                                        scale=self.scale,
                                                        all_use=self.all_use)
        print('load finished!')

        self.G = Generator(source=source, target=target)
        self.FD = Feature_Discriminator()
        self.R = Reconstructor()
        self.MI = Mine()

        self.C = nn.ModuleDict({
            'ds': Classifier(source=source, target=target),
            'di': Classifier(source=source, target=target),
            'ci': Classifier(source=source, target=target)
        })

        self.D = nn.ModuleDict({
            'ds': Disentangler(),
            'di': Disentangler(),
            'ci': Disentangler()
        })

        # All modules in the same dict
        self.modules = nn.ModuleDict({
            'G': self.G,
            'FD': self.FD,
            'R': self.R,
            'MI': self.MI
        })

        if args.eval_only:
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))

        self.xent_loss = nn.CrossEntropyLoss().cuda()
        self.adv_loss = nn.BCEWithLogitsLoss().cuda()
        self.set_optimizer(which_opt=optimizer, lr=learning_rate)
        self.to_device()
예제 #5
0
    def __init__(self,
                 args,
                 batch_size=64,
                 source='source',
                 target='target',
                 learning_rate=0.0002,
                 interval=100,
                 optimizer='adam',
                 num_k=4,
                 all_use=False,
                 checkpoint_dir=None,
                 save_epoch=10,
                 leave_one_num=-1,
                 model_name=''):
        self.args = args
        self.batch_size = batch_size
        self.source = source
        self.target = target
        self.num_k = num_k
        self.checkpoint_dir = checkpoint_dir
        self.save_epoch = save_epoch
        self.use_abs_diff = args.use_abs_diff
        self.leave_one_num = leave_one_num

        print('dataset loading')
        download()
        self.data_train, self.data_val, self.data_test = dataset_read(
            source,
            target,
            self.batch_size,
            is_resize=args.is_resize,
            leave_one_num=self.leave_one_num,
            dataset=args.dataset,
            sensor_num=args.sensor_num)
        print('load finished!')

        self.G = Generator(source=source,
                           target=target,
                           is_resize=args.is_resize,
                           dataset=args.dataset,
                           sensor_num=args.sensor_num)
        self.LC = Classifier(source=source,
                             target=target,
                             is_resize=args.is_resize,
                             dataset=args.dataset)
        self.DC = DomainClassifier(source=source,
                                   target=target,
                                   is_resize=args.is_resize,
                                   dataset=args.dataset)

        if args.eval_only:
            self.data_val = self.data_test
            self.G = torch.load(r'checkpoint_DANN/best_model_G' + model_name +
                                '.pt')
            self.LC = torch.load(r'checkpoint_DANN/best_model_C1' +
                                 model_name + '.pt')
            self.DC = torch.load(r'checkpoint_DANN/best_model_C2' +
                                 model_name + '.pt')

        self.G.cuda()
        self.LC.cuda()
        self.DC.cuda()
        self.interval = interval

        self.set_optimizer(which_opt=optimizer, lr=learning_rate)
        self.lr = learning_rate
예제 #6
0
class SolverDANN(object):
    def __init__(self,
                 args,
                 batch_size=64,
                 source='source',
                 target='target',
                 learning_rate=0.0002,
                 interval=100,
                 optimizer='adam',
                 num_k=4,
                 all_use=False,
                 checkpoint_dir=None,
                 save_epoch=10,
                 leave_one_num=-1,
                 model_name=''):
        self.args = args
        self.batch_size = batch_size
        self.source = source
        self.target = target
        self.num_k = num_k
        self.checkpoint_dir = checkpoint_dir
        self.save_epoch = save_epoch
        self.use_abs_diff = args.use_abs_diff
        self.leave_one_num = leave_one_num

        print('dataset loading')
        download()
        self.data_train, self.data_val, self.data_test = dataset_read(
            source,
            target,
            self.batch_size,
            is_resize=args.is_resize,
            leave_one_num=self.leave_one_num,
            dataset=args.dataset,
            sensor_num=args.sensor_num)
        print('load finished!')

        self.G = Generator(source=source,
                           target=target,
                           is_resize=args.is_resize,
                           dataset=args.dataset,
                           sensor_num=args.sensor_num)
        self.LC = Classifier(source=source,
                             target=target,
                             is_resize=args.is_resize,
                             dataset=args.dataset)
        self.DC = DomainClassifier(source=source,
                                   target=target,
                                   is_resize=args.is_resize,
                                   dataset=args.dataset)

        if args.eval_only:
            self.data_val = self.data_test
            self.G = torch.load(r'checkpoint_DANN/best_model_G' + model_name +
                                '.pt')
            self.LC = torch.load(r'checkpoint_DANN/best_model_C1' +
                                 model_name + '.pt')
            self.DC = torch.load(r'checkpoint_DANN/best_model_C2' +
                                 model_name + '.pt')

        self.G.cuda()
        self.LC.cuda()
        self.DC.cuda()
        self.interval = interval

        self.set_optimizer(which_opt=optimizer, lr=learning_rate)
        self.lr = learning_rate

    def set_optimizer(self, which_opt='momentum', lr=0.001, momentum=0.9):
        if which_opt == 'momentum':
            self.opt_g = optim.SGD(self.G.parameters(),
                                   lr=lr,
                                   weight_decay=0.0005,
                                   momentum=momentum)

            self.opt_lc = optim.SGD(self.LC.parameters(),
                                    lr=lr,
                                    weight_decay=0.0005,
                                    momentum=momentum)
            self.opt_dc = optim.SGD(self.DC.parameters(),
                                    lr=lr,
                                    weight_decay=0.0005,
                                    momentum=momentum)

        if which_opt == 'adam':
            self.opt_g = optim.Adam(self.G.parameters(),
                                    lr=lr,
                                    weight_decay=0.0005)

            self.opt_lc = optim.Adam(self.LC.parameters(),
                                     lr=lr,
                                     weight_decay=0.0005)
            self.opt_dc = optim.Adam(self.DC.parameters(),
                                     lr=lr,
                                     weight_decay=0.0005)

    def reset_grad(self):
        self.opt_g.zero_grad()
        self.opt_lc.zero_grad()
        self.opt_dc.zero_grad()

    def ent(self, output):
        return -torch.mean(output * torch.log(output + 1e-6))

    def discrepancy(self, out1, out2):
        return torch.mean(torch.abs(F.softmax(out1) - F.softmax(out2)))

    def train(self, epoch, record_file=None):
        criterion = nn.CrossEntropyLoss().cuda()
        self.G.train()
        self.LC.train()
        self.DC.train()
        torch.cuda.manual_seed(1)
        for batch_idx, data in enumerate(self.data_train):
            img_t = data['T']
            img_s = data['S']
            label_s = data['S_label']
            domain_label_s = torch.zeros(img_s.shape[0])
            domain_label_t = torch.ones(img_t.shape[0])
            if img_s.size()[0] < self.batch_size or img_t.size(
            )[0] < self.batch_size:
                break
            img_s = img_s.cuda()
            img_t = img_t.cuda()
            label_s = Variable(label_s.long().cuda())
            domain_label_s = Variable(domain_label_s.long().cuda())
            domain_label_t = Variable(domain_label_t.long().cuda())
            img_s = Variable(img_s)
            img_t = Variable(img_t)
            self.reset_grad()

            feat_s = self.G(img_s)
            output_label_s = self.LC(feat_s)
            loss_label_s = criterion(output_label_s, label_s)
            loss_label_s.backward()
            self.opt_g.step()
            self.opt_lc.step()
            self.reset_grad()

            feat_s = self.G(img_s)
            output_domain_s = self.DC(feat_s)
            feat_t = self.G(img_t)
            output_domain_t = self.DC(feat_t)

            # The objective of the domain classifier is to classify the domain of data accurately.
            loss_domain_s = criterion(output_domain_s, domain_label_s)
            loss_domain_t = criterion(output_domain_t, domain_label_t)
            loss_domain = loss_domain_s + loss_domain_t
            loss_domain.backward()
            self.opt_dc.step()
            self.reset_grad()

            # One objective of the feature generator is to confuse the domain classifier.
            feat_s = self.G(img_s)
            output_domain_s = self.DC(feat_s)
            feat_t = self.G(img_t)
            output_domain_t = self.DC(feat_t)

            loss_domain_s = criterion(output_domain_s, domain_label_s)
            loss_domain_t = criterion(output_domain_t, domain_label_t)
            loss_domain = -loss_domain_s - loss_domain_t
            loss_domain.backward()
            self.opt_g.step()
            self.reset_grad()

            if batch_idx > 500:
                return batch_idx
        return batch_idx

    def test(self, epoch, record_file=None, save_model=False):
        self.G.eval()
        self.LC.eval()
        self.DC.eval()
        correct = 0.0
        size = 0.0
        for batch_idx, data in enumerate(self.data_val):
            img = data['T']
            label = data['T_label']
            img, label = img.cuda(), label.long().cuda()
            img, label = Variable(img, volatile=True), Variable(label)
            # label = label.squeeze()
            feat = self.G(img)
            output1 = self.LC(feat)
            pred1 = output1.data.max(1)[1]
            k = label.data.size()[0]
            correct += pred1.eq(label.data).cpu().sum()
            size += k
#        if save_model and epoch % self.save_epoch == 0:
#            torch.save(self.G,
#                       '%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, epoch))
#            torch.save(self.C1,
#                       '%s/%s_to_%s_model_epoch%s_C1.pt' % (self.checkpoint_dir, self.source, self.target, epoch))
#            torch.save(self.C2,
#                       '%s/%s_to_%s_model_epoch%s_C2.pt' % (self.checkpoint_dir, self.source, self.target, epoch))

        if record_file:
            record = open(record_file, 'a')
            record.write('%s\n' % (float(correct) / size, ))
            record.close()
        return float(correct) / size, epoch, size, self.G, self.LC, self.DC

    def test_best(self, G, LC, DC):
        G.eval()
        LC.eval()
        DC.eval()
        test_loss = 0
        correct = 0
        size = 0
        for batch_idx, data in enumerate(self.data_test):
            img = data['T']
            label = data['T_label']
            img, label = img.cuda(), label.long().cuda()
            img, label = Variable(img, volatile=True), Variable(label)
            label = label.squeeze()
            feat = G(img)
            output = LC(feat)
            test_loss += F.nll_loss(output, label).item()
            pred1 = output.data.max(1)[1]
            k = label.data.size()[0]
            correct += pred1.eq(label.data).cpu().sum()
            size += k
        test_loss = test_loss / size
        print('Best test target acc:', 100.0 * correct.numpy() / size, '%')
        return correct.numpy() / size

    def calc_correct_ensemble(self, G, LC, DC, x, y):
        x, y = x.cuda(), y.long().cuda()
        x, y = Variable(x, volatile=True), Variable(y)
        y = y.squeeze()
        feat = G(x)
        output = LC(feat)
        pred = output.data.max(1)[1]
        correct_num = pred.eq(y.data).cpu().sum()
        if len(y.data.size()) == 0:
            print('Error, the size of y is 0!')
            return 0, 0
        size_data = y.data.size()[0]
        return correct_num, size_data

    def calc_test_acc(self, G, LC, DC, set_name='T'):
        correct_all = 0
        size_all = 0
        for batch_idx, data in enumerate(self.data_test):
            correct_num, size_data = self.calc_correct_ensemble(
                G, LC, DC, data[set_name], data[set_name + '_label'])
            if 0 != size_data:
                correct_all += correct_num
                size_all += size_data
        return correct_all.numpy() / size_all

    def test_ensemble(self, G, LC, DC):
        G.eval()
        LC.eval()
        DC.eval()
        acc_s = self.calc_test_acc(G, LC, DC, set_name='S')
        print('Final test source acc:', 100.0 * acc_s, '%')
        acc_t = self.calc_test_acc(G, LC, DC, set_name='T')
        print('Final test target acc:', 100.0 * acc_t, '%')
        return acc_s, acc_t

    def input_feature(self):
        feature_vec = np.zeros(0)
        label_vec = np.zeros(0)
        domain_vec = np.zeros(0)
        for batch_idx, data in enumerate(self.data_test):
            if data['S'].shape[0] != self.batch_size or \
            data['T'].shape[0] != self.batch_size:
                continue
            if batch_idx > 6:
                break

            feature_s = data['S'].reshape((self.batch_size, -1))
            label_s = data['S_label'].squeeze()
            domain_s = np.zeros(label_s.shape)

            feature_t = data['T'].reshape((self.batch_size, -1))
            label_t = data['T_label'].squeeze()

            domain_t = np.ones(label_t.shape)

            feature_c = np.concatenate([feature_s, feature_t])

            if 0 == feature_vec.shape[0]:
                feature_vec = np.copy(feature_c)
            else:
                feature_vec = np.r_[feature_vec, feature_c]
            label_c = np.concatenate([label_s, label_t])
            domain_c = np.concatenate([domain_s, domain_t])
            label_vec = np.concatenate([label_vec, label_c])
            domain_vec = np.concatenate([domain_vec, domain_c])

        return feature_vec, label_vec, domain_vec
예제 #7
0
class Solver(object):
    def __init__(self,
                 args,
                 batch_size=64,
                 source='source',
                 target='target',
                 learning_rate=0.0002,
                 interval=100,
                 optimizer='adam',
                 num_k=4,
                 all_use=False,
                 checkpoint_dir=None,
                 save_epoch=10,
                 leave_one_num=-1,
                 model_name=''):
        self.args = args
        self.batch_size = batch_size
        self.source = source
        self.target = target
        self.num_k = num_k
        self.checkpoint_dir = checkpoint_dir
        self.save_epoch = save_epoch
        self.use_abs_diff = args.use_abs_diff
        self.leave_one_num = leave_one_num

        print('dataset loading')
        download()
        self.data_train, self.data_val, self.data_test = dataset_read(
            source,
            target,
            self.batch_size,
            is_resize=args.is_resize,
            leave_one_num=self.leave_one_num,
            dataset=args.dataset,
            sensor_num=args.sensor_num)
        print('load finished!')

        self.G = Generator(source=source,
                           target=target,
                           is_resize=args.is_resize,
                           dataset=args.dataset,
                           sensor_num=args.sensor_num)
        self.C1 = Classifier(source=source,
                             target=target,
                             is_resize=args.is_resize,
                             dataset=args.dataset)
        self.C2 = Classifier(source=source,
                             target=target,
                             is_resize=args.is_resize,
                             dataset=args.dataset)

        if args.eval_only:
            self.data_val = self.data_test
            self.G = torch.load(r'checkpoint/best_model_G' + model_name +
                                '.pt')
            self.C1 = torch.load(r'checkpoint/best_model_C1' + model_name +
                                 '.pt')
            self.C2 = torch.load(r'checkpoint/best_model_C2' + model_name +
                                 '.pt')

        self.G.cuda()
        self.C1.cuda()
        self.C2.cuda()
        self.interval = interval

        self.set_optimizer(which_opt=optimizer, lr=learning_rate)
        self.lr = learning_rate

    def set_optimizer(self, which_opt='momentum', lr=0.001, momentum=0.9):
        if which_opt == 'momentum':
            self.opt_g = optim.SGD(self.G.parameters(),
                                   lr=lr,
                                   weight_decay=0.0005,
                                   momentum=momentum)

            self.opt_c1 = optim.SGD(self.C1.parameters(),
                                    lr=lr,
                                    weight_decay=0.0005,
                                    momentum=momentum)
            self.opt_c2 = optim.SGD(self.C2.parameters(),
                                    lr=lr,
                                    weight_decay=0.0005,
                                    momentum=momentum)

        if which_opt == 'adam':
            self.opt_g = optim.Adam(self.G.parameters(),
                                    lr=lr,
                                    weight_decay=0.0005)

            self.opt_c1 = optim.Adam(self.C1.parameters(),
                                     lr=lr,
                                     weight_decay=0.0005)
            self.opt_c2 = optim.Adam(self.C2.parameters(),
                                     lr=lr,
                                     weight_decay=0.0005)

    def reset_grad(self):
        self.opt_g.zero_grad()
        self.opt_c1.zero_grad()
        self.opt_c2.zero_grad()

    def ent(self, output):
        return -torch.mean(output * torch.log(output + 1e-6))

    def discrepancy(self, out1, out2):
        return torch.mean(torch.abs(F.softmax(out1) - F.softmax(out2)))

    def train_souce_only(self, epoch, record_file=None):
        criterion = nn.CrossEntropyLoss().cuda()
        self.G.train()
        self.C1.train()
        self.C2.train()
        torch.cuda.manual_seed(1)

        for batch_idx, data in enumerate(self.data_train):
            img_s = data['S']
            label_s = data['S_label']
            if img_s.size()[0] < self.batch_size:
                break
            img_s = img_s.cuda()
            label_s = Variable(label_s.long().cuda())
            label_s = label_s.squeeze()
            img_s = Variable(img_s)
            self.reset_grad()
            feat_s = self.G(img_s)
            output_s1 = self.C1(feat_s)
            output_s2 = self.C2(feat_s)

            # print(label_s.shape)
            loss_s1 = criterion(output_s1, label_s)
            loss_s2 = criterion(output_s2, label_s)
            loss_s = loss_s1 + loss_s2
            loss_s.backward()
            self.opt_g.step()
            self.opt_c1.step()
            self.opt_c2.step()
            self.reset_grad()

            if batch_idx > 500:
                return batch_idx

            if batch_idx % self.interval == 0:
                if record_file:
                    record = open(record_file, 'a')
                    record.write('%s \n' % (loss_s.item()))
                    record.close()
        return batch_idx

    def train(self, epoch, record_file=None):
        criterion = nn.CrossEntropyLoss().cuda()
        self.G.train()
        self.C1.train()
        self.C2.train()
        torch.cuda.manual_seed(1)

        for batch_idx, data in enumerate(self.data_train):
            img_t = data['T']
            img_s = data['S']
            label_s = data['S_label']
            if img_s.size()[0] < self.batch_size or img_t.size(
            )[0] < self.batch_size:
                break
            img_s = img_s.cuda()
            img_t = img_t.cuda()
            label_s = Variable(label_s.long().cuda())
            label_s = label_s.squeeze()

            img_s = Variable(img_s)
            img_t = Variable(img_t)
            self.reset_grad()
            feat_s = self.G(img_s)
            output_s1 = self.C1(feat_s)
            output_s2 = self.C2(feat_s)

            # print(label_s.shape)
            loss_s1 = criterion(output_s1, label_s)
            loss_s2 = criterion(output_s2, label_s)
            loss_s = loss_s1 + loss_s2
            loss_s.backward()
            self.opt_g.step()
            self.opt_c1.step()
            self.opt_c2.step()
            self.reset_grad()

            feat_s = self.G(img_s)
            output_s1 = self.C1(feat_s)
            output_s2 = self.C2(feat_s)
            feat_t = self.G(img_t)
            output_t1 = self.C1(feat_t)
            output_t2 = self.C2(feat_t)

            loss_s1 = criterion(output_s1, label_s)
            loss_s2 = criterion(output_s2, label_s)
            loss_s = loss_s1 + loss_s2
            loss_dis = self.discrepancy(output_t1, output_t2)
            loss = loss_s - 4 * loss_dis  # 1: 92.9; 2: 93.1; 3: 93.5; 5:93.46%
            loss.backward()
            self.opt_c1.step()
            self.opt_c2.step()
            self.reset_grad()

            for i in range(self.num_k):
                #
                feat_t = self.G(img_t)
                output_t1 = self.C1(feat_t)
                output_t2 = self.C2(feat_t)
                loss_dis = self.discrepancy(output_t1, output_t2)
                loss_dis.backward()
                self.opt_g.step()
                self.reset_grad()
            if batch_idx > 500:
                return batch_idx

            if batch_idx % self.interval == 0:
                if record_file:
                    record = open(record_file, 'a')
                    record.write(
                        '%s %s %s\n' %
                        (loss_dis.item(), loss_s1.item(), loss_s2.item()))
                    record.close()
        return batch_idx

    def train_onestep(self, epoch, record_file=None):
        criterion = nn.CrossEntropyLoss().cuda()
        self.G.train()
        self.C1.train()
        self.C2.train()
        torch.cuda.manual_seed(1)

        for batch_idx, data in enumerate(self.data_train):
            img_t = data['T']
            img_s = data['S']
            label_s = data['S_label']
            if img_s.size()[0] < self.batch_size or img_t.size(
            )[0] < self.batch_size:
                break
            img_s = img_s.cuda()
            img_t = img_t.cuda()
            label_s = Variable(label_s.long().cuda())
            img_s = Variable(img_s)
            img_t = Variable(img_t)
            self.reset_grad()
            feat_s = self.G(img_s)
            output_s1 = self.C1(feat_s)
            output_s2 = self.C2(feat_s)
            loss_s1 = criterion(output_s1, label_s)
            loss_s2 = criterion(output_s2, label_s)
            loss_s = loss_s1 + loss_s2
            loss_s.backward(retain_variables=True)
            feat_t = self.G(img_t)
            self.C1.set_lambda(1.0)
            self.C2.set_lambda(1.0)
            output_t1 = self.C1(feat_t, reverse=True)
            output_t2 = self.C2(feat_t, reverse=True)
            loss_dis = -self.discrepancy(output_t1, output_t2)
            #loss_dis.backward()
            self.opt_c1.step()
            self.opt_c2.step()
            self.opt_g.step()
            self.reset_grad()
            if batch_idx > 500:
                return batch_idx

            if batch_idx % self.interval == 0:
                if record_file:
                    record = open(record_file, 'a')
                    record.write(
                        '%s %s %s\n' %
                        (loss_dis.data[0], loss_s1.data[0], loss_s2.data[0]))
                    record.close()
        return batch_idx

    def test(self, epoch, record_file=None, save_model=False):
        self.G.eval()
        self.C1.eval()
        self.C2.eval()
        test_loss = 0.0
        correct1 = 0.0
        correct2 = 0.0
        correct3 = 0.0
        size = 0.0
        for batch_idx, data in enumerate(self.data_val):
            img = data['T']
            label = data['T_label']
            img, label = img.cuda(), label.long().cuda()
            img, label = Variable(img, volatile=True), Variable(label)
            # label = label.squeeze()
            feat = self.G(img)
            output1 = self.C1(feat)
            output2 = self.C2(feat)
            test_loss += F.nll_loss(output1, label).item()
            output_ensemble = output1 + output2
            pred1 = output1.data.max(1)[1]
            pred2 = output2.data.max(1)[1]
            pred_ensemble = output_ensemble.data.max(1)[1]
            k = label.data.size()[0]
            correct1 += pred1.eq(label.data).cpu().sum()
            correct2 += pred2.eq(label.data).cpu().sum()
            correct3 += pred_ensemble.eq(label.data).cpu().sum()
            size += k
        test_loss = test_loss / size
        #        if save_model and epoch % self.save_epoch == 0:
        #            torch.save(self.G,
        #                       '%s/%s_to_%s_model_epoch%s_G.pt' % (self.checkpoint_dir, self.source, self.target, epoch))
        #            torch.save(self.C1,
        #                       '%s/%s_to_%s_model_epoch%s_C1.pt' % (self.checkpoint_dir, self.source, self.target, epoch))
        #            torch.save(self.C2,
        #                       '%s/%s_to_%s_model_epoch%s_C2.pt' % (self.checkpoint_dir, self.source, self.target, epoch))

        if record_file:
            record = open(record_file, 'a')
            record.write('%s %s %s\n' %
                         (float(correct1) / size, float(correct2) / size,
                          float(correct3) / size))
            record.close()
        return float(correct3) / size, epoch, size, self.G, self.C1, self.C2

    def test_best(self, G, C1, C2):
        G.eval()
        C1.eval()
        C2.eval()
        test_loss = 0
        correct1 = 0
        correct2 = 0
        correct3 = 0
        size = 0
        for batch_idx, data in enumerate(self.data_test):
            img = data['T']
            label = data['T_label']
            img, label = img.cuda(), label.long().cuda()
            img, label = Variable(img, volatile=True), Variable(label)
            label = label.squeeze()
            feat = G(img)
            output1 = C1(feat)
            output2 = C2(feat)
            test_loss += F.nll_loss(output1, label).item()
            output_ensemble = output1 + output2
            pred1 = output1.data.max(1)[1]
            pred2 = output2.data.max(1)[1]
            pred_ensemble = output_ensemble.data.max(1)[1]
            k = label.data.size()[0]
            correct1 += pred1.eq(label.data).cpu().sum()
            correct2 += pred2.eq(label.data).cpu().sum()
            correct3 += pred_ensemble.eq(label.data).cpu().sum()
            size += k
        test_loss = test_loss / size
        print('Best test target acc:', 100.0 * correct3.numpy() / size, '%')
        return correct3.numpy() / size

    def calc_correct_ensemble(self, G, C1, C2, x, y):
        x, y = x.cuda(), y.long().cuda()
        x, y = Variable(x, volatile=True), Variable(y)
        y = y.squeeze()
        feat = G(x)
        output1 = C1(feat)
        output2 = C2(feat)
        output_ensemble = output1 + output2
        pred_ensemble = output_ensemble.data.max(1)[1]
        correct_num = pred_ensemble.eq(y.data).cpu().sum()
        if len(y.data.size()) == 0:
            return 0, 0
        size_data = y.data.size()[0]
        return correct_num, size_data

    def calc_test_acc(self, G, C1, C2, set_name='T'):
        correct_all = 0
        size_all = 0
        for batch_idx, data in enumerate(self.data_test):
            correct_num, size_data = self.calc_correct_ensemble(
                G, C1, C2, data[set_name], data[set_name + '_label'])
            if 0 != size_data:
                correct_all += correct_num
                size_all += size_data
        return correct_all.numpy() / size_all

    def test_ensemble(self, G, C1, C2):
        G.eval()
        C1.eval()
        C2.eval()
        acc_s = self.calc_test_acc(G, C1, C2, set_name='S')
        print('Final test source acc:', 100.0 * acc_s, '%')
        acc_t = self.calc_test_acc(G, C1, C2, set_name='T')
        print('Final test target acc:', 100.0 * acc_t, '%')
        return acc_s, acc_t

    def input_feature(self):
        feature_vec = np.zeros(0)
        label_vec = np.zeros(0)
        domain_vec = np.zeros(0)
        for batch_idx, data in enumerate(self.data_test):
            if data['S'].shape[0] != self.batch_size or \
            data['T'].shape[0] != self.batch_size:
                continue
            if batch_idx > 6:
                break

            feature_s = data['S'].reshape((self.batch_size, -1))
            label_s = data['S_label'].squeeze()
            domain_s = np.zeros(label_s.shape)

            feature_t = data['T'].reshape((self.batch_size, -1))
            label_t = data['T_label'].squeeze()

            domain_t = np.ones(label_t.shape)

            feature_c = np.concatenate([feature_s, feature_t])

            if 0 == feature_vec.shape[0]:
                feature_vec = np.copy(feature_c)
            else:
                feature_vec = np.r_[feature_vec, feature_c]
            label_c = np.concatenate([label_s, label_t])
            domain_c = np.concatenate([domain_s, domain_t])
            label_vec = np.concatenate([label_vec, label_c])
            domain_vec = np.concatenate([domain_vec, domain_c])

        return feature_vec, label_vec, domain_vec

    def tsne_feature(self):
        self.G.eval()
        feature_vec = torch.tensor(()).cuda()
        label_vec = np.zeros(0)
        domain_vec = np.zeros(0)
        for batch_idx, data in enumerate(self.data_test):
            if data['S'].shape[0] != self.batch_size or \
            data['T'].shape[0] != self.batch_size:
                continue
            if batch_idx > 6:
                break
            img_s = data['S']
            label_s = data['S_label'].squeeze()
            domain_s = np.zeros(label_s.shape)

            img_t = data['T']
            label_t = data['T_label'].squeeze()
            domain_t = np.ones(label_t.shape)

            img_c = np.vstack([img_s, img_t])

            img_c = torch.from_numpy(img_c)
            img_c = img_c.cuda()
            img_c = Variable(img_c, volatile=True)

            feat_c = self.G(img_c)
            feature_vec = torch.cat((feature_vec, feat_c), 0)

            label_c = np.concatenate([label_s, label_t])
            domain_c = np.concatenate([domain_s, domain_t])
            label_vec = np.concatenate([label_vec, label_c])
            domain_vec = np.concatenate([domain_vec, domain_c])

        return feature_vec.cpu().detach().numpy(), label_vec, domain_vec
class Solver(object):
    def __init__(self,
                 args,
                 batch_size=256,
                 source='mnist',
                 target='usps',
                 learning_rate=0.02,
                 interval=100,
                 optimizer='momentum',
                 num_k=4,
                 all_use=False,
                 checkpoint_dir=None,
                 save_epoch=10):
        self.batch_size = batch_size
        self.source = source
        self.target = target
        self.num_k = num_k
        self.checkpoint_dir = checkpoint_dir
        self.save_epoch = save_epoch
        self.use_abs_diff = args.use_abs_diff
        self.all_use = all_use
        self.alpha = args.alpha
        self.beta = args.beta
        if self.source == 'svhn':
            self.scale = True
        else:
            self.scale = False
        print('dataset loading')
        self.datasets, self.dataset_test = dataset_read(source,
                                                        target,
                                                        self.batch_size,
                                                        scale=self.scale,
                                                        all_use=self.all_use)
        print('load finished!')
        self.G = Generator(source=source, target=target)
        self.C = Classifier(source=source, target=target)

        if args.eval_only:
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               self.checkpoint_dir, args.resume_epoch))
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))

        self.G.cuda()
        self.C.cuda()
        self.interval = interval

        self.set_optimizer(which_opt=optimizer, lr=learning_rate)
        self.lr = learning_rate

    def set_optimizer(self, which_opt='momentum', lr=0.02, momentum=0.9):
        if which_opt == 'momentum':
            self.opt_g = optim.SGD(self.G.parameters(),
                                   lr=lr,
                                   weight_decay=0.0005,
                                   momentum=momentum)

            self.opt_c = optim.SGD(self.C.parameters(),
                                   lr=lr,
                                   weight_decay=0.0005,
                                   momentum=momentum)

        if which_opt == 'adam':
            self.opt_g = optim.Adam(self.G.parameters(),
                                    lr=lr,
                                    weight_decay=0.0005)

            self.opt_c = optim.Adam(self.C.parameters(),
                                    lr=lr,
                                    weight_decay=0.0005)

    def reset_grad(self):
        self.opt_g.zero_grad()
        self.opt_c.zero_grad()

    def get_entropy_loss(self, p_softmax):
        mask = p_softmax.ge(0.000001)
        mask_out = torch.masked_select(p_softmax, mask)
        entropy = -(torch.sum(mask_out * torch.log(mask_out)))
        return 0.1 * (entropy / float(p_softmax.size(0)))

    def discrepancy(self, out1, out2):
        return torch.mean(torch.abs(F.softmax(out1) - F.softmax(out2)))

    def train(self, epoch, record_file=None):
        criterion = nn.CrossEntropyLoss().cuda()
        # initialze a L1 loss for DAL
        criterionDAL = nn.L1Loss().cuda()

        self.G.train()
        self.C.train()
        torch.cuda.manual_seed(1)

        Tensor = torch.cuda.FloatTensor

        for batch_idx, data in enumerate(self.datasets):
            img_t = data['T']
            img_s = data['S']
            label_s = data['S_label']
            if img_s.size()[0] < self.batch_size or img_t.size(
            )[0] < self.batch_size:
                break
            img_s = img_s.cuda()
            img_t = img_t.cuda()
            label_s = Variable(label_s.long().cuda())

            # for mnist or usps (source)
            zn = Variable(Tensor(np.random.normal(0, 1, (4096, 48))))
            # for svhn (source)
            #zn = Variable(Tensor(np.random.normal(0,1, (16384, 128))))

            img_s = Variable(img_s)
            img_t = Variable(img_t)

            self.reset_grad()

            feat_s = self.G(img_s)
            output_s = self.C(feat_s)
            feat_t = self.G(img_t)
            output_t = self.C(feat_t)

            # for mnist or usps (source)
            feat_s_kl = feat_s.view(-1, 48)
            # for svhn (source)
            #feat_s_kl = feat_s.view(-1,128)

            loss_kld_s = F.kl_div(F.log_softmax(feat_s_kl), F.softmax(zn))

            loss_s = criterion(output_s, label_s)

            loss = loss_s + self.alpha * loss_kld_s
            loss.backward()

            self.opt_g.step()
            self.opt_c.step()
            self.reset_grad()

            feat_t = self.G(img_t)
            output_t = self.C(feat_t)
            feat_t_recon = self.G(img_t, is_deconv=True)

            feat_zn_recon = self.G.decode(zn)
            # DAL
            loss_dal = criterionDAL(feat_t_recon, feat_zn_recon)

            # entropy loss
            t_prob = F.softmax(output_t)
            t_entropy_loss = self.get_entropy_loss(t_prob)

            loss = t_entropy_loss + self.beta * loss_dal
            loss.backward()

            self.opt_g.step()
            self.opt_c.step()
            self.reset_grad()

            if batch_idx > 500:
                return batch_idx

            if batch_idx % self.interval == 0:
                print(
                    'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t   Entropy: {:.6f}'
                    .format(epoch, batch_idx, 100, 100. * batch_idx / 70000,
                            loss_s.item(), t_entropy_loss.item()))
                if record_file:
                    record = open(record_file, 'a')
                    record.write('%s %s\n' %
                                 (t_entropy_loss.item(), loss_s.item()))
                    record.close()
            torch.save(
                self.G, '%s/%s_to_%s_model_epoch%s_G.pt' %
                (self.checkpoint_dir, self.source, self.target, epoch))
        return batch_idx

    def test(self, epoch, record_file=None, save_model=False):
        self.G.eval()
        self.C.eval()

        test_loss = 0
        correct = 0
        size = 0
        for batch_idx, data in enumerate(self.dataset_test):
            img = data['T']
            label = data['T_label']
            img, label = img.cuda(), label.long().cuda()
            img, label = Variable(img, volatile=True), Variable(label)
            feat = self.G(img)
            output = self.C(feat)

            test_loss += F.nll_loss(output, label).item()
            pred = output.data.max(1)[1]

            k = label.data.size()[0]
            correct += pred.eq(label.data).cpu().sum()

            size += k
        test_loss = test_loss / size
        print(
            '\nTest set: Average loss: {:.4f}, Accuracy C: {}/{} ({:.0f}%) \n'.
            format(test_loss, correct, size, 100. * correct / size))
        if save_model and epoch % self.save_epoch == 0:
            torch.save(
                self.G, '%s/%s_to_%s_model_epoch%s_G.pt' %
                (self.checkpoint_dir, self.source, self.target, epoch))
            torch.save(
                self.C, '%s/%s_to_%s_model_epoch%s_C.pt' %
                (self.checkpoint_dir, self.source, self.target, epoch))
        if record_file:
            record = open(record_file, 'a')
            print('recording %s', record_file)
            record.write('%s\n' % (float(correct) / size))
            record.close()