Example #1
0
    def __init__(self,
                 args,
                 batch_size=128,
                 target='mnistm',
                 learning_rate=0.0002,
                 interval=10,
                 optimizer='adam',
                 checkpoint_dir=None,
                 save_epoch=10):
        self.batch_size = batch_size
        self.target = target
        self.checkpoint_dir = checkpoint_dir
        self.save_epoch = save_epoch
        self.interval = interval
        self.lr = learning_rate
        self.best_correct = 0
        self.args = args
        if self.args.use_target:
            self.ndomain = self.args.ndomain
        else:
            self.ndomain = self.args.ndomain - 1

        # load source and target domains
        self.datasets, self.dataset_test, self.dataset_size = dataset_read(
            target, self.batch_size)
        self.niter = self.dataset_size / self.batch_size
        print('Dataset loaded!')

        # define the feature extractor and GCN-based classifier
        self.G = Generator(self.args.net)
        self.GCN = GCN(nfeat=args.nfeat, nclasses=args.nclasses)
        self.G.cuda()
        self.GCN.cuda()
        print('Model initialized!')

        if self.args.load_checkpoint is not None:
            self.state = torch.load(self.args.load_checkpoint)
            self.G.load_state_dict(self.state['G'])
            self.GCN.load_state_dict(self.state['GCN'])
            print('Model load from: ', self.args.load_checkpoint)

        # initialize statistics (prototypes and adjacency matrix)
        if self.args.load_checkpoint is None:
            self.mean = torch.zeros(args.nclasses * self.ndomain,
                                    args.nfeat).cuda()
            self.adj = torch.zeros(args.nclasses * self.ndomain,
                                   args.nclasses * self.ndomain).cuda()
            print('Statistics initialized!')
        else:
            self.mean = self.state['mean'].cuda()
            self.adj = self.state['adj'].cuda()
            print('Statistics loaded!')

        # define the optimizer
        self.set_optimizer(which_opt=optimizer, lr=self.lr)
        print('Optimizer defined!')
    def __init__(self,
                 args,
                 batch_size=64,
                 source='svhn',
                 target='mnist',
                 learning_rate=0.0002,
                 interval=100,
                 optimizer='adam',
                 num_k=4,
                 all_use=False,
                 checkpoint_dir=None,
                 save_epoch=10):
        self.batch_size = batch_size
        self.source = source
        self.target = target
        self.num_k = num_k
        self.checkpoint_dir = checkpoint_dir
        self.save_epoch = save_epoch
        self.use_abs_diff = args.use_abs_diff
        self.all_use = all_use
        if self.source == 'svhn':
            self.scale = True
        else:
            self.scale = False
        print('dataset loading')
        self.datasets, self.dataset_test = dataset_read(source,
                                                        target,
                                                        self.batch_size,
                                                        scale=self.scale,
                                                        all_use=self.all_use)
        print('load finished!')
        self.G = Generator(source=source, target=target)
        self.C1 = Classifier(source=source, target=target)
        self.C2 = Classifier(source=source, target=target)
        self.D1 = D(source=source, target=target)
        self.D2 = D(source=source, target=target)
        if args.eval_only:
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               self.checkpoint_dir, args.resume_epoch))
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))

        self.G.cuda()
        self.C1.cuda()
        self.C2.cuda()
        self.D1.cuda()
        self.D2.cuda()
        self.interval = interval

        self.set_optimizer(which_opt=optimizer, lr=learning_rate)
        self.lr = learning_rate
    def __init__(self,
                 args,
                 batch_size=64,
                 target='It doesnt matter',
                 learning_rate=0.0002,
                 interval=100,
                 optimizer='adam',
                 checkpoint_dir=None,
                 save_epoch=10):
        self.batch_size = batch_size
        self.target = target
        self.checkpoint_dir = checkpoint_dir
        self.save_epoch = save_epoch
        self.use_abs_diff = args.use_abs_diff

        print('dataset loading')
        #self.datasets, self.dataset_test = dataset_read(target, self.batch_size)
        self.datasets = dataset_read(target, self.batch_size, args)
        #print(self.dataset['S1'].shape)

        print('load finished!')
        self.G = Generator()
        self.C1 = Classifier()
        self.C2 = Classifier()
        print('model_loaded')

        if args.eval_only:
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               self.checkpoint_dir, args.resume_epoch))
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))
        """
        self.G.cuda()
        self.C1.cuda()
        self.C2.cuda()
        """

        self.G
        self.C1
        self.C2

        self.interval = interval

        self.set_optimizer(which_opt=optimizer, lr=learning_rate)
        self.lr = learning_rate
        print('initialize complete')
