示例#1
0
 def __init__(self, in_features, out_features, latent_domain_num=2):
     super(MD_MMD_Layer, self).__init__()
     self.latent_domain_num = latent_domain_num
     self.in_features = in_features
     self.out_features = out_features
     self.aux_classifier = nn.Linear(self.in_features,
                                     self.latent_domain_num)
     self.layers = []
     for i in range(self.latent_domain_num):
         self.layers.append(
             nn.Linear(self.in_features, self.out_features).to(DEVICE))
     self.cluster_ciriterion = ClusterLoss1()
     self.entropy_ciriterion = EntropyLoss()
示例#2
0
 def __init__(self, in_features, out_features, latent_domain_num=2):
     super(MDCL, self).__init__()
     DEVICE = torch.device("cuda" if torch.cuda.is_available else "cpu")
     self.latent_domain_num = latent_domain_num
     self.in_features, self.out_features = in_features, out_features
     self.aux_classifier = nn.Linear(self.in_features,
                                     self.latent_domain_num)
     self.layers = []
     for i in range(self.latent_domain_num):
         self.layers.append(
             nn.Linear(self.in_features, self.out_features).to(DEVICE))
     self.cluster_ciriterion = ClusterLoss1()
     self.entropy_ciriterion = EntropyLoss()
     self.moving_center = None
     self.moving_factor = 0.9
示例#3
0
    print(net)
    net.avgpool = nn.AvgPool2d(28, 28)
    net.fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(512, 5))  # OULU's paper.
    load_file = None
    #'/gpfs/data/denizlab/Users/bz1030/KneeNet/KneeProject/model/model_torch/model_flatten_linear_layer/model_weights3/epoch_3.pth'
    if load_file:
        net.load_state_dict(torch.load(load_file))
        start_epoch = 3
    else:
        start_epoch = 0
    net = net.to(device)
    optimizer = optim.Adam(net.parameters(), lr=0.0001,weight_decay=1e-4)
    # Network

    print('############### Model Finished ####################')
    criterion = EntropyLoss(beta = beta)

    train_losses = []
    val_losses = []
    val_mse = []
    val_kappa = []
    val_acc = []

    best_dice = 0
    prev_model = None
    iteration = 500
    train_started = time.time()
    with open(output_file_path, 'a+') as f:
        f.write('######## Train Start #######\n')
    for epoch in range(start_epoch, EPOCH):
        train_loader = data.DataLoader(dataset_train, batch_size=8, shuffle=True)