Example #4
0
    def __init__(self, args):

        super().__init__()

        self.opt_losses = args.opt_losses
        self.check_list = [
            'disc_ds_di_C_D', 'disc_ds_di_G', 'ring', 'confusion_G'
        ]
        assert all([ol in self.check_list
                    for ol in self.opt_losses]), 'Check loss entries'
        opt_losses_str = ",".join(map(str, self.opt_losses))
        timestring = strftime("%Y-%m-%d_%H-%M-%S",
                              gmtime()) + "_{}_optloss={}_src={}".format(
                                  args.exp_name, opt_losses_str, args.source)
        self.logdir = os.path.join('./logs', timestring)
        self.logger = SummaryWriter(log_dir=self.logdir)
        self.device = torch.device("cuda" if args.use_cuda else "cpu")

        self.src_domain_code = np.repeat(np.array([[*([1]), *([0])]]),
                                         args.batch_size,
                                         axis=0)
        self.trg_domain_code = np.repeat(np.array([[*([0]), *([1])]]),
                                         args.batch_size,
                                         axis=0)
        self.src_domain_code = torch.FloatTensor(self.src_domain_code).to(
            self.device)
        self.trg_domain_code = torch.FloatTensor(self.trg_domain_code).to(
            self.device)

        self.source = args.source
        self.target = args.target
        self.num_k = args.num_k
        self.checkpoint_dir = args.checkpoint_dir
        self.save_epoch = args.save_epoch
        self.use_abs_diff = args.use_abs_diff

        self.mi_k = 1
        self.delta = 0.01
        self.mi_coeff = 0.0001
        self.interval = 10  # write on tb every
        self.batch_size = args.batch_size
        self.which_opt = 'adam'
        self.lr = args.lr
        self.scale = 32
        self.global_step = 0

        print('Loading datasets')
        self.dataset_train, self.dataset_test = dataset_read(
            args.data_dir, self.source, self.target, self.batch_size,
            self.scale)
        print('Done!')

        self.total_batches = {
            'train': self.get_dataset_size('train'),
            'test': self.get_dataset_size('test')
        }

        self.G = Generator(source=self.source, target=self.target)
        self.FD = Feature_Discriminator()
        self.R = Reconstructor()
        self.MI = Mine()

        self.C = nn.ModuleDict({
            'ds':
            Classifier(source=self.source, target=self.target),
            'di':
            Classifier(source=self.source, target=self.target),
            'ci':
            Classifier(source=self.source, target=self.target)
        })

        self.D = nn.ModuleDict({
            'ds': Disentangler(),
            'di': Disentangler(),
            'ci': Disentangler()
        })

        # All modules in the same dict
        self.components = nn.ModuleDict({
            'G': self.G,
            'FD': self.FD,
            'R': self.R,
            'MI': self.MI
        })

        self.xent_loss = nn.CrossEntropyLoss().to(self.device)
        self.adv_loss = nn.BCEWithLogitsLoss().to(self.device)
        self.ring_loss = RingLoss(type='auto', loss_weight=1.0).to(self.device)
        self.set_optimizer(lr=self.lr)
        self.to_device()
Example #5
0
    def __init__(self,
                 args,
                 batch_size=64,
                 source='svhn',
                 target='mnist',
                 learning_rate=0.0002,
                 interval=100,
                 optimizer='adam',
                 num_k=4,
                 all_use=False,
                 checkpoint_dir=None,
                 save_epoch=10):
        self.src_domain_code = np.repeat(np.array([[*([1]), *([0])]]),
                                         batch_size,
                                         axis=0)
        self.tgt_domain_code = np.repeat(np.array([[*([0]), *([1])]]),
                                         batch_size,
                                         axis=0)
        self.src_domain_code = Variable(torch.FloatTensor(
            self.src_domain_code).cuda(),
                                        requires_grad=False)
        self.tgt_domain_code = Variable(torch.FloatTensor(
            self.tgt_domain_code).cuda(),
                                        requires_grad=False)
        self.batch_size = batch_size

        self.source = source
        self.target = target
        self.num_k = num_k
        self.mi_k = 1
        self.checkpoint_dir = checkpoint_dir
        self.save_epoch = save_epoch
        self.use_abs_diff = args.use_abs_diff
        self.all_use = all_use
        self.belta = 0.01
        self.mi_para = 0.0001
        if self.source == 'svhn':
            self.scale = False
        else:
            self.scale = False
        print('dataset loading')
        self.datasets, self.dataset_test = dataset_read(source,
                                                        target,
                                                        self.batch_size,
                                                        scale=self.scale,
                                                        all_use=self.all_use)
        print('load finished!')
        self.G = Generator(source=source, target=target)
        self.D0 = Disentangler()
        self.D1 = Disentangler()
        self.D2 = Disentangler()

        self.C0 = Classifier(source=source, target=target)
        self.C1 = Classifier(source=source, target=target)
        self.C2 = Classifier(source=source, target=target)
        self.FD = Feature_Discriminator()
        self.R = Reconstructor()
        # Mutual information network estimation
        self.M = Mine()

        if args.eval_only:
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               self.checkpoint_dir, args.resume_epoch))
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))

        self.G.cuda()
        self.C0.cuda()
        self.C1.cuda()
        self.C2.cuda()

        self.D0.cuda()
        self.D1.cuda()
        self.D2.cuda()
        self.FD.cuda()
        self.R.cuda()
        self.M.cuda()

        self.interval = interval

        self.set_optimizer(which_opt=optimizer, lr=learning_rate)
        self.lr = learning_rate
Example #6
0
    def __init__(self,
                 args,
                 batch_size=64,
                 source='svhn',
                 target='mnist',
                 learning_rate=0.0002,
                 interval=1,
                 optimizer='adam',
                 num_k=4,
                 all_use=False,
                 checkpoint_dir=None,
                 save_epoch=10):

        timestring = strftime("%Y-%m-%d_%H-%M-%S",
                              gmtime()) + "_%s" % args.exp_name
        self.logdir = os.path.join('./logs', timestring)
        self.logger = SummaryWriter(log_dir=self.logdir)
        self.device = torch.device("cuda" if args.use_cuda else "cpu")

        self.src_domain_code = np.repeat(np.array([[*([1]), *([0])]]),
                                         batch_size,
                                         axis=0)
        self.trg_domain_code = np.repeat(np.array([[*([0]), *([1])]]),
                                         batch_size,
                                         axis=0)
        self.src_domain_code = torch.FloatTensor(self.src_domain_code).to(
            self.device)
        self.trg_domain_code = torch.FloatTensor(self.trg_domain_code).to(
            self.device)

        self.source = source
        self.target = target
        self.num_k = num_k
        self.mi_k = 1
        self.checkpoint_dir = checkpoint_dir
        self.save_epoch = save_epoch
        self.use_abs_diff = args.use_abs_diff
        self.all_use = all_use
        self.delta = 0.01
        self.mi_coeff = 0.0001
        self.interval = interval
        self.batch_size = batch_size
        self.lr = learning_rate
        self.scale = False

        print('dataset loading')
        self.datasets, self.dataset_test = dataset_read(source,
                                                        target,
                                                        self.batch_size,
                                                        scale=self.scale,
                                                        all_use=self.all_use)
        print('load finished!')

        self.G = Generator(source=source, target=target)
        self.FD = Feature_Discriminator()
        self.R = Reconstructor()
        self.MI = Mine()

        self.C = nn.ModuleDict({
            'ds': Classifier(source=source, target=target),
            'di': Classifier(source=source, target=target),
            'ci': Classifier(source=source, target=target)
        })

        self.D = nn.ModuleDict({
            'ds': Disentangler(),
            'di': Disentangler(),
            'ci': Disentangler()
        })

        # All modules in the same dict
        self.modules = nn.ModuleDict({
            'G': self.G,
            'FD': self.FD,
            'R': self.R,
            'MI': self.MI
        })

        if args.eval_only:
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))

        self.xent_loss = nn.CrossEntropyLoss().cuda()
        self.adv_loss = nn.BCEWithLogitsLoss().cuda()
        self.set_optimizer(which_opt=optimizer, lr=learning_rate)
        self.to_device()