示例#4
0
    def train(self):
        start_time = time.time()
        self.features.train()
        self.bottleneck_layer.train()
        self.classifier.train()

        # prepare data
        src_loader, tar_loader, tar_test_loader = data_loader_dict[
            self.cfg.dataset](self.cfg, val=False)
        src_iter_len, tar_iter_len = len(src_loader), len(tar_loader)
        print("data_size[src: {:.0f}, tar: {:.0f}]".format(
            len(src_loader.dataset), len(tar_loader.dataset)))

        #loss
        classifier_ciriterion = nn.CrossEntropyLoss()
        entropy_ciriterion = EntropyLoss()
        MMD_ciriterion = MMDLoss()

        optimizer = optim.Adam(
                self.features.get_param_groups(self.cfg.learning_rate) + \
                self.bottleneck_layer.get_param_groups(self.cfg.new_layer_learning_rate) + \
                self.classifier.get_param_groups(self.cfg.new_layer_learning_rate),
                weight_decay = 0.01)

        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.cfg.max_iter)

        # train
        best_src_acc, best_tar_acc, best_tar_test_acc = 0.0, 0.0, 0.0
        early_stop_acc = -1.0
        epoch_src_acc, epoch_tar_acc = 0.0, 0.0
        epoch_src_correct, epoch_tar_correct = 0.0, 0.0
        move_average_loss = 0.0
        move_factor = 0.9
        test_acc_list = []
        for _iter in range(self.cfg.max_iter):
            if _iter % src_iter_len == 0:
                src_iter = iter(src_loader)

                epoch_src_acc = epoch_src_correct / (src_iter_len *
                                                     self.cfg.batch_size)
                if epoch_src_acc > best_src_acc:
                    best_src_acc = epoch_src_acc

                print("\n" + "-" * 100)
                print(
                    "Iter[{:02d}/{:03d}] Acc[src:{:.4f}, tar:{:.4f}] Best Acc[src:{:.4f}, tar:{:.4f}] Src Update"
                    .format(_iter, self.cfg.max_iter, epoch_src_acc,
                            epoch_tar_acc, best_src_acc, best_tar_acc))
                print("-" * 100 + "\n")
                epoch_src_correct = 0.0

            if _iter % tar_iter_len == 0 or _iter == self.cfg.max_iter - 1:
                tar_iter = iter(tar_loader)

                epoch_tar_acc = epoch_tar_correct / (tar_iter_len *
                                                     self.cfg.batch_size)
                if epoch_tar_acc > best_tar_acc:
                    best_tar_acc = epoch_tar_acc

                test_tar_acc = self.test(
                    tar_test_loader
                )  # 每个batch结束后测试整个dataset,训练用的tar是drop last的,而且有random transform
                test_acc_list.append(test_tar_acc)
                if test_tar_acc > best_tar_test_acc:
                    best_tar_test_acc = test_tar_acc

                print("\n" + "-" * 100)
                print(
                    "Iter[{:02d}/{:03d}] Acc[src:{:.3f}, tar:{:.4f}, test:{:.4f}] Best Acc[src:{:.3f}, tar:{:.4f}, test:{:.4f}]"
                    .format(_iter, self.cfg.max_iter, epoch_src_acc,
                            epoch_tar_acc, test_tar_acc, best_src_acc,
                            best_tar_acc, best_tar_test_acc))
                print("-" * 100 + "\n")
                epoch_tar_correct = 0.0
                if _iter > self.cfg.early_stop_iter:
                    if early_stop_acc <= 0.0:
                        early_stop_acc = test_tar_acc
                    if self.cfg.early_stop:
                        break

            X_src, y_src = src_iter.next()
            X_tar, y_tar = tar_iter.next()
            X_src, y_src = X_src.to(DEVICE), y_src.to(DEVICE)
            X_tar, y_tar = X_tar.to(DEVICE), y_tar.to(DEVICE)
            optimizer.zero_grad()

            # forward
            src_features, src_cluster_loss, src_aux_entropy_loss = self.bottleneck_layer(
                self.features(X_src))
            src_outputs = self.classifier(src_features)

            tar_features, tar_cluster_loss, tar_aux_entropy_loss = self.bottleneck_layer(
                self.features(X_tar))
            tar_outputs = self.classifier(tar_features)

            # loss
            classifier_loss = classifier_ciriterion(src_outputs, y_src)
            entropy_loss = entropy_ciriterion(tar_outputs)
            inter_MMD_loss = MMD_ciriterion(src_features, tar_features)

            loss_factor = 2.0 / (
                1.0 + math.exp(-10 * _iter / self.cfg.max_iter)) - 1.0
            loss = classifier_loss + \
                   entropy_loss * self.cfg.entropy_loss_weight * loss_factor + \
                   inter_MMD_loss * self.cfg.inter_MMD_loss_weight * loss_factor + \
                   (src_aux_entropy_loss + tar_aux_entropy_loss) * self.cfg.aux_entropy_loss_weight * loss_factor + \
                   (src_cluster_loss + tar_cluster_loss) * self.cfg.cluster_loss_weight * loss_factor

            # optimize
            loss.backward()
            optimizer.step()

            #lr_scheduler
            lr_scheduler.step()

            # stat
            iter_loss = loss.item()
            move_average_loss = move_average_loss * move_factor + iter_loss * (
                1.0 - move_factor)

            pred_src = src_outputs.argmax(dim=1)
            pred_tar = tar_outputs.argmax(dim=1)
            epoch_src_correct += (y_src == pred_src).double().sum().item()
            epoch_tar_correct += (y_tar == pred_tar).double().sum().item()
            print(
                "Iter[{:02d}/{:03d}] Loss[M-Ave:{:.4f}\titer:{:.4f}\tCla:{:.4f}\tMMD:{:.4f}"
                .format(_iter, self.cfg.max_iter, move_average_loss, iter_loss,
                        classifier_loss, inter_MMD_loss))
            print(
                "Iter[{:02d}/{:03d}] Ent-Loss[aux_src:{:.4f}\taux_tar:{:.4f}\tClaEnt{:.4f}"
                .format(_iter, self.cfg.max_iter, src_aux_entropy_loss,
                        tar_aux_entropy_loss, entropy_loss))
            print("Iter[{:02d}/{:03d}] Cluster-Loss[src:{:.4f}\ttar:{:.4f}]\n".
                  format(_iter, self.cfg.max_iter, src_cluster_loss,
                         tar_cluster_loss))
        time_pass = time.time() - start_time
        print("Train finish in {:.0f}m {:.0f}s".format(time_pass // 60,
                                                       time_pass % 60))
        return best_tar_test_acc, early_stop_acc, test_acc_list