Example #7
0
    def __init__(self,
                 args,
                 batch_size=64,
                 source='source',
                 target='target',
                 learning_rate=0.0002,
                 interval=100,
                 optimizer='adam',
                 num_k=4,
                 all_use=False,
                 checkpoint_dir=None,
                 save_epoch=10,
                 leave_one_num=-1,
                 model_name=''):
        self.args = args
        self.batch_size = batch_size
        self.source = source
        self.target = target
        self.num_k = num_k
        self.checkpoint_dir = checkpoint_dir
        self.save_epoch = save_epoch
        self.use_abs_diff = args.use_abs_diff
        self.leave_one_num = leave_one_num

        print('dataset loading')
        download()
        self.data_train, self.data_val, self.data_test = dataset_read(
            source,
            target,
            self.batch_size,
            is_resize=args.is_resize,
            leave_one_num=self.leave_one_num,
            dataset=args.dataset,
            sensor_num=args.sensor_num)
        print('load finished!')

        self.G = Generator(source=source,
                           target=target,
                           is_resize=args.is_resize,
                           dataset=args.dataset,
                           sensor_num=args.sensor_num)
        self.LC = Classifier(source=source,
                             target=target,
                             is_resize=args.is_resize,
                             dataset=args.dataset)
        self.DC = DomainClassifier(source=source,
                                   target=target,
                                   is_resize=args.is_resize,
                                   dataset=args.dataset)

        if args.eval_only:
            self.data_val = self.data_test
            self.G = torch.load(r'checkpoint_DANN/best_model_G' + model_name +
                                '.pt')
            self.LC = torch.load(r'checkpoint_DANN/best_model_C1' +
                                 model_name + '.pt')
            self.DC = torch.load(r'checkpoint_DANN/best_model_C2' +
                                 model_name + '.pt')

        self.G.cuda()
        self.LC.cuda()
        self.DC.cuda()
        self.interval = interval

        self.set_optimizer(which_opt=optimizer, lr=learning_rate)
        self.lr = learning_rate
    def __init__(self,
                 args,
                 batch_size=64,
                 target='mnist',
                 learning_rate=0.0002,
                 interval=100,
                 optimizer='adam',
                 checkpoint_dir=None,
                 save_epoch=10):
        self.batch_size = batch_size
        self.target = target
        self.checkpoint_dir = checkpoint_dir
        self.save_epoch = save_epoch
        self.use_abs_diff = args.use_abs_diff
        self.dl_type = args.dl_type

        self.args = args

        self.best_loss = 9999999
        self.best_acc = 0
        print('dataset loading')
        if args.data == 'digits':
            if args.dl_type == 'original':
                self.datasets, self.dataset_test, self.dataset_valid = dataset_read(
                    target, self.batch_size)
            elif args.dl_type == 'hard_cluster':
                self.datasets, self.dataset_test, self.dataset_valid = dataset_hard_cluster(
                    target, self.batch_size, args.num_domain)
            elif args.dl_type == 'soft_cluster':
                self.datasets, self.dataset_test, self.dataset_valid = dataset_combined(
                    target, self.batch_size, args.num_domain,
                    args.office_directory)
            elif args.dl_type == 'source_only':
                self.datasets, self.dataset_test, self.dataset_valid = dataset_combined(
                    target, self.batch_size, args.num_domain,
                    args.office_directory)
            elif args.dl_type == 'source_target_only':
                self.datasets, self.dataset_test, self.dataset_valid = dataset_combined(
                    target, self.batch_size, args.num_domain,
                    args.office_directory)
            else:
                raise Exception('Type of experiment undefined')

            print('load finished!')
            num_classes = 10
            num_domains = args.num_domain
            self.num_domains = num_domains
            self.entropy_wt = 0.01
            self.msda_wt = 0.1
            self.kl_wt = args.kl_wt
            self.to_detach = args.to_detach
            self.G = Generator_digit()
            self.C1 = Classifier_digit()
            self.C2 = Classifier_digit()
            self.DP = DP_Digit(num_domains)
        elif args.data == 'cars':
            if args.dl_type == 'soft_cluster':
                self.datasets, self.dataset_test, self.dataset_valid = cars_combined(
                    target, self.batch_size)
            elif args.dl_type == 'source_target_only':
                self.datasets, self.dataset_test, self.dataset_valid = cars_combined(
                    target, self.batch_size)
            elif args.dl_type == 'source_only':
                self.datasets, self.dataset_test, self.dataset_valid = cars_combined(
                    target, self.batch_size)
            print('load finished!')
            self.entropy_wt = 0.1
            self.msda_wt = 0.25
            self.to_detach = args.to_detach
            num_classes = 163
            num_domains = args.num_domain
            self.num_domains = num_domains
            self.G = Generator_cars()
            self.C1 = Classifier_cars(num_classes)
            self.C2 = Classifier_cars(num_classes)
            self.DP = DP_cars(num_domains)
        elif args.data == 'office':
            if args.dl_type == 'soft_cluster':
                self.datasets, self.dataset_test, self.dataset_valid = office_combined(
                    target, self.batch_size, args.office_directory, args.seed)
            elif args.dl_type == 'source_target_only':
                self.datasets, self.dataset_test, self.dataset_valid = office_combined(
                    target, self.batch_size, args.office_directory, args.seed)
            elif args.dl_type == 'source_only':
                self.datasets, self.dataset_test, self.dataset_valid = office_combined(
                    target, self.batch_size, args.office_directory, args.seed)

            print('load finished!')
            self.entropy_wt = 1
            self.msda_wt = 0.25
            self.kl_wt = args.kl_wt
            self.to_detach = args.to_detach
            num_classes = 31
            num_domains = args.num_domain
            self.num_domains = num_domains
            self.G = Generator_office()
            self.C1 = Classifier_office(num_classes)
            self.C2 = Classifier_office(num_classes)
            self.DP = DP_office(num_domains)
        # print(self.dataset['S1'].shape)
        print('model_loaded')

        self.set_optimizer(which_opt=optimizer, lr=learning_rate)
        print('ARGS EVAL ONLY : ', args.eval_only)
        if args.eval_only:
            print('Loading state from: ',
                  '%s/%s_model_best.pth' % (self.checkpoint_dir, self.target))
            checkpoint = torch.load('%s/%s_model_best.pth' %
                                    (self.checkpoint_dir, self.target))
            self.G.load_state_dict(checkpoint['G_state_dict'])
            self.C1.load_state_dict(checkpoint['C1_state_dict'])
            self.C2.load_state_dict(checkpoint['C2_state_dict'])
            self.DP.load_state_dict(checkpoint['DP_state_dict'])

            self.opt_g.load_state_dict(checkpoint['G_state_dict_opt'])
            self.opt_c1.load_state_dict(checkpoint['C1_state_dict_opt'])
            self.opt_c2.load_state_dict(checkpoint['C2_state_dict_opt'])
            self.opt_dp.load_state_dict(checkpoint['DP_state_dict_opt'])

        self.G.cuda()
        self.C1.cuda()
        self.C2.cuda()
        self.DP.cuda()
        self.interval = interval
        if args.data == 'cars':
            milestones = [100]
        else:
            milestones = [100]
        self.sche_g = torch.optim.lr_scheduler.MultiStepLR(self.opt_g,
                                                           milestones,
                                                           gamma=0.1)
        self.sche_c1 = torch.optim.lr_scheduler.MultiStepLR(self.opt_c1,
                                                            milestones,
                                                            gamma=0.1)
        self.sche_c2 = torch.optim.lr_scheduler.MultiStepLR(self.opt_c2,
                                                            milestones,
                                                            gamma=0.1)
        self.sche_dp = torch.optim.lr_scheduler.MultiStepLR(self.opt_dp,
                                                            milestones,
                                                            gamma=0.1)

        self.lr = learning_rate
        print('initialize complete')
Example #9
0
    def __init__(self,
                 args,
                 batch_size=128,
                 source='usps',
                 target='mnist',
                 learning_rate=0.0002,
                 interval=100,
                 optimizer='adam',
                 num_k=4,
                 all_use=False,
                 checkpoint_dir=None,
                 save_epoch=10):
        self.batch_size = batch_size
        self.source = source
        self.target = target
        self.num_k = num_k
        self.checkpoint_dir = checkpoint_dir
        self.save_epoch = save_epoch
        self.use_abs_diff = args.use_abs_diff
        self.all_use = all_use
        self.class_num = 10
        self.num_k1 = 8
        self.num_k2 = 1
        self.num_k3 = 8
        self.num_k4 = 1
        #self.offset =0.1
        self.output_cr_t_C_label = np.zeros(self.batch_size)

        if self.source == 'svhn':
            self.scale = True
        else:
            self.scale = False
        print('dataset loading')
        self.datasets, self.dataset_test = dataset_read(source,
                                                        target,
                                                        self.batch_size,
                                                        scale=self.scale,
                                                        all_use=self.all_use)
        print('load finished!')
        self.G = Generator(source=source, target=target)
        self.C = Classifier(source=source, target=target)
        self.C1 = Classifier(source=source, target=target)
        self.C2 = Classifier(source=source, target=target)
        self.D = discriminator(source=source, target=target)
        if args.eval_only:
            self.G.torch.load('%s/%s_to_%s_model_epoch%s_G.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))
            self.C.torch.load('%s/%s_to_%s_model_epoch%s_C.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))
            self.C1.torch.load('%s/%s_to_%s_model_epoch%s_C1.pt' %
                               (self.checkpoint_dir, self.source, self.target,
                                args.resume_epoch))
            self.C2.torch.load('%s/%s_to_%s_model_epoch%s_C2.pt' %
                               (self.checkpoint_dir, self.source, self.target,
                                args.resume_epoch))
            self.D.torch.load('%s/%s_to_%s_model_epoch%s_D.pt' %
                              (self.checkpoint_dir, self.source, self.target,
                               args.resume_epoch))

        self.G.cuda()
        self.C.cuda()
        self.C1.cuda()
        self.C2.cuda()
        self.D.cuda()
        self.interval = interval

        self.set_optimizer(which_opt=optimizer, lr=learning_rate)
        self.lr = learning_rate
    def __init__(
            self, 
            args,
            batch_size=64, 
            source='svhn', 
            target='mnist',
            learning_rate=0.0002, 
            interval=100, 
            optimizer='adam', 
            num_k=4,
            all_use=False, 
            checkpoint_dir=None, 
            save_epoch=10,
            num_classifiers_train=2, 
            num_classifiers_test=20,
            init='kaiming_u', 
            use_init=False, 
            dis_metric='L1'
    ):

        self.batch_size = batch_size
        self.source = source
        self.target = target
        self.num_k = num_k
        self.checkpoint_dir = checkpoint_dir
        self.save_epoch = save_epoch
        self.use_abs_diff = args.use_abs_diff
        self.all_use = all_use
        self.num_classifiers_train = num_classifiers_train
        self.num_classifiers_test = num_classifiers_test
        self.init = init
        self.dis_metric = dis_metric
        self.use_init = use_init

        if self.source == 'svhn':
            self.scale = True
        else:
            self.scale = False

        print('dataset loading')
        self.datasets, self.dataset_test = dataset_read(
            source, target,
            self.batch_size,
            scale=self.scale,
            all_use=self.all_use
        )
        print('load finished!')

        self.G = Generator(source=source, target=target)
        self.C = Classifier(
            source=source, target=target,
            num_classifiers_train=self.num_classifiers_train,
            num_classifiers_test=self.num_classifiers_test,
            init=self.init,
            use_init=self.use_init
        )

        if args.eval_only:
            self.G.torch.load('{}/{}_to_{}_model_epoch{}_G.pt'.format(
                self.checkpoint_dir,
                self.source,
                self.target,
                args.resume_epoch)
            )

            self.C.torch.load('{}/{}_to_{}_model_epoch{}_C.pt'.format(
                self.checkpoint_dir,
                self.source,
                self.target,
                args.resume_epoch)
            )

        self.G.cuda()
        self.C.cuda()
        self.interval = interval
        self.writer = SummaryWriter()

        self.opt_c, self.opt_g = self.set_optimizer(
            which_opt=optimizer,
            lr=learning_rate
        )
        self.lr = learning_rate

        # Learning rate scheduler
        self.scheduler_g = optim.lr_scheduler.CosineAnnealingLR(
            self.opt_g,
            float(args.max_epoch)
        )
        self.scheduler_c = optim.lr_scheduler.CosineAnnealingLR(
            self.opt_c,
            float(args.max_epoch)
        )