示例#5
0
    def __init__(self, data,
                 train_loader=None,
                 test_loader=None,
                 total_epoch=200,
                 alpha=0.1,
                 epsilon=0.1,
                 use_cuda=False,
                 resume=False,
                 ckpt_filename=None,
                 resume_filename=None,
                 privacy_flag=True,
                 privacy_option='maxent-arl',
                 print_interval_train=10,
                 print_interval_test=10
                 ):
        # data info
        self.data = data
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.n_sensitive_class = self.data.n_sensitive_class
        self.n_target_class = self.data.n_target_class

        # models
        self.adv_net = data.adversary_net
        self.target_net = data.target_net
        self.discriminator_net = data.discriminator_net

        # optimizer
        self.optimizer = data.optimizer
        self.discriminator_optimizer = data.discriminator_optimizer
        self.adv_optimizer = data.adv_optimizer
        self.target_optimizer = data.target_optimizer

        # loss
        self.kl_loss = nn.KLDivLoss()
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        self.entropy_loss = EntropyLoss()
        self.nll_loss = nn.NLLLoss()
        self.mse_loss = nn.MSELoss()

        # filename
        self.log_file_name = ckpt_filename+"_log.txt"
        self.adv_log_file_name = ckpt_filename+"_adv_log.txt"
        self.target_log_file_name = ckpt_filename + "_target_log.txt"
        self.checkpoint_filename = ckpt_filename
        self.adv_checkpoint_filename = ckpt_filename+"_adv.ckpt"
        self.target_checkpoint_filename = ckpt_filename + "_target.ckpt"

        # algorithm and visualization parameters
        self.alpha = torch.tensor([alpha*1.0], requires_grad=True)
        self.resume = resume
        self.epoch = 0
        self.gamma_param = 0.01
        self.plot_interval = 10
        self.print_interval_train = print_interval_train
        self.print_interval_test = print_interval_test
        self.use_cuda = use_cuda
        self.privacy_flag = privacy_flag
        self.privacy_option = privacy_option

        # local variables
        self.uniform = torch.tensor(1 / (self.data.n_sensitive_class)).repeat(self.data.n_sensitive_class)
        self.target_label = torch.zeros(0, dtype=torch.long)
        self.sensitive_label = torch.zeros(0, dtype=torch.long)
        self.sensitive_label_onehot = torch.FloatTensor(0, self.data.n_sensitive_class)
        self.target_label_onehot = torch.FloatTensor(0, self.data.n_target_class)
        self.inputs = torch.zeros(0, 0, 0)
        self.inputs.requires_grad = False
        self.batch_uniform = torch.FloatTensor(0, self.data.n_sensitive_class)
        self.epsilon = torch.tensor([epsilon]).float()

        if resume:
            assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
            if self.use_cuda:
                checkpoint = torch.load(os.path.join('checkpoint/', resume_filename))
            else:
                checkpoint = torch.load(os.path.join('checkpoint/',resume_filename), map_location=lambda storage, loc: storage)
            self.net = checkpoint['net']
            self.best_acc = 0  # checkpoint['acc']
            self.start_epoch = 0  # checkpoint['epoch']
            self.total_epoch = total_epoch  # + self.start_epoch

            for param in self.net.parameters():
                param.requires_grad = True
        else:
            self.net = data.net
            self.best_acc = 0
            self.start_epoch = 0
            self.total_epoch = total_epoch

        if self.use_cuda:
            self.net = self.net.cuda()
            self.discriminator_net = self.discriminator_net.cuda()
            self.adv_net = self.adv_net.cuda()
            self.target_net = self.target_net.cuda()
            self.net = nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count()))
            self.target_net = nn.DataParallel(self.target_net, device_ids=range(torch.cuda.device_count()))
            self.discriminator_net = nn.DataParallel(self.discriminator_net, device_ids=range(torch.cuda.device_count()))
            self.adv_net = nn.DataParallel(self.adv_net, device_ids=range(torch.cuda.device_count()))
            cudnn.benchmark = True
            self.inputs = self.inputs.cuda()
            self.target_label = self.target_label.cuda()
            self.sensitive_label = self.sensitive_label.cuda()
            self.sensitive_label_onehot = self.sensitive_label_onehot.cuda()
            self.target_label_onehot = self.target_label_onehot.cuda()
            self.uniform = self.uniform.cuda()
            self.batch_uniform = self.batch_uniform.cuda()
            self.alpha = self.alpha.cuda()

        self.best_loss = 1e16
        self.adv_best_acc = 0
        self.target_best_acc = 0
        self.t_losses, self.t_top1, self.d_losses, self.d_top1 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
        self.e_losses, self.losses = AverageMeter(), AverageMeter()
        self.t_top5, self.d_top5 = AverageMeter(), AverageMeter()
        self.adv_losses, self.adv_top1, self.adv_top5, self.entropy_losses = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
        self.target_losses, self.target_top1, self.target_top5, self.target_entropy_losses = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()