Esempio n. 1
0
def train_cnn(config):
    if config['network'] == 'inceptionv1':
        extractor = InceptionV1(num_classes=32, dilation=config['dilation'])
    elif config['network'] == 'inceptionv1s':
        extractor = InceptionV1s(num_classes=32, dilation=config['dilation'])
    else:
        extractor = Extractor(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'])
    classifier = Classifier(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], n_class=config['n_class'])

    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()

    res_dir = os.path.join(config['res_dir'], 'normal{}-{}-lr{}'.format(config['normal'], 
                                                                        config['network'],
                                                                        config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    set_log_config(res_dir)
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(config)

    criterion = torch.nn.CrossEntropyLoss()
    # optimizer = optim.Adam(
    #     list(extractor.parameters()) + list(classifier.parameters()),
    #     lr = config['lr'])
    opts = optim.SGD(
        list(extractor.parameters()) + list(classifier.parameters()),
        lr = config['lr'], nesterov=True, momentum=0.9)

    def train(extractor, classifier, config, epoch):
        extractor.train()
        classifier.train()

        optimizer = inv_lr_scheduler(opts, epoch, gamma=0.01, power=0.75, lr=config['lr'], weight_decay=0.0005)

        for step, (features, labels) in enumerate(config['source_train_loader']):
            if torch.cuda.is_available():
                features, labels = features.cuda(), labels.cuda()

            optimizer.zero_grad()
            # if config['aux_classifier'] == 1:
            #     x1, x2, x3 = extractor(features)
            #     preds = classifier(x1, x2, x3)
            preds = classifier(extractor(features))
            # print('preds {}, labels {}'.format(preds.shape, labels.shape))
            # print(preds[0])
            # preds_l = F.softmax(preds, dim=1)
            # print('preds_l {}'.format(preds_l.shape))
            # print(preds_l[0])
            # print('------')

            loss = criterion(preds, labels)
            loss.backward()
            optimizer.step()

    for epoch in range(1, config['n_epochs'] + 1):
        train(extractor, classifier, config, epoch)
        if epoch % config['TEST_INTERVAL'] == 0:
            print('test on source_test_loader')
            test(extractor, classifier, config['source_test_loader'], epoch)
            print('test on target_test_loader')
            test(extractor, classifier, config['target_test_loader'], epoch)
Esempio n. 2
0
def train_fada(config):
    if config['network'] == 'inceptionv1':
        extractor = InceptionV1(num_classes=32, dilation=config['dilation'])
    elif config['network'] == 'inceptionv1s':
        extractor = InceptionV1s(num_classes=32, dilation=config['dilation'])
    else:
        extractor = Extractor(n_flattens=config['n_flattens'],
                              n_hiddens=config['n_hiddens'])
    classifier = Classifier(n_flattens=config['n_flattens'],
                            n_hiddens=config['n_hiddens'],
                            n_class=config['n_class'])

    critic = Critic2(n_flattens=config['n_flattens'],
                     n_hiddens=config['n_hiddens'])
    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()
        critic = critic.cuda()

    criterion = torch.nn.CrossEntropyLoss()
    loss_class = torch.nn.CrossEntropyLoss()
    loss_domain = torch.nn.CrossEntropyLoss()

    res_dir = os.path.join(config['res_dir'],
                           'snr{}-lr{}'.format(config['snr'], config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    set_log_config(res_dir)
    logging.debug('train_dann')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(critic)
    logging.debug(config)

    optimizer = optim.Adam([{
        'params': extractor.parameters()
    }, {
        'params': classifier.parameters()
    }, {
        'params': critic.parameters()
    }],
                           lr=config['lr'])

    # TODO
    discriminator = main_models.DCD(input_features=128)
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.001)

    # source samples预训练
    #--------------pretrain g and h for step 1---------------------------------
    for epoch in range(config['n_epoches_1']):
        for data, labels in config['source_train_loader']:
            data = data.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            y_pred = classifier(extractor(data))

            loss = loss_class(y_pred, labels)
            loss.backward()

            optimizer.step()

        acc = 0
        for data, labels in config['target_test_loader']:
            data = data.to(device)
            labels = labels.to(device)
            y_test_pred = classifier(extractor(data))
            acc += (torch.max(y_test_pred,
                              1)[1] == labels).float().mean().item()

        accuracy = round(acc / float(len(config['target_test_loader'])), 3)

        print("step1----Epoch %d/%d  accuracy: %.3f " %
              (epoch + 1, config['n_epoches_1'], accuracy))

    #-----------------train DCD for step 2--------------------------------

    # X_s,Y_s=dataloader.sample_data()
    # X_t,Y_t=dataloader.create_target_samples(config['n_target_samples'])

    for epoch in range(config['n_epoches_2']):
        # for data,labels in config['source_train_loader']:
        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        num_iter = len_source_loader
        for i in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, label_target = iter_target.next()
            if i % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target, label_target = data_target.cuda(
                ), label_target.cuda()

        groups, aa = dataloader.sample_groups(data_source,
                                              label_source,
                                              data_target,
                                              label_target,
                                              seed=epoch)
        # groups, aa = dataloader.sample_groups(X_s,Y_s,X_t,Y_t,seed=epoch)

        n_iters = 4 * len(groups[1])
        index_list = torch.randperm(n_iters)
        mini_batch_size = 40  #use mini_batch train can be more stable

        loss_mean = []

        X1 = []
        X2 = []
        ground_truths = []
        for index in range(n_iters):

            ground_truth = index_list[index] // len(groups[1])

            x1, x2 = groups[ground_truth][index_list[index] -
                                          len(groups[1]) * ground_truth]
            X1.append(x1)
            X2.append(x2)
            ground_truths.append(ground_truth)

            #select data for a mini-batch to train
            if (index + 1) % mini_batch_size == 0:
                X1 = torch.stack(X1)
                X2 = torch.stack(X2)
                ground_truths = torch.LongTensor(ground_truths)
                X1 = X1.to(device)
                X2 = X2.to(device)
                ground_truths = ground_truths.to(device)

                optimizer_D.zero_grad()
                X_cat = torch.cat([extractor(X1), extractor(X2)], 1)
                y_pred = discriminator(X_cat.detach())
                loss = loss_domain(y_pred, ground_truths)
                loss.backward()
                optimizer_D.step()
                loss_mean.append(loss.item())
                X1 = []
                X2 = []
                ground_truths = []

        print("step2----Epoch %d/%d loss:%.3f" %
              (epoch + 1, config['n_epoches_2'], np.mean(loss_mean)))
Esempio n. 3
0
def train_dctln(config):
    if config['network'] == 'inceptionv1':
        extractor = InceptionV1(num_classes=32, dilation=config['dilation'])
    elif config['network'] == 'inceptionv1s':
        extractor = InceptionV1s(num_classes=32, dilation=config['dilation'])
    else:
        extractor = Extractor(n_flattens=config['n_flattens'],
                              n_hiddens=config['n_hiddens'])
    classifier = Classifier(n_flattens=config['n_flattens'],
                            n_hiddens=config['n_hiddens'],
                            n_class=config['n_class'])

    critic = Critic2(n_flattens=config['n_flattens'],
                     n_hiddens=config['n_hiddens'])
    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()
        critic = critic.cuda()

    criterion = torch.nn.CrossEntropyLoss()
    loss_class = torch.nn.CrossEntropyLoss()
    loss_domain = torch.nn.CrossEntropyLoss()

    res_dir = os.path.join(
        config['res_dir'],
        'slim{}-snr{}-lr{}'.format(config['slim'], config['snr'],
                                   config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    set_log_config(res_dir)
    logging.debug('train_dann')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(critic)
    logging.debug(config)

    optimizer = optim.Adam([{
        'params': extractor.parameters()
    }, {
        'params': classifier.parameters()
    }, {
        'params': critic.parameters()
    }],
                           lr=config['lr'])

    def dann(input_data, alpha):
        feature = extractor(input_data)
        feature = feature.view(feature.size(0), -1)
        reverse_feature = ReverseLayerF.apply(feature, alpha)
        class_output = classifier(feature)
        domain_output = critic(reverse_feature)

        return class_output, domain_output, feature

    def train(extractor, classifier, critic, config, epoch):
        extractor.train()
        classifier.train()
        critic.train()

        gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        if config['slim'] > 0:
            iter_target_semi = iter(config['target_train_semi_loader'])
            len_target_semi_loader = len(config['target_train_semi_loader'])

        num_iter = len_source_loader
        for i in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, _ = iter_target.next()
            data_target_semi, label_target_semi = iter_target_semi.next()

            if i % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if i % len_target_semi_loader == 0:
                iter_target_semi = iter(config['target_train_semi_loader'])

            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target = data_target.cuda()
                data_target_semi, label_target_semi = data_target_semi.cuda(
                ), label_target_semi.cuda()

            optimizer.zero_grad()

            source = extractor(data_source)
            source = source.view(source.size(0), -1)
            target = extractor(data_target)
            target = target.view(target.size(0), -1)
            # loss_mmd = mmd_linear(source, target)

            class_output_s, domain_output, _ = dann(input_data=data_source,
                                                    alpha=gamma)
            # print('domain_output {}'.format(domain_output.size()))
            err_s_label = loss_class(class_output_s, label_source)
            domain_label = torch.zeros(data_source.size(0)).long().cuda()
            err_s_domain = loss_domain(domain_output, domain_label)

            # Training model using target data
            domain_label = torch.ones(data_target.size(0)).long().cuda()
            class_output_t, domain_output, _ = dann(input_data=data_target,
                                                    alpha=gamma)
            err_t_domain = loss_domain(domain_output, domain_label)

            class_output_semi_t, _, _ = dann(input_data=data_target_semi,
                                             alpha=gamma)
            err_t_label = loss_class(class_output_semi_t, label_target_semi)

            # err = 1.0*err_s_label + err_s_domain + err_t_domain + 0*loss_mmd + err_t_label
            err = 1.0 * err_s_label + err_s_domain + err_t_domain + err_t_label

            # if i % 200 == 0:
            #     # print('err_s_label {}, err_s_domain {}, gamma {}, err_t_domain {}, loss_mmd {}, total err {}'.format(err_s_label.item(), err_s_domain.item(), gamma, err_t_domain.item(), loss_mmd.item(), err.item()))
            #     print('err_s_label {:.2f}, err_t_label {:.2f}, err_s_domain {:.2f}, gamma {:.2f}, err_t_domain {:.2f}, total err {:.2f}'.format(err_s_label.item(), err_t_label.item(), err_s_domain.item(), gamma, err_t_domain.item(), err.item()))

            err.backward()
            optimizer.step()

    for epoch in range(1, config['n_epochs'] + 1):
        train(extractor, classifier, critic, config, epoch)
        if epoch % config['TEST_INTERVAL'] == 0:
            # print('test on source_test_loader')
            # test(extractor, classifier, config['source_test_loader'], epoch)
            print('test on target_test_loader')
            test(extractor, classifier, config['target_test_loader'], epoch)
        if epoch % config['VIS_INTERVAL'] == 0:
            draw_confusion_matrix(extractor, classifier,
                                  config['target_test_loader'], res_dir, epoch,
                                  config['models'])
            draw_tsne(extractor,
                      classifier,
                      config['source_train_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      config['models'],
                      separate=True)
Esempio n. 4
0
def train_pada(config):
    if config['network'] == 'inceptionv1':
        extractor_s = InceptionV1(num_classes=32)
        extractor_t = InceptionV1(num_classes=32)
    elif config['network'] == 'inceptionv1s':
        extractor_s = InceptionV1s(num_classes=32)
        extractor_t = InceptionV1s(num_classes=32)
    else:
        extractor_s = Extractor(n_flattens=config['n_flattens'],
                                n_hiddens=config['n_hiddens'])
        extractor_t = Extractor(n_flattens=config['n_flattens'],
                                n_hiddens=config['n_hiddens'])

    classifier_s = Classifier(n_flattens=config['n_flattens'],
                              n_hiddens=config['n_hiddens'],
                              n_class=config['n_class'])
    classifier_t = Classifier(n_flattens=config['n_flattens'],
                              n_hiddens=config['n_hiddens'],
                              n_class=config['n_class'])

    if torch.cuda.is_available():
        extractor_s = extractor_s.cuda()
        classifier_s = classifier_s.cuda()

        extractor_t = extractor_t.cuda()
        classifier_t = classifier_t.cuda()

    cdan_random = config['random_layer']
    res_dir = os.path.join(
        config['res_dir'],
        'normal{}-{}-cons{}-lr{}'.format(config['normal'], config['network'],
                                         config['pada_cons_w'], config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    print('train_pada')
    print(config)

    set_log_config(res_dir)
    logging.debug('train_pada')
    # logging.debug(extractor)
    # logging.debug(classifier)
    logging.debug(config)

    if config['models'] == 'PADA':
        random_layer = None
        ad_net = AdversarialNetwork(config['n_flattens'], config['n_hiddens'])
    elif cdan_random:
        random_layer = RandomLayer([config['n_flattens'], config['n_class']],
                                   config['n_hiddens'])
        ad_net = AdversarialNetwork(config['n_hiddens'], config['n_hiddens'])
        random_layer.cuda()
    else:
        random_layer = None
        ad_net = AdversarialNetwork(config['n_flattens'] * config['n_class'],
                                    config['n_hiddens'])
    ad_net = ad_net.cuda()
    optimizer_s = torch.optim.Adam([{
        'params': extractor_s.parameters(),
        'lr': config['lr']
    }, {
        'params': classifier_s.parameters(),
        'lr': config['lr']
    }])
    optimizer_t = torch.optim.Adam([{
        'params': extractor_t.parameters(),
        'lr': config['lr']
    }, {
        'params': classifier_t.parameters(),
        'lr': config['lr']
    }])

    optimizer_ad = torch.optim.Adam(ad_net.parameters(), lr=config['lr'])

    def train_stage1(extractor_s, classifier_s, config, epoch):
        extractor_s.train()
        classifier_s.train()

        # STAGE 1:
        # 在labeled source上训练extractor_s和classifier_s
        # 训练完成后freeze这两个model
        iter_source = iter(config['source_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        for step in range(1, len_source_loader + 1):
            data_source, label_source = iter_source.next()
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()

            optimizer_s.zero_grad()

            h_s = extractor_s(data_source)
            h_s = h_s.view(h_s.size(0), -1)
            source_preds = classifier_s(h_s)
            cls_loss = nn.CrossEntropyLoss()(source_preds, label_source)

            cls_loss.backward()
            optimizer_s.step()

    def train(extractor_s, classifier_s, extractor_t, classifier_t, ad_net,
              config, epoch):
        start_epoch = 0

        # extractor_s.train()
        # classifier_s.train()
        # ad_net.train()

        # # STAGE 1:
        # # 在labeled source上训练extractor_s和classifier_s
        # # 训练完成后freeze这两个model
        # iter_source = iter(config['source_train_loader'])
        # len_source_loader = len(config['source_train_loader'])
        # for step in range(1, len_source_loader + 1):
        #     data_source, label_source = iter_source.next()
        #     if torch.cuda.is_available():
        #         data_source, label_source = data_source.cuda(), label_source.cuda()

        #     optimizer_s.zero_grad()

        #     h_s = extractor_s(data_source)
        #     h_s = h_s.view(h_s.size(0), -1)
        #     source_preds = classifier_s(h_s)
        #     cls_loss = nn.CrossEntropyLoss()(source_preds, label_source)

        #     cls_loss.backward()
        #     optimizer_s.step()

        # for param in extractor_s.parameters():
        #     param.requires_grad = False
        # for param in classifier_s.parameters():
        #     param.requires_grad = False

        # STAGE 2:
        # 使用新的extractor和classifier进行DANN训练
        # 不同的地方是,每个target 同时使用extractor_s和extractor_t

        extractor_t.train()
        classifier_t.train()
        ad_net.train()

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        num_iter = len_source_loader
        for step in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, label_target = iter_target.next()
            if step % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target, label_target = data_target.cuda(
                ), label_target.cuda()

            optimizer_t.zero_grad()
            optimizer_ad.zero_grad()

            h_s = extractor_t(data_source)
            h_s = h_s.view(h_s.size(0), -1)
            h_t = extractor_t(data_target)
            h_t = h_t.view(h_t.size(0), -1)

            source_preds = classifier_t(h_s)
            cls_loss = nn.CrossEntropyLoss()(source_preds, label_source)
            softmax_output_s = nn.Softmax(dim=1)(source_preds)

            target_preds = classifier_t(h_t)
            softmax_output_t = nn.Softmax(dim=1)(target_preds)
            if config['target_labeling'] == 1:
                cls_loss += nn.CrossEntropyLoss()(target_preds, label_target)

            feature = torch.cat((h_s, h_t), 0)
            softmax_output = torch.cat((softmax_output_s, softmax_output_t), 0)

            if epoch > start_epoch:
                gamma = 2 / (1 + math.exp(-10 *
                                          (epoch) / config['n_epochs'])) - 1
                if config['models'] == 'CDAN-E':
                    entropy = loss_func.Entropy(softmax_output)
                    d_loss = loss_func.CDAN(
                        [feature, softmax_output], ad_net, gamma, entropy,
                        loss_func.calc_coeff(num_iter * (epoch - start_epoch) +
                                             step), random_layer)
                elif config['models'] == 'CDAN':
                    d_loss = loss_func.CDAN([feature, softmax_output], ad_net,
                                            gamma, None, None, random_layer)
                elif config['models'] == 'PADA':
                    d_loss = loss_func.DANN(feature, ad_net, gamma)
                else:
                    raise ValueError('Method cannot be recognized.')
            else:
                d_loss = 0

            # constraints loss
            h_s_prev = extractor_s(data_source)
            cons_loss = nn.L1Loss()(h_s, h_s_prev)

            loss = cls_loss + d_loss + config['pada_cons_w'] * cons_loss
            loss.backward()
            optimizer_t.step()
            if epoch > start_epoch:
                optimizer_ad.step()
            if (step) % 20 == 0:
                print(
                    'Train Epoch {} closs {:.6f}, dloss {:.6f}, cons_loss {:.6f}, Loss {:.6f}'
                    .format(epoch, cls_loss.item(), d_loss.item(),
                            cons_loss.item(), loss.item()))

    for epoch in range(1, config['n_epochs'] + 1):
        train_stage1(extractor_s, classifier_s, config, epoch)
        if epoch % config['TEST_INTERVAL'] == 0:
            print('test on source_test_loader')
            test(extractor_s, classifier_s, config['source_test_loader'],
                 epoch)
            # print('test on target_test_loader')
            # accuracy = test(extractor_s, classifier_s, config['target_test_loader'], epoch)

    extractor_t.load_state_dict(extractor_s.state_dict())
    classifier_t.load_state_dict(classifier_s.state_dict())

    for param in extractor_s.parameters():
        param.requires_grad = False
    for param in classifier_s.parameters():
        param.requires_grad = False

    for epoch in range(1, config['n_epochs'] + 1):
        train(extractor_s, classifier_s, extractor_t, classifier_t, ad_net,
              config, epoch)
        if epoch % config['TEST_INTERVAL'] == 0:
            # print('test on source_test_loader')
            # test(extractor, classifier, config['source_test_loader'], epoch)
            print('test on target_test_loader')
            accuracy = test(extractor_t, classifier_t,
                            config['target_test_loader'], epoch)

        if epoch % config['VIS_INTERVAL'] == 0:
            title = config['models']
            draw_confusion_matrix(extractor_t, classifier_t,
                                  config['target_test_loader'], res_dir, epoch,
                                  title)
            draw_tsne(extractor_t,
                      classifier_t,
                      config['source_train_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      title,
                      separate=True)
Esempio n. 5
0
def train_wasserstein(config):
    # extractor = Extractor(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'])
    extractor = InceptionV1(num_classes=32)
    classifier = Classifier(n_flattens=config['n_flattens'],
                            n_hiddens=config['n_hiddens'],
                            n_class=config['n_class'])
    critic = Critic(n_flattens=config['n_flattens'],
                    n_hiddens=config['n_hiddens'])
    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()
        critic = critic.cuda()

    triplet_type = config['triplet_type']
    gamma = config['w_gamma']
    weight_wd = config['w_weight']
    weight_triplet = config['t_weight']
    t_margin = config['t_margin']
    t_confidence = config['t_confidence']
    k_critic = 3
    k_clf = 1
    TRIPLET_START_INDEX = 95

    if triplet_type == 'none':
        res_dir = os.path.join(
            config['res_dir'],
            'bs{}-lr{}-w{}-gamma{}'.format(config['batch_size'], config['lr'],
                                           weight_wd, gamma))
        if not os.path.exists(res_dir):
            os.makedirs(res_dir)
        extractor_path = os.path.join(res_dir, "extractor.pth")
        classifier_path = os.path.join(res_dir, "classifier.pth")
        critic_path = os.path.join(res_dir, "critic.pth")
        EPOCH_START = 1
        TEST_INTERVAL = 10

    else:
        TEST_INTERVAL = 1
        w_dir = os.path.join(
            config['res_dir'],
            'bs{}-lr{}-w{}-gamma{}'.format(config['batch_size'], config['lr'],
                                           weight_wd, gamma))
        if not os.path.exists(w_dir):
            os.makedirs(w_dir)
        res_dir = os.path.join(
            w_dir, '{}_t_weight{}_margin{}_confidence{}'.format(
                triplet_type, weight_triplet, t_margin, t_confidence))
        if not os.path.exists(res_dir):
            os.makedirs(res_dir)
        extractor_path = os.path.join(w_dir, "extractor.pth")
        classifier_path = os.path.join(w_dir, "classifier.pth")
        critic_path = os.path.join(w_dir, "critic.pth")

        if os.path.exists(extractor_path):
            extractor.load_state_dict(torch.load(extractor_path))
            classifier.load_state_dict(torch.load(classifier_path))
            critic.load_state_dict(torch.load(critic_path))
            print('load models')
            EPOCH_START = TRIPLET_START_INDEX
        else:
            EPOCH_START = 1

    set_log_config(res_dir)
    print('start epoch {}'.format(EPOCH_START))
    print('triplet type {}'.format(triplet_type))
    print(config)

    logging.debug('train_wt')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(critic)
    logging.debug(config)

    criterion = torch.nn.CrossEntropyLoss()
    softmax_layer = nn.Softmax(dim=1)

    critic_opt = torch.optim.Adam(critic.parameters(), lr=config['lr'])
    classifier_opt = torch.optim.Adam(classifier.parameters(), lr=config['lr'])
    feature_opt = torch.optim.Adam(extractor.parameters(),
                                   lr=config['lr'] / 10)

    def train(extractor, classifier, critic, config, epoch):
        extractor.train()
        classifier.train()
        critic.train()

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        num_iter = len_source_loader
        for step in range(1, num_iter):
            data_source, label_source = iter_source.next()
            data_target, _ = iter_target.next()
            if step % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target = data_target.cuda()

            # 1. train critic
            set_requires_grad(extractor, requires_grad=False)
            set_requires_grad(classifier, requires_grad=False)
            set_requires_grad(critic, requires_grad=True)
            with torch.no_grad():
                h_s = extractor(data_source)
                h_s = h_s.view(h_s.size(0), -1)
                h_t = extractor(data_target)
                h_t = h_t.view(h_t.size(0), -1)

            for j in range(k_critic):
                gp = gradient_penalty(critic, h_s, h_t)
                critic_s = critic(h_s)
                critic_t = critic(h_t)
                wasserstein_distance = critic_s.mean() - critic_t.mean()
                critic_cost = -wasserstein_distance + gamma * gp

                critic_opt.zero_grad()
                critic_cost.backward()
                critic_opt.step()

                if step == 10 and j == 0:
                    print('EPOCH {}, DISCRIMINATOR: wd {}, gp {}, loss {}'.
                          format(epoch, wasserstein_distance.item(),
                                 (gamma * gp).item(), critic_cost.item()))
                    logging.debug(
                        'EPOCH {}, DISCRIMINATOR: wd {}, gp {}, loss {}'.
                        format(epoch, wasserstein_distance.item(),
                               (gamma * gp).item(), critic_cost.item()))

            # 2. train feature and class_classifier
            set_requires_grad(extractor, requires_grad=True)
            set_requires_grad(classifier, requires_grad=True)
            set_requires_grad(critic, requires_grad=False)
            for _ in range(k_clf):
                h_s = extractor(data_source)
                h_s = h_s.view(h_s.size(0), -1)
                h_t = extractor(data_target)
                h_t = h_t.view(h_t.size(0), -1)

                source_preds = classifier(h_s)
                clf_loss = criterion(source_preds, label_source)
                wasserstein_distance = critic(h_s).mean() - critic(h_t).mean()

                if triplet_type != 'none' and epoch >= TRIPLET_START_INDEX:
                    target_preds = classifier(h_t)
                    target_labels = target_preds.data.max(1)[1]
                    target_logits = softmax_layer(target_preds)
                    if triplet_type == 'all':
                        triplet_index = np.where(
                            target_logits.data.max(1)[0].cpu().numpy() >
                            t_margin)[0]
                        images = torch.cat((h_s, h_t[triplet_index]), 0)
                        labels = torch.cat(
                            (label_source, target_labels[triplet_index]), 0)
                    elif triplet_type == 'src':
                        images = h_s
                        labels = label_source
                    elif triplet_type == 'tgt':
                        triplet_index = np.where(
                            target_logits.data.max(1)[0].cpu().numpy() >
                            t_confidence)[0]
                        images = h_t[triplet_index]
                        labels = target_labels[triplet_index]
                    elif triplet_type == 'sep':
                        triplet_index = np.where(
                            target_logits.data.max(1)[0].cpu().numpy() >
                            t_confidence)[0]
                        images = h_t[triplet_index]
                        labels = target_labels[triplet_index]
                        t_loss_sep, _ = triplet_loss(extractor, {
                            "X": images,
                            "y": labels
                        }, t_confidence)
                        images = h_s
                        labels = label_source

                    t_loss, _ = triplet_loss(extractor, {
                        "X": images,
                        "y": labels
                    }, t_margin)
                    loss = clf_loss + \
                        weight_wd * wasserstein_distance + \
                        weight_triplet * t_loss
                    if triplet_type == 'sep':
                        loss += t_loss_sep
                    feature_opt.zero_grad()
                    classifier_opt.zero_grad()
                    loss.backward()
                    feature_opt.step()
                    classifier_opt.step()

                    if step == 10:
                        print(
                            'EPOCH {}, CLASSIFIER: clf_loss {}, wd {}, t_loss {}, total loss {}'
                            .format(epoch, clf_loss.item(),
                                    weight_wd * wasserstein_distance.item(),
                                    weight_triplet * t_loss.item(),
                                    loss.item()))
                        logging.debug(
                            'EPOCH {}, CLASSIFIER: clf_loss {}, wd {}, t_loss {}, total loss {}'
                            .format(epoch, clf_loss.item(),
                                    weight_wd * wasserstein_distance.item(),
                                    weight_triplet * t_loss.item(),
                                    loss.item()))

                else:
                    loss = clf_loss + weight_wd * wasserstein_distance
                    feature_opt.zero_grad()
                    classifier_opt.zero_grad()
                    loss.backward()
                    feature_opt.step()
                    classifier_opt.step()

                    if step == 10:
                        print(
                            'EPOCH {}, CLASSIFIER: clf_loss {}, wd {},  loss {}'
                            .format(epoch, clf_loss.item(),
                                    weight_wd * wasserstein_distance.item(),
                                    loss.item()))
                        logging.debug(
                            'EPOCH {}, CLASSIFIER: clf_loss {}, wd {},  loss {}'
                            .format(epoch, clf_loss.item(),
                                    weight_wd * wasserstein_distance.item(),
                                    loss.item()))

    # pretrain(model, config, pretrain_epochs=20)
    for epoch in range(EPOCH_START, config['n_epochs'] + 1):
        train(extractor, classifier, critic, config, epoch)
        if epoch % TEST_INTERVAL == 0:
            # print('test on source_test_loader')
            # test(extractor, classifier, config['source_test_loader'], epoch)
            # print('test on target_train_loader')
            # test(model, config['target_train_loader'], epoch)
            print('test on target_test_loader')
            test(extractor, classifier, config['target_test_loader'], epoch)
        if epoch % config['VIS_INTERVAL'] == 0:
            if triplet_type == 'none':
                title = '(a) WDGRL'
            else:
                title = '(b) TLADA'
            draw_confusion_matrix(extractor, classifier,
                                  config['target_test_loader'], res_dir, epoch,
                                  title)
            draw_tsne(extractor,
                      classifier,
                      config['source_train_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      title,
                      separate=True)
            # draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, title, separate=False)
    if triplet_type == 'none':
        torch.save(extractor.state_dict(), extractor_path)
        torch.save(classifier.state_dict(), classifier_path)
        torch.save(critic.state_dict(), critic_path)
Esempio n. 6
0
def train_dann_vat(config):
    extractor = Extractor(n_flattens=config['n_flattens'],
                          n_hiddens=config['n_hiddens'])
    classifier = Classifier(n_flattens=config['n_flattens'],
                            n_hiddens=config['n_hiddens'],
                            n_class=config['n_class'])
    critic = Critic2(n_flattens=config['n_flattens'],
                     n_hiddens=config['n_hiddens'])
    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()
        critic = critic.cuda()

    res_dir = os.path.join(
        config['res_dir'],
        'slim{}-target_labeling{}-lr{}'.format(config['slim'],
                                               config['target_labeling'],
                                               config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    criterion = torch.nn.CrossEntropyLoss()
    loss_class = torch.nn.CrossEntropyLoss()
    loss_domain = torch.nn.CrossEntropyLoss()
    cent = ConditionalEntropyLoss().cuda()
    vat_loss = VAT(extractor, classifier, n_power=1, radius=3.5).cuda()

    print('train_dann_vat')
    print(extractor)
    print(classifier)
    print(critic)
    print(config)

    set_log_config(res_dir)
    logging.debug('train_dann_vat')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(critic)
    logging.debug(config)

    optimizer = optim.Adam([{
        'params': extractor.parameters()
    }, {
        'params': classifier.parameters()
    }, {
        'params': critic.parameters()
    }],
                           lr=config['lr'])

    def dann(input_data, alpha):
        feature = extractor(input_data)
        feature = feature.view(feature.size(0), -1)
        reverse_feature = ReverseLayerF.apply(feature, alpha)
        class_output = classifier(feature)
        domain_output = critic(reverse_feature)

        return class_output, domain_output, feature

    def train(extractor, classifier, critic, config, epoch):
        extractor.train()
        classifier.train()
        critic.train()

        gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        num_iter = len_source_loader
        for i in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, _ = iter_target.next()
            if i % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target = data_target.cuda()

            optimizer.zero_grad()

            class_output_s, domain_output_s, features_source = dann(
                input_data=data_source, alpha=gamma)
            err_s_class = loss_class(class_output_s, label_source)
            domain_label_s = torch.zeros(data_source.size(0)).long().cuda()
            err_s_domain = loss_domain(domain_output_s, domain_label_s)

            # Training model using target data
            class_output_t, domain_output_t, features_target = dann(
                input_data=data_target, alpha=gamma)
            domain_label_t = torch.ones(data_target.size(0)).long().cuda()
            err_t_domain = loss_domain(domain_output_t, domain_label_t)
            # target entropy loss
            err_t_entropy = get_loss_entropy(class_output_t)

            # virtual adversarial loss.
            err_s_vat = vat_loss(data_source, class_output_s)
            err_t_vat = vat_loss(data_target, class_output_t)

            err_domain = 0.5 * (err_s_domain + err_t_domain)

            # combined loss.
            dw = 1
            cw = 1
            sw = 1
            tw = 1
            bw = 1
            err_all = (dw * err_domain + cw * err_s_class + sw * err_s_vat +
                       tw * err_t_vat + tw * err_t_entropy)

            if i % 20 == 0:
                print(
                    'err_s_class {:.2f}, err_s_domain {:.2f}, gamma {:.2f}, err_t_domain {:.2f}, err_t_vat {:.2f}, err_s_vat {:.2f}, err_all {:.2f}'
                    .format(err_s_class.item(), err_s_domain.item(), gamma,
                            err_t_domain.item(), err_s_vat.item(),
                            err_t_vat.item(), err_all.item()))

            err_all.backward()
            optimizer.step()

    if config['testonly'] == 0:
        best_accuracy = 0
        best_model_index = -1
        for epoch in range(1, config['n_epochs'] + 1):
            train(extractor, classifier, critic, config, epoch)
            if epoch % config['TEST_INTERVAL'] == 0:
                print('test on source_test_loader')
                test(extractor, classifier, config['source_test_loader'],
                     epoch)
                print('test on target_test_loader')
                accuracy = test(extractor, classifier,
                                config['target_test_loader'], epoch)

                if accuracy > best_accuracy:
                    best_accuracy = accuracy
                    best_model_index = epoch
                print(
                    'epoch {} accuracy: {:.6f}, best accuracy {:.6f} on epoch {}'
                    .format(epoch, accuracy, best_accuracy, best_model_index))

            if epoch % config['VIS_INTERVAL'] == 0:
                title = config['models']
                draw_confusion_matrix(extractor, classifier,
                                      config['target_test_loader'], res_dir,
                                      epoch, title)
                draw_tsne(extractor,
                          classifier,
                          config['source_test_loader'],
                          config['target_test_loader'],
                          res_dir,
                          epoch,
                          title,
                          separate=True)
Esempio n. 7
0
def train_cdan_vat(config):
    extractor = Extractor(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'])
    classifier = Classifier(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], n_class=config['n_class'])
    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()

    xi = 1e-06
    ip = 1
    eps = 15
    vat = VirtualAdversarialPerturbationGenerator(extractor, classifier, xi=xi, eps=eps, ip=ip)

    cdan_random = config['random_layer'] 
    res_dir = os.path.join(config['res_dir'], 'random{}-bs{}-lr{}'.format(cdan_random, config['batch_size'], config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    print('train_cdan')
    print(extractor)
    print(classifier)
    print(config)

    set_log_config(res_dir)
    logging.debug('train_cdan')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(config)

    if config['models'] == 'DANN':
        random_layer = None
        ad_net = AdversarialNetwork(config['n_flattens'], config['n_hiddens'])
    elif cdan_random:
        random_layer = RandomLayer([config['n_flattens'], config['n_class']], config['n_hiddens'])
        ad_net = AdversarialNetwork(config['n_hiddens'], config['n_hiddens'])
        random_layer.cuda()
    else:
        random_layer = None
        ad_net = AdversarialNetwork(config['n_flattens'] * config['n_class'], config['n_hiddens'])
    ad_net = ad_net.cuda()
    optimizer = torch.optim.Adam([
        {'params': extractor.parameters(), 'lr': config['lr']},
        {'params': classifier.parameters(), 'lr': config['lr']}
        ])
    optimizer_ad = torch.optim.Adam(ad_net.parameters(), lr=config['lr'])


    extractor_path = os.path.join(res_dir, "extractor.pth")
    classifier_path = os.path.join(res_dir, "classifier.pth")
    adnet_path = os.path.join(res_dir, "adnet.pth")

    def train(extractor, classifier, ad_net, config, epoch):
        start_epoch = 0

        extractor.train()
        classifier.train()
        ad_net.train()

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        num_iter = len_source_loader
        for step in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, _ = iter_target.next()
            if step % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(), label_source.cuda()
                data_target = data_target.cuda()

            optimizer.zero_grad()
            optimizer_ad.zero_grad()

            h_s = extractor(data_source)
            h_s = h_s.view(h_s.size(0), -1)
            h_t = extractor(data_target)
            h_t = h_t.view(h_t.size(0), -1)

            source_preds = classifier(h_s)
            cls_loss = nn.CrossEntropyLoss()(source_preds, label_source)
            softmax_output_s = nn.Softmax(dim=1)(source_preds)

            target_preds = classifier(h_t)
            softmax_output_t = nn.Softmax(dim=1)(target_preds)

            feature = torch.cat((h_s, h_t), 0)
            softmax_output = torch.cat((softmax_output_s, softmax_output_t), 0)

            if epoch > start_epoch:
                gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1
                if config['models'] == 'CDAN-E':
                    entropy = loss_func.Entropy(softmax_output)
                    d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, entropy, loss_func.calc_coeff(num_iter*(epoch-start_epoch)+step), random_layer)
                elif config['models'] == 'CDAN':
                    d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, None, None, random_layer)
                elif config['models'] == 'DANN':
                    d_loss = loss_func.DANN(feature, ad_net, gamma)
                elif config['models'] == 'CDAN_VAT':
                    # entropy = loss_func.Entropy(softmax_output)
                    # d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, entropy, loss_func.calc_coeff(num_iter*(epoch-start_epoch)+step), random_layer)
                    d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, None, None, random_layer)

                    # vat_loss = loss_func.VAT(vat, data_target, extractor, classifier, target_consistency_criterion)

                    # vat_adv, clean_vat_logits = vat(data_target)
                    # vat_adv_inputs = data_target + vat_adv
                    # adv_vat_features = extractor(vat_adv_inputs)
                    # adv_vat_logits = classifier(adv_vat_features)
                    # target_vat_loss = target_consistency_criterion(adv_vat_logits, clean_vat_logits)
                    # vat_loss = target_vat_loss_weight * target_vat_loss

                else:
                    raise ValueError('Method cannot be recognized.')
            else:
                d_loss = 0

            loss = cls_loss + d_loss
            loss.backward()
            optimizer.step()

            vat_adv, clean_vat_logits = vat(data_target)
            vat_adv_inputs = data_target + vat_adv
            adv_vat_features = extractor(vat_adv_inputs)
            adv_vat_logits = classifier(adv_vat_features)
            target_vat_loss = target_consistency_criterion(adv_vat_logits, clean_vat_logits)
            vat_loss = target_vat_loss_weight * target_vat_loss
            vat_loss.backward()
            # optimizer.step()

            if epoch > start_epoch:
                optimizer_ad.step()
            if (step) % 20 == 0:
                print('Train Epoch {} closs {:.6f}, dloss {:.6f}, vat_loss {:.6f}, Loss {:.6f}'.format(epoch, cls_loss.item(), d_loss.item(), vat_loss.item(), loss.item()))

    if config['testonly'] == 0:
        best_accuracy = 0
        best_model_index = -1
        for epoch in range(1, config['n_epochs'] + 1):
            train(extractor, classifier, ad_net, config, epoch)
            if epoch % config['TEST_INTERVAL'] == 0:
                print('test on source_test_loader')
                test(extractor, classifier, config['source_test_loader'], epoch)
                print('test on target_test_loader')
                accuracy = test(extractor, classifier, config['target_test_loader'], epoch)

                if accuracy > best_accuracy:
                    best_accuracy = accuracy
                    best_model_index = epoch
                    torch.save(extractor.state_dict(), extractor_path)
                    torch.save(classifier.state_dict(), classifier_path)
                    torch.save(ad_net.state_dict(), adnet_path)
                print('epoch {} accuracy: {:.6f}, best accuracy {:.6f} on epoch {}'.format(epoch, accuracy, best_accuracy, best_model_index))


            if epoch % config['VIS_INTERVAL'] == 0:
                title = config['models']
                draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, epoch, title)
                draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, title, separate=True)
                # draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, title, separate=False)
    else:
        if os.path.exists(extractor_path) and os.path.exists(classifier_path) and os.path.exists(adnet_path):
            extractor.load_state_dict(torch.load(extractor_path))
            classifier.load_state_dict(torch.load(classifier_path))
            ad_net.load_state_dict(torch.load(adnet_path))
            print('Test only mode, model loaded')

            print('test on source_test_loader')
            test(extractor, classifier, config['source_test_loader'], -1)
            print('test on target_test_loader')
            test(extractor, classifier, config['target_test_loader'], -1)

            title = config['models']
            draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, -1, title)
            draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, -1, title, separate=True)
        else:
            print('no saved model found')
Esempio n. 8
0
def train_deepcoral(config):
    if config['network'] == 'inceptionv1':
        extractor = InceptionV1(num_classes=32, dilation=config['dilation'])
    elif config['network'] == 'inceptionv1s':
        extractor = InceptionV1s(num_classes=32, dilation=config['dilation'])
    else:
        extractor = Extractor(n_flattens=config['n_flattens'],
                              n_hiddens=config['n_hiddens'])
    classifier = Classifier(n_flattens=config['n_flattens'],
                            n_hiddens=config['n_hiddens'],
                            n_class=config['n_class'])

    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()

    res_dir = os.path.join(
        config['res_dir'], 'normal{}-{}-dilation{}-lr{}-mmdgamma{}'.format(
            config['normal'], config['network'], config['dilation'],
            config['lr'], config['mmd_gamma']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    criterion = torch.nn.CrossEntropyLoss()

    set_log_config(res_dir)
    logging.debug('train_deepcoral')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(config)

    optimizer = optim.Adam(list(extractor.parameters()) +
                           list(classifier.parameters()),
                           lr=config['lr'])

    def train(extractor, classifier, config, epoch):
        extractor.train()
        classifier.train()

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        if config['slim'] > 0:
            iter_target_semi = iter(config['target_train_semi_loader'])
            len_target_semi_loader = len(config['target_train_semi_loader'])

        num_iter = len_source_loader
        for i in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, _ = iter_target.next()
            if config['slim'] > 0:
                data_target_semi, label_target_semi = iter_target_semi.next()
                if i % len_target_semi_loader == 0:
                    iter_target_semi = iter(config['target_train_semi_loader'])
            if i % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target = data_target.cuda()
                if config['slim'] > 0:
                    data_target_semi, label_target_semi = data_target_semi.cuda(
                    ), label_target_semi.cuda()

            optimizer.zero_grad()

            source = extractor(data_source)
            source = source.view(source.size(0), -1)
            target = extractor(data_target)
            target = target.view(target.size(0), -1)

            preds = classifier(source)
            loss_cls = criterion(preds, label_source)

            loss_coral = CORAL(source, target)

            # gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1
            # loss = loss_cls + gamma * loss_coral
            loss = loss_cls + config['mmd_gamma'] * loss_coral

            if config['slim'] > 0:
                feature_target_semi = extractor(data_target_semi)
                feature_target_semi = feature_target_semi.view(
                    feature_target_semi.size(0), -1)
                preds_target_semi = classifier(feature_target_semi)
                err_t_class_semi = criterion(preds_target_semi,
                                             label_target_semi)
                loss += err_t_class_semi

            if i % 50 == 0:
                print('loss_cls {}, loss_coral {}, gamma {}, total loss {}'.
                      format(loss_cls.item(), loss_coral.item(),
                             config['mmd_gamma'], loss.item()))
            loss.backward()
            optimizer.step()

    for epoch in range(1, config['n_epochs'] + 1):
        train(extractor, classifier, config, epoch)
        if epoch % config['TEST_INTERVAL'] == 0:
            print('test on source_test_loader')
            test(extractor, classifier, config['source_test_loader'], epoch)
            print('test on target_test_loader')
            test(extractor, classifier, config['target_test_loader'], epoch)
        if epoch % config['VIS_INTERVAL'] == 0:
            draw_confusion_matrix(extractor, classifier,
                                  config['target_test_loader'], res_dir, epoch,
                                  config['models'])
            draw_tsne(extractor,
                      classifier,
                      config['source_test_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      config['models'],
                      separate=True)
            draw_tsne(extractor,
                      classifier,
                      config['source_test_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      config['models'],
                      separate=False)
Esempio n. 9
0
def train_cdan(config):
    if config['network'] == 'inceptionv1':
        extractor = InceptionV1(num_classes=32)
    elif config['network'] == 'inceptionv1s':
        extractor = InceptionV1s(num_classes=32)
    else:
        extractor = Extractor(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], bn=config['bn'])

    classifier = Classifier(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], n_class=config['n_class'])
    vat_loss = VAT(extractor, classifier, n_power=1, radius=3.5).cuda()

    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()

    cdan_random = config['random_layer'] 
    res_dir = os.path.join(config['res_dir'], 'slim{}-targetLabel{}-snr{}-snrp{}-lr{}'.format(config['slim'],
                                                                        config['target_labeling'],
                                                                        config['snr'],
                                                                        config['snrp'],
                                                                        config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    print('train_cdan')
    #print(extractor)
    #print(classifier)
    print(config)

    set_log_config(res_dir)
    logging.debug('train_cdan')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(config)

    if config['models'] == 'DANN':
        random_layer = None
        ad_net = AdversarialNetwork(config['n_flattens'], config['n_hiddens'])
    elif cdan_random:
        random_layer = RandomLayer([config['n_flattens'], config['n_class']], config['n_hiddens'])
        ad_net = AdversarialNetwork(config['n_hiddens'], config['n_hiddens'])
        random_layer.cuda()
    else:
        random_layer = None
        ad_net = AdversarialNetwork(config['n_flattens'] * config['n_class'], config['n_hiddens'])
    ad_net = ad_net.cuda()
    optimizer = torch.optim.Adam([
        {'params': extractor.parameters(), 'lr': config['lr']},
        {'params': classifier.parameters(), 'lr': config['lr']}
        ])
    optimizer_ad = torch.optim.Adam(ad_net.parameters(), lr=config['lr'])


    extractor_path = os.path.join(res_dir, "extractor.pth")
    classifier_path = os.path.join(res_dir, "classifier.pth")
    adnet_path = os.path.join(res_dir, "adnet.pth")

    def train(extractor, classifier, ad_net, config, epoch):
        start_epoch = 0

        extractor.train()
        classifier.train()
        ad_net.train()

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        if config['slim'] > 0:
            iter_target_semi = iter(config['target_train_semi_loader'])
            len_target_semi_loader = len(config['target_train_semi_loader'])

        num_iter = len_source_loader
        for step in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, label_target = iter_target.next()
            if config['slim'] > 0:
                data_target_semi, label_target_semi = iter_target_semi.next()
                if step % len_target_semi_loader == 0:
                    iter_target_semi = iter(config['target_train_semi_loader'])
            if step % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(), label_source.cuda()
                data_target, label_target = data_target.cuda(), label_target.cuda()
                if config['slim'] > 0:
                    data_target_semi, label_target_semi = data_target_semi.cuda(), label_target_semi.cuda()

            optimizer.zero_grad()
            optimizer_ad.zero_grad()

            h_s = extractor(data_source)
            h_s = h_s.view(h_s.size(0), -1)
            h_t = extractor(data_target)
            h_t = h_t.view(h_t.size(0), -1)

            source_preds = classifier(h_s)
            cls_loss = nn.CrossEntropyLoss()(source_preds, label_source)
            softmax_output_s = nn.Softmax(dim=1)(source_preds)

            target_preds = classifier(h_t)
            softmax_output_t = nn.Softmax(dim=1)(target_preds)
            if config['target_labeling'] == 1:
                cls_loss += nn.CrossEntropyLoss()(target_preds, label_target)

            feature = torch.cat((h_s, h_t), 0)
            softmax_output = torch.cat((softmax_output_s, softmax_output_t), 0)

            if epoch > start_epoch:
                gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1
                if config['models'] == 'CDAN-E':
                    entropy = loss_func.Entropy(softmax_output)
                    d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, entropy, loss_func.calc_coeff(num_iter*(epoch-start_epoch)+step), random_layer)
                elif config['models'] == 'CDAN':
                    d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, None, None, random_layer)
                elif config['models'] == 'DANN':
                    d_loss = loss_func.DANN(feature, ad_net, gamma)
                else:
                    raise ValueError('Method cannot be recognized.')
            else:
                d_loss = 0

            loss = cls_loss + d_loss

            err_t_bnm = get_loss_bnm(target_preds)
            err_s_vat = vat_loss(data_source, source_preds)
            err_t_vat = vat_loss(data_target, target_preds)
            loss +=  1.0 * err_s_vat + 1.0 * err_t_vat + 1.0 * err_t_bnm

            if config['slim'] > 0:
                feature_target_semi = extractor(data_target_semi)
                feature_target_semi = feature_target_semi.view(feature_target_semi.size(0), -1)
                preds_target_semi = classifier(feature_target_semi)
                loss += nn.CrossEntropyLoss()(preds_target_semi, label_target_semi)

            loss.backward()
            optimizer.step()
            if epoch > start_epoch:
                optimizer_ad.step()
            if (step) % 100 == 0:
                print('Train Epoch {} closs {:.6f}, dloss {:.6f}, Loss {:.6f}'.format(epoch, cls_loss.item(), d_loss.item(), loss.item()))

    if config['testonly'] == 0:
        best_accuracy = 0
        best_model_index = -1
        for epoch in range(1, config['n_epochs'] + 1):
            train(extractor, classifier, ad_net, config, epoch)
            if epoch % config['TEST_INTERVAL'] == 0:
                # print('test on source_test_loader')
                # test(extractor, classifier, config['source_test_loader'], epoch)
                print('test on target_test_loader')
                accuracy = test(extractor, classifier, config['target_test_loader'], epoch)

                if accuracy > best_accuracy:
                    best_accuracy = accuracy
                    best_model_index = epoch
                    torch.save(extractor.state_dict(), extractor_path)
                    torch.save(classifier.state_dict(), classifier_path)
                    torch.save(ad_net.state_dict(), adnet_path)
                print('epoch {} accuracy: {:.6f}, best accuracy {:.6f} on epoch {}'.format(epoch, accuracy, best_accuracy, best_model_index))


            if epoch % config['VIS_INTERVAL'] == 0:
                title = config['models']
                draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, epoch, title)
                draw_tsne(extractor, classifier, config['source_train_loader'], config['target_test_loader'], res_dir, epoch, title, separate=True)
                # draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, title, separate=False)
    else:
        if os.path.exists(extractor_path) and os.path.exists(classifier_path) and os.path.exists(adnet_path):
            extractor.load_state_dict(torch.load(extractor_path))
            classifier.load_state_dict(torch.load(classifier_path))
            ad_net.load_state_dict(torch.load(adnet_path))
            print('Test only mode, model loaded')

            # print('test on source_test_loader')
            # test(extractor, classifier, config['source_test_loader'], -1)
            print('test on target_test_loader')
            test(extractor, classifier, config['target_test_loader'], -1)

            title = config['models']
            draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, -1, title)
            # draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, -1, title, separate=True)
            draw_tsne(extractor, classifier, config['source_train_loader'], config['target_test_loader'], res_dir, -1, title, separate=True)
        else:
            print('no saved model found')
Esempio n. 10
0
def train_mcd(config):
    def discrepancy(p1, p2):
        if config['mcd_swd'] == 1:
            dist = discrepancy_slice_wasserstein(p1, p2)
        else:
            dist = discrepancy_mcd(p1, p2)
        return dist

    if config['inception'] == 1:
        # G = InceptionV4(num_classes=32)
        G = InceptionV1(num_classes=32)
    else:
        G = Extractor(n_flattens=config['n_flattens'],
                      n_hiddens=config['n_hiddens'])
    C1 = Classifier(n_flattens=config['n_flattens'],
                    n_hiddens=config['n_hiddens'],
                    n_class=config['n_class'])
    C2 = Classifier(n_flattens=config['n_flattens'],
                    n_hiddens=config['n_hiddens'],
                    n_class=config['n_class'])
    if torch.cuda.is_available():
        G = G.cuda()
        C1 = C1.cuda()
        C2 = C2.cuda()

    # opt_g = optim.Adam(G.parameters(), lr=config['lr'], weight_decay=0.0005)
    # opt_c1 = optim.Adam(C1.parameters(), lr=config['lr'], weight_decay=0.0005)
    # opt_c2 = optim.Adam(C2.parameters(), lr=config['lr'], weight_decay=0.0005)
    opt_g = optim.Adam(G.parameters(), lr=config['lr'])
    opt_c1 = optim.Adam(C1.parameters(), lr=config['lr'])
    opt_c2 = optim.Adam(C2.parameters(), lr=config['lr'])

    criterion = torch.nn.CrossEntropyLoss()
    res_dir = os.path.join(
        config['res_dir'], 'normal{}-{}-dilation{}-swd{}-lr{}'.format(
            config['normal'], config['network'], config['dilation'],
            config['mcd_swd'], config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    set_log_config(res_dir)
    logging.debug('train_mcd')
    logging.debug(G)
    logging.debug(C1)
    logging.debug(C2)
    logging.debug(config)

    def train(G, C1, C2, config, epoch):
        G.train()
        C1.train()
        C2.train()

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        num_iter = len_source_loader
        for i in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, _ = iter_target.next()
            if i % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target = data_target.cuda()

            opt_g.zero_grad()
            opt_c1.zero_grad()
            opt_c2.zero_grad()

            # 源分类误差
            opt_g.zero_grad()
            opt_c1.zero_grad()
            opt_c2.zero_grad()

            feat_s = G(data_source)
            output_s1 = C1(feat_s)
            output_s2 = C2(feat_s)
            loss_s1 = criterion(output_s1, label_source)
            loss_s2 = criterion(output_s2, label_source)
            loss_s = loss_s1 + loss_s2
            loss_s.backward()
            opt_g.step()
            opt_c1.step()
            opt_c2.step()

            # 源分类误差 - 源和目的特征差异
            opt_g.zero_grad()
            opt_c1.zero_grad()
            opt_c2.zero_grad()
            feat_s = G(data_source)
            output_s1 = C1(feat_s)
            output_s2 = C2(feat_s)
            feat_t = G(data_target)
            output_t1 = C1(feat_t)
            output_t2 = C2(feat_t)
            loss_s1 = criterion(output_s1, label_source)
            loss_s2 = criterion(output_s2, label_source)
            loss_s = loss_s1 + loss_s2
            loss_dis = discrepancy(output_t1, output_t2)
            loss = loss_s - loss_dis
            #loss =  - loss_dis
            loss.backward()
            opt_c1.step()
            opt_c2.step()

            # 更新特征提取器
            for _ in range(1):
                opt_g.zero_grad()
                opt_c1.zero_grad()
                opt_c2.zero_grad()
                feat_t = G(data_target)
                output_t1 = C1(feat_t)
                output_t2 = C2(feat_t)
                loss_dis = discrepancy(output_t1, output_t2)

                feat_s = G(data_source)
                output_s1 = C1(feat_s)
                output_s2 = C2(feat_s)
                loss_s1 = criterion(output_s1, label_source)
                loss_s2 = criterion(output_s2, label_source)
                loss_s = loss_s1 + loss_s2
                loss = loss_s + loss_dis

                loss.backward()

                #loss_dis.backward()
                opt_g.step()

            if i % 20 == 0:
                print(
                    'Train Epoch: {} Loss1: {:.6f}\t Loss2: {:.6f}\t  Discrepancy: {:.6f}'
                    .format(epoch, loss_s1.item(), loss_s2.item(),
                            loss_dis.item()))
                logging.debug(
                    'Train Epoch: {} Loss1: {:.6f}\t Loss2: {:.6f}\t  Discrepancy: {:.6f}'
                    .format(epoch, loss_s1.item(), loss_s2.item(),
                            loss_dis.item()))

    def train_onestep(G, C1, C2, config, epoch):
        criterion = nn.CrossEntropyLoss().cuda()
        G.train()
        C1.train()
        C2.train()

        gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        num_iter = len_source_loader
        for i in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, _ = iter_target.next()
            if i % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target = data_target.cuda()

            opt_g.zero_grad()
            opt_c1.zero_grad()
            opt_c2.zero_grad()

            set_requires_grad(G, requires_grad=True)
            set_requires_grad(C1, requires_grad=True)
            set_requires_grad(C2, requires_grad=True)
            feat_s = G(data_source)
            output_s1 = C1(feat_s)
            output_s2 = C2(feat_s)
            loss_s1 = criterion(output_s1, label_source)
            loss_s2 = criterion(output_s2, label_source)
            loss_s = loss_s1 + loss_s2
            # loss_s.backward(retain_variables=True)
            ##loss_s.backward()

            set_requires_grad(G, requires_grad=False)
            set_requires_grad(C1, requires_grad=True)
            set_requires_grad(C2, requires_grad=True)
            with torch.no_grad():
                feat_t = G(data_target)
            reverse_feature_t = ReverseLayerF.apply(feat_t, gamma)
            output_t1 = C1(reverse_feature_t)
            output_t2 = C2(reverse_feature_t)

            loss_dis = -discrepancy(output_t1, output_t2)
            ##loss_dis.backward()
            loss = loss_s + loss_dis
            loss.backward()
            opt_c1.step()
            opt_c2.step()
            opt_g.step()

            if i % 20 == 0:
                print(
                    'Train Epoch: {}, Loss1: {:.6f}\t Loss2: {:.6f}\t  Discrepancy: {:.6f}'
                    .format(epoch, loss_s1.item(), loss_s2.item(),
                            loss_dis.item()))

    for epoch in range(1, config['n_epochs'] + 1):
        if config['mcd_onestep'] == 1:
            train_onestep(G, C1, C2, config, epoch)
        else:
            train(G, C1, C2, config, epoch)

        if epoch % config['TEST_INTERVAL'] == 0:
            #print('C1 on source_test_loader')
            #logging.debug('C1 on source_test_loader')
            #test(G, C1, config['source_test_loader'], epoch)
            #print('C2 on source_test_loader')
            #logging.debug('C2 on source_test_loader')
            #test(G, C2, config['source_test_loader'], epoch)
            print('C1 on target_test_loader')
            logging.debug('C1 on target_test_loader')
            test(G, C1, config['target_test_loader'], epoch)
            print('C2 on target_test_loader')
            logging.debug('C2 on target_test_loader')
            test(G, C2, config['target_test_loader'], epoch)
        if epoch % config['VIS_INTERVAL'] == 0:
            draw_confusion_matrix(G, C1, config['target_test_loader'], res_dir,
                                  epoch, config['models'])
            draw_tsne(G,
                      C1,
                      config['source_train_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      config['models'],
                      separate=True)
            draw_tsne(G,
                      C1,
                      config['source_train_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      config['models'],
                      separate=False)
Esempio n. 11
0
def train_tcl_vat(config):
    if config['network'] == 'inceptionv1':
        extractor = InceptionV1(num_classes=32)
    elif config['network'] == 'inceptionv1s':
        extractor = InceptionV1s(num_classes=32)
    else:
        extractor = Extractor(n_flattens=config['n_flattens'],
                              n_hiddens=config['n_hiddens'])
    classifier = Classifier(n_flattens=config['n_flattens'],
                            n_hiddens=config['n_hiddens'],
                            n_class=config['n_class'])

    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()
        #summary(extractor, (1, 5120))

    res_dir = os.path.join(
        config['res_dir'],
        'slim{}-snr{}-snrp{}-Lythred{}-Ldthred{}-lambdad{}-lr{}'.format(
            config['slim'], config['snr'], config['snrp'], config['Lythred'],
            config['Ldthred'], config['lambdad'], config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    print('train_tcl')
    #print(extractor)
    #print(classifier)
    print(config)

    set_log_config(res_dir)
    logging.debug('train_tcl')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(config)

    ad_net = AdversarialNetwork(config['n_flattens'], config['n_hiddens'])
    ad_net = ad_net.cuda()

    optimizer = torch.optim.Adam([{
        'params': extractor.parameters(),
        'lr': config['lr']
    }, {
        'params': classifier.parameters(),
        'lr': config['lr']
    }],
                                 weight_decay=0.0001)
    optimizer_ad = torch.optim.Adam(ad_net.parameters(),
                                    lr=config['lr'],
                                    weight_decay=0.0001)
    print(ad_net)

    extractor_path = os.path.join(res_dir, "extractor.pth")
    classifier_path = os.path.join(res_dir, "classifier.pth")
    adnet_path = os.path.join(res_dir, "adnet.pth")

    def cal_Ly(source_y_softmax, source_d, label):
        #
        # source_y_softmax, category预测结果带softmax
        # source_d,domain预测结果
        # label: 实际category标签
        #
        agey = -math.log(config['Lythred'])
        aged = -math.log(1.0 - config['Ldthred'])
        age = agey + config['lambdad'] * aged
        # print('agey {}, labmdad {}, aged {}, age {}'.format(agey, config['lambdad'], aged, age))
        y_softmax = source_y_softmax
        the_index = torch.LongTensor(np.array(range(
            config['batch_size']))).cuda()
        # 这是什么意思?对于每个样本,只取出实际label对应的softmax值
        # 与softmax loss有什么区别?

        y_label = y_softmax[the_index, label]
        # print('y_softmax {}, the_index {}, y_label shape {}'.format(y_softmax.shape, the_index.shape, y_label.shape))
        y_loss = -torch.log(y_label + 1e-8)

        d_loss = -torch.log(1.0 - source_d)
        d_loss = d_loss.view(config['batch_size'])

        weight_loss = y_loss + config['lambdad'] * d_loss
        # print('y_loss {}'.format(torch.mean(y_loss)))
        # print('lambdad {}'.format(config['lambdad']))
        # print('d_loss {}'.format(torch.mean(d_loss)))

        # print('y_loss {}'.format(y_loss.item()))
        # print('lambdad {}'.format(config['lambdad']))
        # print('d_loss {}'.format(d_loss.item()))

        weight_var = (weight_loss < age).float().detach()
        Ly = torch.mean(y_loss * weight_var)

        source_weight = weight_var.data.clone()
        source_num = float((torch.sum(source_weight)))
        return Ly, source_weight, source_num

    def cal_Lt(target_y_softmax):
        # 这是entropy loss吧?
        Gt_var = target_y_softmax
        Gt_en = -torch.sum((Gt_var * torch.log(Gt_var + 1e-8)), 1)
        Lt = torch.mean(Gt_en)
        return Lt

    def train(extractor, classifier, ad_net, config, epoch):
        start_epoch = 0

        extractor.train()
        classifier.train()
        ad_net.train()

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        num_iter = len_source_loader
        if config['slim'] > 0:
            iter_target_semi = iter(config['target_train_semi_loader'])
            len_target_semi_loader = len(config['target_train_semi_loader'])

        for step in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, _ = iter_target.next()
            if config['slim'] > 0:
                data_target_semi, label_target_semi = iter_target_semi.next()
                if step % len_target_semi_loader == 0:
                    iter_target_semi = iter(config['target_train_semi_loader'])

            if step % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target = data_target.cuda()
                if config['slim'] > 0:
                    data_target_semi, label_target_semi = data_target_semi.cuda(
                    ), label_target_semi.cuda()

            source_domain_label = torch.FloatTensor(config['batch_size'], 1)
            target_domain_label = torch.FloatTensor(config['batch_size'], 1)
            source_domain_label.fill_(1)
            target_domain_label.fill_(0)
            domain_label = torch.cat(
                [source_domain_label, target_domain_label], 0)
            domain_label = domain_label.cuda()

            inputs = torch.cat([data_source, data_target], 0)
            features = extractor(inputs)
            gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1
            y_var = classifier(features)
            features = features.view(features.size(0), -1)
            d_var = ad_net(features, gamma)
            y_softmax_var = nn.Softmax(dim=1)(y_var)
            source_y, target_y = y_var.chunk(2, 0)
            source_y_softmax, target_y_softmax = y_softmax_var.chunk(2, 0)
            source_d, target_d = d_var.chunk(2, 0)

            # h_s = extractor(data_source)
            # h_s = h_s.view(h_s.size(0), -1)
            # h_t = extractor(data_target)
            # h_t = h_t.view(h_t.size(0), -1)

            # source_preds = classifier(h_s)
            # softmax_output_s = nn.Softmax(dim=1)(source_preds)
            # target_preds = classifier(h_t)
            # softmax_output_t = nn.Softmax(dim=1)(target_preds)

            # source_d, d_loss_source = loss_func.DANN_logits(h_s, ad_net, gamma)
            # target_d, d_loss_target = loss_func.DANN_logits(h_t, ad_net, gamma)
            # source_d = ad_net(h_s, gamma)
            # target_d = ad_net(h_t, gamma)

            #calculate Ly
            if epoch < config['startiter']:
                #也就是cls_loss,不考虑权重
                Ly = nn.CrossEntropyLoss()(source_y, label_source)
            else:
                Ly, source_weight, source_num = cal_Ly(source_y_softmax,
                                                       source_d, label_source)
                # print('source_num {}'.format(source_num))
                target_weight = torch.ones(source_weight.size()).cuda()

            #calculate Lt
            # 计算target category的熵
            Lt = cal_Lt(target_y_softmax)

            #calculate Ld
            if epoch < config['startiter']:
                Ld = nn.BCELoss()(d_var, domain_label)
            else:
                domain_weight = torch.cat([source_weight, target_weight], 0)
                domain_weight = domain_weight.view(-1, 1)
                # print('domain_weight {}'.format(domain_weight.shape))
                # print('domain_weight {}'.format(domain_weight))
                # print('d_var {}'.format(d_var))

                domain_criterion = nn.BCELoss(weight=domain_weight).cuda()
                # domain_criterion = nn.BCELoss().cuda()

                # print('max {}'.format(torch.max(d_var)))
                # print('min {}'.format(torch.min(d_var)))
                # print(d_var)
                Ld = domain_criterion(d_var, domain_label)

            loss = Ly + config['traded'] * Ld + config['tradet'] * Lt

            if config['slim'] > 0:
                feature_target_semi = extractor(data_target_semi)
                feature_target_semi = feature_target_semi.view(
                    feature_target_semi.size(0), -1)
                preds_target_semi = classifier(feature_target_semi)
                loss += nn.CrossEntropyLoss()(preds_target_semi,
                                              label_target_semi)

            optimizer.zero_grad()
            optimizer_ad.zero_grad()
            # net.zero_grad()
            loss.backward()
            optimizer.step()
            optimizer_ad.step()

            # if (step) % 20 == 0:
            # print('Train Epoch {} closs {:.6f}, dloss {:.6f}, coral_loss {:.6f}, Loss {:.6f}'.format(epoch, cls_loss.item(), d_loss.item(), coral_loss.item(), loss.item()))
            # print('Train Epoch {} closs {:.6f}, dloss {:.6f}, Loss {:.6f}'.format(epoch, cls_loss.item(), d_loss.item(), loss.item()))

    best_accuracy = 0
    best_model_index = -1
    for epoch in range(1, config['n_epochs'] + 1):
        train(extractor, classifier, ad_net, config, epoch)
        if epoch % config['TEST_INTERVAL'] == 0:
            # print('test on source_test_loader')
            # test(extractor, classifier, config['source_test_loader'], epoch)
            print('test on target_test_loader')
            accuracy = test(extractor, classifier,
                            config['target_test_loader'], epoch)

            if accuracy > best_accuracy:
                best_accuracy = accuracy
                best_model_index = epoch
                torch.save(extractor.state_dict(), extractor_path)
                torch.save(classifier.state_dict(), classifier_path)
                torch.save(ad_net.state_dict(), adnet_path)
            print(
                'epoch {} accuracy: {:.6f}, best accuracy {:.6f} on epoch {}'.
                format(epoch, accuracy, best_accuracy, best_model_index))

        if epoch % config['VIS_INTERVAL'] == 0:
            title = config['models']
            draw_confusion_matrix(extractor, classifier,
                                  config['target_test_loader'], res_dir, epoch,
                                  title)
            draw_tsne(extractor,
                      classifier,
                      config['source_train_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      title,
                      separate=True)
            draw_tsne(extractor,
                      classifier,
                      config['source_train_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      title,
                      separate=False)
Esempio n. 12
0
def train_cdan_vat(config):
    if config['network'] == 'inceptionv1':
        extractor = InceptionV1(num_classes=32)
    elif config['network'] == 'inceptionv1s':
        extractor = InceptionV1s(num_classes=32)
    else:
        extractor = Extractor(n_flattens=config['n_flattens'],
                              n_hiddens=config['n_hiddens'],
                              bn=config['bn'])
    classifier = Classifier(n_flattens=config['n_flattens'],
                            n_hiddens=config['n_hiddens'],
                            n_class=config['n_class'])

    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()

    cdan_random = config['random_layer']
    res_dir = os.path.join(
        config['res_dir'],
        'normal{}-{}-dilation{}-lr{}'.format(config['normal'],
                                             config['network'],
                                             config['dilation'], config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    print('train_cdan_vat')
    #print(extractor)
    #print(classifier)
    print(config)

    set_log_config(res_dir)
    logging.debug('train_cdan_vat')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(config)

    if config['models'] == 'DANN_VAT':
        random_layer = None
        ad_net = AdversarialNetwork(config['n_flattens'], config['n_hiddens'])
    elif cdan_random:
        random_layer = RandomLayer([config['n_flattens'], config['n_class']],
                                   config['n_hiddens'])
        ad_net = AdversarialNetwork(config['n_hiddens'], config['n_hiddens'])
        random_layer.cuda()
    else:
        random_layer = None
        ad_net = AdversarialNetwork(config['n_flattens'] * config['n_class'],
                                    config['n_hiddens'])
    ad_net = ad_net.cuda()
    optimizer = torch.optim.Adam([{
        'params': extractor.parameters(),
        'lr': config['lr']
    }, {
        'params': classifier.parameters(),
        'lr': config['lr']
    }])
    optimizer_ad = torch.optim.Adam(ad_net.parameters(), lr=config['lr'])

    vat_loss = VAT(extractor, classifier, n_power=1, radius=3.5).cuda()

    def train(extractor, classifier, ad_net, config, epoch):
        start_epoch = 0

        extractor.train()
        classifier.train()
        ad_net.train()

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        num_iter = len_source_loader
        for step in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, label_target = iter_target.next()
            if step % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target, label_target = data_target.cuda(
                ), label_target.cuda()

            with torch.no_grad():
                if 'CDAN' in config['models']:
                    h_s = extractor(data_source)
                    h_s = h_s.view(h_s.size(0), -1)
                    source_preds = classifier(h_s)
                    softmax_output_s = nn.Softmax(dim=1)(source_preds)

                    op_out = torch.bmm(softmax_output_s.unsqueeze(2),
                                       h_s.unsqueeze(1))
                    gamma = 2 / (1 +
                                 math.exp(-10 *
                                          (epoch) / config['n_epochs'])) - 1
                    ad_out = ad_net(op_out.view(
                        -1,
                        softmax_output_s.size(1) * h_s.size(1)),
                                    gamma,
                                    training=False)
                    dom_entropy = 1 - (torch.abs(0.5 - ad_out))**config['iw']
                    dom_weight = dom_entropy

                elif 'DANN' in config['models']:
                    h_s = extractor(data_source)
                    h_s = h_s.view(h_s.size(0), -1)
                    gamma = 2 / (1 +
                                 math.exp(-10 *
                                          (epoch) / config['n_epochs'])) - 1
                    ad_out = ad_net(h_s, gamma, training=False)
                    dom_entropy = 1 - (torch.abs(0.5 - ad_out))**config['iw']
                    dom_weight = dom_entropy

            optimizer.zero_grad()
            optimizer_ad.zero_grad()

            h_s = extractor(data_source)
            h_s = h_s.view(h_s.size(0), -1)
            h_t = extractor(data_target)
            h_t = h_t.view(h_t.size(0), -1)

            source_preds = classifier(h_s)
            softmax_output_s = nn.Softmax(dim=1)(source_preds)

            if config['iw'] > 0:
                cls_loss = nn.CrossEntropyLoss(reduction='none')(source_preds,
                                                                 label_source)
                cls_loss = torch.mean(dom_weight * cls_loss)
                # print('dom_weight mean {}'.format(torch.mean(dom_weight)))
            else:
                cls_loss = nn.CrossEntropyLoss()(source_preds, label_source)

            target_preds = classifier(h_t)
            softmax_output_t = nn.Softmax(dim=1)(target_preds)
            if config['target_labeling'] == 1:
                cls_loss += nn.CrossEntropyLoss()(target_preds, label_target)

            feature = torch.cat((h_s, h_t), 0)
            softmax_output = torch.cat((softmax_output_s, softmax_output_t), 0)

            if epoch > start_epoch:
                gamma = 2 / (1 + math.exp(-10 *
                                          (epoch) / config['n_epochs'])) - 1
                if config['models'] == 'CDAN-E':
                    entropy = loss_func.Entropy(softmax_output)
                    d_loss = loss_func.CDAN(
                        [feature, softmax_output], ad_net, gamma, entropy,
                        loss_func.calc_coeff(num_iter * (epoch - start_epoch) +
                                             step), random_layer)
                elif config['models'] == 'CDAN_VAT':
                    d_loss = loss_func.CDAN([feature, softmax_output], ad_net,
                                            gamma, None, None, random_layer)
                elif config['models'] == 'DANN_VAT':
                    d_loss = loss_func.DANN(feature, ad_net, gamma)
                else:
                    raise ValueError('Method cannot be recognized.')
            else:
                d_loss = 0

            # target entropy loss
            err_t_entropy = get_loss_entropy(softmax_output_t)

            # virtual adversarial loss.
            err_s_vat = vat_loss(data_source, source_preds)
            err_t_vat = vat_loss(data_target, target_preds)

            # loss = cls_loss + d_loss
            loss = cls_loss + d_loss + err_t_entropy + err_s_vat + err_t_vat
            loss.backward()
            optimizer.step()
            if epoch > start_epoch:
                optimizer_ad.step()
            if (step) % 20 == 0:
                print('Train Epoch {} closs {:.6f}, dloss {:.6f}, Loss {:.6f}'.
                      format(epoch, cls_loss.item(), d_loss.item(),
                             loss.item()))

    best_accuracy = 0
    best_model_index = -1
    for epoch in range(1, config['n_epochs'] + 1):
        train(extractor, classifier, ad_net, config, epoch)
        if epoch % config['TEST_INTERVAL'] == 0:
            # print('test on source_test_loader')
            # test(extractor, classifier, config['source_test_loader'], epoch)
            print('test on target_test_loader')
            accuracy = test(extractor, classifier,
                            config['target_test_loader'], epoch)

            if accuracy > best_accuracy:
                best_accuracy = accuracy
                best_model_index = epoch
            print(
                'epoch {} accuracy: {:.6f}, best accuracy {:.6f} on epoch {}'.
                format(epoch, accuracy, best_accuracy, best_model_index))

        if epoch % config['VIS_INTERVAL'] == 0:
            title = config['models']
            draw_confusion_matrix(extractor, classifier,
                                  config['target_test_loader'], res_dir, epoch,
                                  title)
            draw_tsne(extractor,
                      classifier,
                      config['source_train_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      title,
                      separate=True)
            draw_tsne(extractor,
                      classifier,
                      config['source_test_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      title,
                      separate=False)
Esempio n. 13
0
def train_cdan_ican(config):
    BATCH_SIZE = config['batch_size']
    extractor = Extractor(n_flattens=config['n_flattens'],
                          n_hiddens=config['n_hiddens'])
    classifier = Classifier(n_flattens=config['n_flattens'],
                            n_hiddens=config['n_hiddens'],
                            n_class=config['n_class'])
    disc_activate = Contrast_ReLU_activate(INI_DISC_WEIGHT_SCALE,
                                           INI_DISC_BIAS)

    cdan_random = config['random_layer']
    if config['models'] == 'DANN':
        random_layer = None
        ad_net = AdversarialNetwork(config['n_flattens'], config['n_hiddens'])
    elif cdan_random:
        random_layer = RandomLayer([config['n_flattens'], config['n_class']],
                                   config['n_hiddens'])
        ad_net = AdversarialNetwork(config['n_hiddens'], config['n_hiddens'])
        random_layer.cuda()
    else:
        random_layer = None
        # ad_net = AdversarialNetwork(config['n_flattens'] * config['n_class'], config['n_hiddens'])
        ad_net = AdversarialNetwork(config['n_flattens'], config['n_hiddens'])

    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()
        disc_activate = disc_activate.cuda()
        ad_net = ad_net.cuda()

    optimizer = torch.optim.Adam([{
        'params': extractor.parameters(),
        'lr': config['lr']
    }, {
        'params': classifier.parameters(),
        'lr': config['lr']
    }])
    optimizer_ad = torch.optim.Adam(ad_net.parameters(), lr=config['lr'])
    pseudo_optimizer = torch.optim.Adam(disc_activate.parameters(),
                                        lr=config['lr'])

    class_criterion = nn.CrossEntropyLoss()

    res_dir = os.path.join(
        config['res_dir'],
        'random{}-bs{}-lr{}'.format(cdan_random, config['batch_size'],
                                    config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    print('train_cdan_ican')
    print(extractor)
    print(classifier)
    print(ad_net)
    print(config)

    set_log_config(res_dir)
    logging.debug('train_cdan_ican')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(ad_net)
    logging.debug(config)

    def select_samples_ican(extractor, classifier, ad_net, disc_activate,
                            config, epoch, epoch_acc_s):
        set_training_mode(extractor, False)
        set_training_mode(classifier, False)
        set_training_mode(ad_net, False)
        set_training_mode(disc_activate, False)

        Pseudo_set = []
        confid_threshold = 1 / (1 + np.exp(-2.4 * epoch_acc_s))
        total_pseudo_errors = 0

        # 为什么在target测试集上进行?
        # for target_inputs, target_labels in iter(config['target_test_loader']):
        for target_inputs, target_labels in iter(
                config['target_train_loader']):
            target_inputs = target_inputs.cuda()
            # 论文中target的domain label是1
            domain_labels_t = torch.FloatTensor([0.] *
                                                len(target_inputs)).cuda()

            embeddings = extractor(target_inputs)
            class_t = classifier(embeddings)
            domain_out_t = ad_net(embeddings, training=False)
            disc_weight_t, w_t, b_t = disc_activate(domain_out_t,
                                                    domain_labels_t)
            top_prob, preds_t = torch.max(class_t, 1)

            for i in range(len(disc_weight_t)):
                if disc_weight_t[i] > b_t and top_prob[i] >= float(
                        confid_threshold):
                    s_tuple = (target_inputs[i].cpu(),
                               (preds_t[i].cpu(), float(disc_weight_t[i])))
                    Pseudo_set.append(s_tuple)
            total_pseudo_errors += preds_t.eq(target_labels.cuda()).cpu().sum()

        # 每个pseudo_set样本中包括[features, category_class_predict, domain_weight_predict], [特征,预测的类标,预测的domain权重]
        # print("Pseudo error/total = {}/{}, confid_threshold: {:.4f}".format(total_pseudo_errors, len(Pseudo_set),
        # confid_threshold))
        print(
            'Epoch {}, Stage Select_Sample, accuracy {}, confident threshold {}, pseudo number {}, b_t {}'
            .format(epoch, epoch_acc_s, confid_threshold, len(Pseudo_set),
                    b_t))
        draw_dict['confid_threshold_point'].append(
            float("%.4f" % confid_threshold))

        return Pseudo_set

    # TODO: 为什么不在上一个函数中直接更新呢?选择pseudo-set之后就更新disc-activate的模型参数,完全可以合并成一步
    def update_ican(extractor, classifier, ad_net, disc_activate, config,
                    Pseudo_set, epoch):
        if len(Pseudo_set) == 0:
            return

        set_training_mode(extractor, False)
        set_training_mode(classifier, False)
        set_training_mode(ad_net, False)
        set_training_mode(disc_activate, True)

        pseudo_batch_count = 0
        pseudo_sample_count = 0
        pseudo_epoch_loss = 0.0
        pseudo_epoch_acc = 0
        pseudo_epoch_corrects = 0
        pseudo_avg_loss = 0.0

        # TODO: 每次从pseudo-set中取半个batch-size
        pseudo_loader = torch.utils.data.DataLoader(Pseudo_set,
                                                    batch_size=int(BATCH_SIZE /
                                                                   2),
                                                    shuffle=True)

        for pseudo_inputs, pseudo_labels in pseudo_loader:

            pseudo_batch_count += 1
            pseudo_sample_count += len(pseudo_inputs)

            pseudo_labels, pseudo_weights = pseudo_labels[0], pseudo_labels[1]
            pseudo_inputs, pseudo_labels = pseudo_inputs.cuda(
            ), pseudo_labels.cuda()
            domain_labels = torch.FloatTensor([0.] * len(pseudo_inputs)).cuda()

            embeddings = extractor(pseudo_inputs)
            pseudo_class = classifier(embeddings)
            pseudo_domain_out = ad_net(embeddings, training=False)
            pseudo_disc_weight, pseudo_ww, pseudo_bb = disc_activate(
                pseudo_domain_out, domain_labels)

            pseudo_optimizer.zero_grad()

            # TODO:为什么不用这个pseudo_preds, 而要用上个函数保存的结果呢?
            _, pseudo_preds = torch.max(pseudo_class, 1)

            # pseudo_class:未经过softmax的类分类概率
            # pseudo_labels: 经过softmax的类标签
            # pseudo_disc_weight:样本的领域权重
            # TODO:检查pseudo_disc_weight的形状
            # pseudo_class_loss = compute_new_loss(pseudo_class, pseudo_labels, pseudo_disc_weight)
            pseudo_class_loss = compute_new_loss(pseudo_class, pseudo_preds,
                                                 pseudo_disc_weight)
            # pseudo_class_loss = class_criterion(pseudo_class, pseudo_preds)

            pseudo_epoch_loss += float(pseudo_class_loss)

            # 这个正确率没有意义
            # pseudo_preds 是pseudo_class的最大值,是target train的预测值
            # pseudo_labels 是上一个函数(选择pseudo-set时)计算出来的,同样的公式
            pseudo_epoch_corrects += int(
                torch.sum(pseudo_preds.squeeze() == pseudo_labels.squeeze()))

            pseudo_loss = pseudo_class_loss
            pseudo_loss.backward()
            pseudo_optimizer.step()

            epoch_discrim_lambda = 1.0 / (abs(pseudo_ww)**(1. / 4))
            epoch_discrim_bias = pseudo_bb

        pseudo_avg_loss = pseudo_epoch_loss / pseudo_batch_count
        pseudo_epoch_acc = pseudo_epoch_corrects / pseudo_sample_count

        print(
            'Epoch {}, Phase: {}, Loss: {:.4f} Acc: {:.4f} Disc_Lam: {:.6f} Disc_bias: {:.4f} '
            .format(epoch, 'Pseudo_train', pseudo_avg_loss, pseudo_epoch_acc,
                    epoch_discrim_lambda, epoch_discrim_bias))

    def prepare_dataset(epoch, pseudo_set):
        dset_loaders = {}

        dset_loaders['source'] = config['source_train_loader']
        source_size = len(config['source_train_loader'])
        pseudo_size = len(pseudo_set)
        # source_batches_per_epoch = np.floor(source_size * 2 / BATCH_SIZE).astype(np.int16)
        # total_epochs = config['n_epochs']

        if pseudo_size == 0:
            dset_loaders['pseudo'] = []
            dset_loaders['pseudo_source'] = []
            # source_batchsize = int(BATCH_SIZE / 2)
            source_batchsize = BATCH_SIZE
            pseudo_batchsize = 0
        else:
            # source_batchsize = int(int(BATCH_SIZE / 2) * source_size
            #                             / (source_size + pseudo_size))
            # if source_batchsize == int(BATCH_SIZE / 2):
            #     source_batchsize -= 1
            # if source_batchsize < int(int(BATCH_SIZE / 2) / 2):
            #     source_batchsize = int(int(BATCH_SIZE / 2) / 2)
            # pseudo_batchsize = int(BATCH_SIZE / 2) - source_batchsize
            # print('source_batchsize {}, pseudo_batchsize {}'.format(source_batchsize, pseudo_batchsize))

            # dset_loaders['pseudo'] = torch.utils.data.DataLoader(pseudo_set,
            #                                 batch_size=pseudo_batchsize, shuffle=True)

            # dset_loaders['pseudo_source'] = config['source_train_loader']

            #
            # 重新修改,按照source_train中每个epoch的batch数量,计算pseudo-set的batchsize
            pseudo_batchsize = int(
                np.floor(pseudo_size / len(config['source_train_loader'])))
            dset_loaders['pseudo'] = torch.utils.data.DataLoader(
                pseudo_set,
                batch_size=pseudo_batchsize,
                shuffle=True,
                drop_last=False)
            dset_loaders['pseudo_source'] = config['source_train_loader']
            source_batchsize = BATCH_SIZE

        print(
            'Epoch {}, Stage prepare_dataset, pseudo_size {}, num batch each epoch: {}, pseudo_batchsize {}'
            .format(epoch, pseudo_size, source_size, pseudo_batchsize))

        target_dict = [(i, j) for (i, j) in config['target_train_loader']]
        if pseudo_size > 0:
            pseudo_dict = [(i, j) for (i, j) in dset_loaders['pseudo']]
            pseudo_source_dict = [(i, j)
                                  for (i, j) in dset_loaders['pseudo_source']]
        else:
            pseudo_dict = []
            pseudo_source_dict = []

        # total_iters = source_batches_per_epoch * pre_epochs + \
        #                 source_batches_per_epoch * (total_epochs - pre_epochs) * \
        #                 BATCH_SIZE / (source_batchsize * 2)
        # total_iters = source_batches_per_epoch * (total_epochs) * BATCH_SIZE / (source_batchsize * 2)

        return dset_loaders, target_dict, pseudo_dict, pseudo_source_dict, source_batchsize, pseudo_batchsize

    def train(extractor, classifier, ad_net, disc_activate, config, epoch):
        start_epoch = 0

        extractor.train()
        classifier.train()
        ad_net.train()
        disc_activate.train()

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        num_iter = len_source_loader
        for step in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, _ = iter_target.next()
            if step % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target = data_target.cuda()

            optimizer.zero_grad()
            optimizer_ad.zero_grad()

            h_s = extractor(data_source)
            h_s = h_s.view(h_s.size(0), -1)
            h_t = extractor(data_target)
            h_t = h_t.view(h_t.size(0), -1)

            source_preds = classifier(h_s)
            cls_loss = nn.CrossEntropyLoss()(source_preds, label_source)
            softmax_output_s = nn.Softmax(dim=1)(source_preds)

            target_preds = classifier(h_t)
            softmax_output_t = nn.Softmax(dim=1)(target_preds)

            feature = torch.cat((h_s, h_t), 0)
            softmax_output = torch.cat((softmax_output_s, softmax_output_t), 0)

            if epoch > start_epoch:
                gamma = 2 / (1 + math.exp(-10 *
                                          (epoch) / config['n_epochs'])) - 1
                if config['models'] == 'CDAN-E':
                    entropy = loss_func.Entropy(softmax_output)
                    d_loss = loss_func.CDAN(
                        [feature, softmax_output], ad_net, gamma, entropy,
                        loss_func.calc_coeff(num_iter * (epoch - start_epoch) +
                                             step), random_layer)
                elif config['models'] == 'CDAN':
                    d_loss = loss_func.CDAN([feature, softmax_output], ad_net,
                                            gamma, None, None, random_layer)
                elif config['models'] == 'DANN':
                    d_loss = loss_func.DANN(feature, ad_net, gamma)
                elif config['models'] == 'CDAN_ICAN':
                    d_loss = loss_func.CDAN([feature, softmax_output], ad_net,
                                            gamma, None, None, random_layer)
                else:
                    raise ValueError('Method cannot be recognized.')
            else:
                d_loss = 0

            loss = cls_loss + d_loss
            loss.backward()
            optimizer.step()
            if epoch > start_epoch:
                optimizer_ad.step()
            if (step) % 20 == 0:
                print('Train Epoch {} closs {:.6f}, dloss {:.6f}, Loss {:.6f}'.
                      format(epoch, cls_loss.item(), d_loss.item(),
                             loss.item()))

    # def do_forward(extractor, classifier, ad_net, disc_activate, src_features, all_features, labels):

    #     # 预测source features的class labels
    #     bottle = extractor(src_features)
    #     class_pred = classifier(bottle)

    #     dom_pred = ad_net(bottle)

    #     return class_pred, dom_pred.squeeze(1)

    def do_training(dset_loaders, target_dict, source_batchsize,
                    pseudo_batchsize, pseudo_dict, pseudo_source_dict):

        batch_count = 0
        target_pointer = 0
        target_pointer = 0
        pseudo_pointer = 0
        pseudo_source_pointer = 0
        INI_MAIN_THRESH = -0.8
        # pre_epochs = 10
        pre_epochs = 0

        set_training_mode(extractor, True)
        set_training_mode(classifier, True)
        set_training_mode(ad_net, True)
        set_training_mode(disc_activate, False)

        # class_count = 0
        # epoch_loss = 0.0
        # epoch_corrects = 0
        # domain_epoch_loss = 0.0
        # ini_w_main = torch.FloatTensor([float(INI_MAIN_THRESH)]).cuda()
        # epoch_batch_count = 0
        # total_epoch_loss = 0.0
        # domain_epoch_corrects = 0
        # domain_counts = 0

        for data in dset_loaders['source']:
            inputs, labels = data

            batch_count += 1

            # ---------------- reset exceeded datasets --------------------
            if target_pointer >= len(target_dict) - 1:
                target_pointer = 0
                target_dict = [(i, j)
                               for (i, j) in config['target_train_loader']]

            target_inputs = target_dict[target_pointer][0]

            if epoch <= pre_epochs:
                # 训练CAN,使用source_train和target_train,target_train不经筛选,全部使用
                # -------------------- pretrain model -----------------------
                domain_inputs = torch.cat((inputs, target_inputs), 0)
                # domain_labels = torch.FloatTensor([1.]*BATCH_SIZE + [0.]*BATCH_SIZE)
                domain_labels = torch.FloatTensor([1.] * inputs.size(0) +
                                                  [0.] * target_inputs.size(0))
                domain_inputs, domain_labels = domain_inputs.cuda(
                ), domain_labels.cuda()

                inputs, labels = inputs.cuda(), labels.cuda()

                # print('inputs {}, target_inputs {}, domain_inputs {}, domain_labels {}'.format(inputs.size(), target_inputs.size(), domain_inputs.size(), domain_labels.size()))

                # source数据集上的分类结果
                class_outputs = classifier(extractor(inputs))

                # 在source和target数据集上判断domain分类
                domain_outputs = ad_net(extractor(domain_inputs)).squeeze()

                target_pointer += 1
                # epoch_discrim_bias = 0.5

                # ------------ training classification statistics --------------
                criterion = nn.CrossEntropyLoss()
                class_loss = criterion(class_outputs, labels)

            else:
                # -------------- train with pseudo sample model -------------
                # target域使用经过筛选的pseudo-set数据
                pseudo_weights = torch.FloatTensor([])
                pseudo_size = len(pseudo_dict)

                # 重置索引位置
                if (pseudo_pointer >=
                        len(pseudo_dict) - 1) and (len(pseudo_dict) != 0):
                    pseudo_pointer = 0
                    pseudo_dict = [(i, j) for (i, j) in dset_loaders['pseudo']]

                if (pseudo_source_pointer >= len(pseudo_source_dict) - 1) and (
                        len(pseudo_source_dict) != 0):
                    pseudo_source_pointer = 0
                    pseudo_source_dict = [
                        (i, j) for (i, j) in dset_loaders['pseudo_source']
                    ]

                if pseudo_size == 0:
                    # 如果pseudo-set为空,那还是使用全部source_train和target_train

                    domain_inputs = torch.cat((inputs, target_inputs), 0)
                    # domain_labels = torch.FloatTensor([1.]*int(BATCH_SIZE / 2)+
                    # [0.]*int(BATCH_SIZE / 2))
                    domain_labels = torch.FloatTensor([1.] * inputs.size(0) +
                                                      [0.] *
                                                      target_inputs.size(0))

                    fuse_inputs = inputs
                    fuse_labels = labels

                else:
                    pseudo_inputs, pseudo_labels, pseudo_weights = pseudo_dict[pseudo_pointer][0], \
                                    pseudo_dict[pseudo_pointer][1][0], pseudo_dict[pseudo_pointer][1][1]
                    pseudo_source_inputs = pseudo_source_dict[
                        pseudo_source_pointer][0]

                    # TODO: 为什么要这么干?source + pseudo + target + source
                    # domain_inputs = torch.cat((inputs, pseudo_inputs, target_inputs, pseudo_source_inputs),0)
                    # domain_labels = torch.FloatTensor([1.]*inputs.size(0) + [0.]*pseudo_inputs.size(0) +
                    #                                     [0.]*target_inputs.size(0)+[1.]*pseudo_source_inputs.size(0))
                    domain_inputs = torch.cat((inputs, pseudo_inputs), 0)
                    domain_labels = torch.FloatTensor([1.] * inputs.size(0) +
                                                      [0.] *
                                                      pseudo_inputs.size(0))

                    fuse_inputs = torch.cat((inputs, pseudo_inputs), 0)
                    fuse_labels = torch.cat((labels, pseudo_labels), 0)

                    # print('inputs {}, pseudo_inputs {}, target_inputs {}, domain_inputs {}'.format(inputs.size(), pseudo_inputs.size(), target_inputs.size(), domain_inputs.size()))
                    # print('domain_labels {}, fuse_inputs {}, fuse_labels {}'.format(domain_labels.size(), fuse_inputs.size(), fuse_labels.size()))

                inputs, labels = fuse_inputs.cuda(), fuse_labels.cuda()
                domain_inputs, domain_labels = domain_inputs.cuda(
                ), domain_labels.cuda()

                source_weight_tensor = torch.FloatTensor([1.] *
                                                         source_batchsize)
                pseudo_weights_tensor = pseudo_weights.float()
                class_weights_tensor = torch.cat(
                    (source_weight_tensor, pseudo_weights_tensor), 0)
                dom_weights_tensor = torch.FloatTensor([0.] *
                                                       source_batchsize +
                                                       [1.] * pseudo_batchsize)

                ini_weight = torch.cat(
                    (class_weights_tensor, dom_weights_tensor),
                    0).squeeze().cuda()

                class_outputs = classifier(extractor(inputs))
                domain_outputs = ad_net(extractor(domain_inputs)).squeeze()

                # ------------ training classification statistics --------------
                # _, preds = torch.max(class_outputs, 1)
                # class_count += len(preds)
                class_loss = compute_new_loss(class_outputs, labels,
                                              ini_weight)

                # epoch_loss += float(class_loss)
                # epoch_corrects += int(torch.sum(preds.squeeze() == labels.squeeze()))

                target_pointer += 1
                pseudo_pointer += 1
                pseudo_source_pointer += 1

            # zero the parameter gradients
            optimizer.zero_grad()
            optimizer_ad.zero_grad()

            # ----------- calculate pred domain labels and losses -----------
            domain_criterion = nn.BCEWithLogitsLoss()
            domain_labels = domain_labels.squeeze()
            domain_loss = domain_criterion(domain_outputs, domain_labels)
            # domain_epoch_loss += float(domain_loss)

            # ------ calculate pseudo predicts and losses with weights and threshold lambda -------
            total_loss = class_loss + 1.0 * domain_loss

            # total_epoch_loss += float(total_loss)
            print('class_loss {}, domain_loss {}'.format(
                class_loss.item(), domain_loss.item()))

            #  -------  backward + optimize in training and Pseudo-training phase -------
            total_loss.backward()
            optimizer.step()
            optimizer_ad.step()

    def train_ican(extractor, classifier, ad_net, disc_activate, config,
                   epoch):
        # start_epoch = 0

        # 1. 计算在source上的准确度,用于选择伪标签
        accuracy_s = test(extractor, classifier, config['source_test_loader'],
                          epoch)

        # 2. 计算伪标签数据集
        pseu_set = select_samples_ican(extractor, classifier, ad_net,
                                       disc_activate, config, epoch,
                                       accuracy_s)

        # 3. 使用伪数据集训练disc_activate,更新disc threshold
        update_ican(extractor, classifier, ad_net, disc_activate, config,
                    pseu_set, epoch)

        # 4. 准备最终训练ican所用的数据集,将source dataset和pseudo set合并
        dset_loaders, target_dict, pseudo_dict, pseudo_source_dict, source_batchsize, pseudo_batchsize = prepare_dataset(
            epoch, pseu_set)

        # 5. train
        # do_training()
        do_training(dset_loaders, target_dict, source_batchsize,
                    pseudo_batchsize, pseudo_dict, pseudo_source_dict)

        # iter_source = iter(config['source_train_loader'])
        # iter_target = iter(config['target_train_loader'])
        # len_source_loader = len(config['source_train_loader'])
        # len_target_loader = len(config['target_train_loader'])
        # num_iter = len_source_loader
        # for step in range(1, num_iter + 1):
        #     data_source, label_source = iter_source.next()
        #     data_target, _ = iter_target.next()
        #     if step % len_target_loader == 0:
        #         iter_target = iter(config['target_train_loader'])
        #     if torch.cuda.is_available():
        #         data_source, label_source = data_source.cuda(), label_source.cuda()
        #         data_target = data_target.cuda()

        #     optimizer.zero_grad()
        #     optimizer_ad.zero_grad()

        #     h_s = extractor(data_source)
        #     h_s = h_s.view(h_s.size(0), -1)
        #     h_t = extractor(data_target)
        #     h_t = h_t.view(h_t.size(0), -1)

        #     source_preds = classifier(h_s)
        #     cls_loss = nn.CrossEntropyLoss()(source_preds, label_source)
        #     softmax_output_s = nn.Softmax(dim=1)(source_preds)

        #     target_preds = classifier(h_t)
        #     softmax_output_t = nn.Softmax(dim=1)(target_preds)

        #     feature = torch.cat((h_s, h_t), 0)
        #     softmax_output = torch.cat((softmax_output_s, softmax_output_t), 0)

        #     if epoch > start_epoch:
        #         gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1
        #         if config['models'] == 'CDAN-E':
        #             entropy = loss_func.Entropy(softmax_output)
        #             d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, entropy, loss_func.calc_coeff(num_iter*(epoch-start_epoch)+step), random_layer)
        #         elif config['models'] == 'CDAN':
        #             d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, None, None, random_layer)
        #         elif config['models'] == 'DANN':
        #             d_loss = loss_func.DANN(feature, ad_net, gamma)
        #         elif config['models'] == 'CDAN_ICAN':
        #             d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, None, None, random_layer)
        #         else:
        #             raise ValueError('Method cannot be recognized.')
        #     else:
        #         d_loss = 0

        #     loss = cls_loss + d_loss
        #     loss.backward()
        #     optimizer.step()
        #     if epoch > start_epoch:
        #         optimizer_ad.step()
        #     if (step) % 20 == 0:
        #         print('Train Epoch {} closs {:.6f}, dloss {:.6f}, Loss {:.6f}'.format(epoch, cls_loss.item(), d_loss.item(), loss.item()))

    # function done

    for epoch in range(1, config['n_epochs'] + 1):
        train_ican(extractor, classifier, ad_net, disc_activate, config, epoch)
        if epoch % config['TEST_INTERVAL'] == 0:
            print('test on source_test_loader')
            test(extractor, classifier, config['source_test_loader'], epoch)
            print('test on target_test_loader')
            accuracy = test(extractor, classifier,
                            config['target_test_loader'], epoch)

        if epoch % config['VIS_INTERVAL'] == 0:
            title = config['models']
            draw_confusion_matrix(extractor, classifier,
                                  config['target_test_loader'], res_dir, epoch,
                                  title)
            draw_tsne(extractor,
                      classifier,
                      config['source_test_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      title,
                      separate=True)
Esempio n. 14
0
def train_dann_mm2(config):
    if config['network'] == 'inceptionv1':
        extractor = InceptionV1(num_classes=32)
    elif config['network'] == 'inceptionv1s':
        extractor = InceptionV1s(num_classes=32)
    else:
        extractor = Extractor(n_flattens=config['n_flattens'],
                              n_hiddens=config['n_hiddens'],
                              bn=config['bn'])
    classifier = Classifier(n_flattens=config['n_flattens'],
                            n_hiddens=config['n_hiddens'],
                            n_class=config['n_class'])
    # classifier = Predictor_deep(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], num_class=config['n_class'])
    critic = Critic2(n_flattens=config['n_flattens'],
                     n_hiddens=config['n_hiddens'])
    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()
        critic = critic.cuda()
        summary(extractor, (1, 5120))

    res_dir = os.path.join(config['res_dir'],
                           'snr{}-lr{}'.format(config['snr'], config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    set_log_config(res_dir)
    logging.debug('train_dann_mm2')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(critic)
    logging.debug(config)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer_e = optim.Adam(extractor.parameters(), lr=config['lr'])
    optimizer_cls = optim.Adam(classifier.parameters(), lr=config['lr'])
    optimizer_critic = optim.Adam(critic.parameters(), lr=config['lr'])

    def dann(input_data, alpha):
        feature = extractor(input_data)
        feature = feature.view(feature.size(0), -1)
        reverse_feature = ReverseLayerF.apply(feature, alpha)
        class_output = classifier(feature)
        domain_output = critic(reverse_feature)

        return class_output, domain_output, feature

    def entropy(F1, feat, lamda, eta=1.0):
        out_t1 = F1(feat, reverse=True, eta=-eta)
        out_t1 = F.softmax(out_t1, dim=1)
        loss_ent = -lamda * torch.mean(
            torch.sum(out_t1 * (torch.log(out_t1 + 1e-5)), 1))
        return loss_ent

    def adentropy(F1, feat, lamda, eta=1.0):
        out_t1 = F1(feat, reverse=True, eta=eta)
        out_t1 = F.softmax(out_t1, dim=1)
        loss_adent = lamda * torch.mean(
            torch.sum(out_t1 * (torch.log(out_t1 + 1e-5)), 1))
        return loss_adent

    def entropy_softmax(output, lamda):
        loss_ent = -lamda * torch.mean(
            torch.sum(output * (torch.log(output + 1e-5)), 1))
        return loss_ent

    def adentropy_softmax(output, lamda):
        loss_adent = lamda * torch.mean(
            torch.sum(output * (torch.log(output + 1e-5)), 1))
        return loss_adent

    def train(extractor, classifier, critic, config, epoch):
        extractor.train()
        classifier.train()
        critic.train()

        gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        num_iter = len_source_loader
        for i in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, _ = iter_target.next()
            if i % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target = data_target.cuda()

            optimizer_e.zero_grad()
            optimizer_cls.zero_grad()
            optimizer_critic.zero_grad()

            class_output_s, domain_output, _ = dann(input_data=data_source,
                                                    alpha=gamma)
            err_s_label = criterion(class_output_s, label_source)
            domain_label = torch.zeros(data_source.size(0)).long().cuda()
            err_s_domain = criterion(domain_output, domain_label)

            # Training model using target data
            domain_label = torch.ones(data_target.size(0)).long().cuda()
            class_output_t, domain_output, _ = dann(input_data=data_target,
                                                    alpha=gamma)
            err_t_domain = criterion(domain_output, domain_label)
            err = err_s_label + err_s_domain + err_t_domain

            if i % 100 == 0:
                print(
                    'err_s_label {:.2f}, err_s_domain {:.2f}, gamma {:.2f}, err_t_domain {:.2f}, total err {:.2f}'
                    .format(err_s_label.item(), err_s_domain.item(), gamma,
                            err_t_domain.item(), err.item()))

            err.backward()
            optimizer_e.step()
            optimizer_cls.step()
            optimizer_critic.step()

            # minmax
            optimizer_e.zero_grad()
            optimizer_cls.zero_grad()
            feature_t = extractor(data_target)
            feature_t = feature_t.view(feature_t.size(0), -1)
            # entropy_loss = adentropy(classifier, feature_t, 1)
            entropy_loss = entropy(classifier, feature_t, 1)
            entropy_loss.backward()
            optimizer_e.step()
            optimizer_cls.step()

    for epoch in range(1, config['n_epochs'] + 1):
        train(extractor, classifier, critic, config, epoch)
        if epoch % config['TEST_INTERVAL'] == 0:
            # print('test on source_test_loader')
            # test(extractor, classifier, config['source_test_loader'], epoch)
            print('test on target_test_loader')
            test(extractor, classifier, config['target_test_loader'], epoch)
        if epoch % config['VIS_INTERVAL'] == 0:
            draw_confusion_matrix(extractor, classifier,
                                  config['target_test_loader'], res_dir, epoch,
                                  config['models'])
            draw_tsne(extractor,
                      classifier,
                      config['source_train_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      config['models'],
                      separate=True)
            draw_tsne(extractor,
                      classifier,
                      config['source_train_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      config['models'],
                      separate=False)
Esempio n. 15
0
def train_cdan_iw(config):
    if config['network'] == 'inceptionv1':
        extractor = InceptionV1(num_classes=32)
    elif config['network'] == 'inceptionv1s':
        extractor = InceptionV1s(num_classes=32)
    else:
        extractor = Extractor(n_flattens=config['n_flattens'],
                              n_hiddens=config['n_hiddens'])
    classifier = Classifier(n_flattens=config['n_flattens'],
                            n_hiddens=config['n_hiddens'],
                            n_class=config['n_class'])

    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()
        #summary(extractor, (1, 5120))

    cdan_random = config['random_layer']
    res_dir = os.path.join(
        config['res_dir'],
        'normal{}-{}-dilation{}-iw{}-lr{}'.format(config['normal'],
                                                  config['network'],
                                                  config['dilation'],
                                                  config['iw'], config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    print('train_cdan_iw')
    #print(extractor)
    #print(classifier)
    print(config)

    set_log_config(res_dir)
    logging.debug('train_cdan')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(config)

    if config['models'] == 'DANN_IW':
        random_layer = None
        ad_net = AdversarialNetwork(config['n_flattens'], config['n_hiddens'])
    elif cdan_random:
        random_layer = RandomLayer([config['n_flattens'], config['n_class']],
                                   config['n_hiddens'])
        ad_net = AdversarialNetwork(config['n_hiddens'], config['n_hiddens'])
        random_layer.cuda()
    else:
        random_layer = None
        ad_net = AdversarialNetwork(config['n_flattens'] * config['n_class'],
                                    config['n_hiddens'])
    ad_net = ad_net.cuda()
    optimizer = torch.optim.Adam([{
        'params': extractor.parameters(),
        'lr': config['lr']
    }, {
        'params': classifier.parameters(),
        'lr': config['lr']
    }],
                                 weight_decay=0.0001)
    optimizer_ad = torch.optim.Adam(ad_net.parameters(),
                                    lr=config['lr'],
                                    weight_decay=0.0001)
    print(ad_net)

    extractor_path = os.path.join(res_dir, "extractor.pth")
    classifier_path = os.path.join(res_dir, "classifier.pth")
    adnet_path = os.path.join(res_dir, "adnet.pth")

    def train(extractor, classifier, ad_net, config, epoch):
        start_epoch = 0

        extractor.train()
        classifier.train()
        ad_net.train()

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        num_iter = len_source_loader
        for step in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, _ = iter_target.next()
            if step % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target = data_target.cuda()
            """
            add code start
            """
            with torch.no_grad():
                if config['models'] == 'CDAN_IW':
                    h_s = extractor(data_source)
                    h_s = h_s.view(h_s.size(0), -1)
                    h_t = extractor(data_target)
                    h_t = h_t.view(h_t.size(0), -1)

                    source_preds = classifier(h_s)
                    softmax_output_s = nn.Softmax(dim=1)(source_preds)
                    # print(softmax_output_s.shape)
                    # print(softmax_output_s.unsqueeze(2).shape)
                    # print(softmax_output_s)

                    # target_preds = classifier(h_t)
                    # softmax_output_t = nn.Softmax(dim=1)(target_preds)

                    # feature = torch.cat((h_s, h_t), 0)
                    # softmax_output = torch.cat((softmax_output_s, softmax_output_t), 0)
                    weights = torch.ones(softmax_output_s.shape).cuda()
                    weights = 1.0 * weights
                    weights = weights.unsqueeze(2)

                    # op_out = torch.bmm(softmax_output_s.unsqueeze(2), h_s.unsqueeze(1))
                    op_out = torch.bmm(weights, h_s.unsqueeze(1))
                    # gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1
                    gamma = 1
                    ad_out = ad_net(op_out.view(
                        -1,
                        softmax_output_s.size(1) * h_s.size(1)),
                                    gamma,
                                    training=False)
                    # dom_entropy = loss_func.Entropy(ad_out)
                    dom_entropy = 1 + (torch.abs(0.5 - ad_out))**config['iw']
                    # dom_weight = dom_entropy / torch.sum(dom_entropy)
                    dom_weight = dom_entropy

                elif config['models'] == 'DANN_IW':
                    h_s = extractor(data_source)
                    h_s = h_s.view(h_s.size(0), -1)
                    # gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1
                    gamma = 1
                    ad_out = ad_net(h_s, gamma, training=False)
                    # dom_entropy = 1-((torch.abs(0.5-ad_out))**config['iw'])
                    # dom_weight = dom_entropy
                    dom_weight = torch.ones(ad_out.shape).cuda()
                    #dom_entropy = loss_func.Entropy(dom_entropy)
                    # dom_weight = dom_entropy / torch.sum(dom_entropy)
            """
            add code end
            """

            optimizer.zero_grad()
            optimizer_ad.zero_grad()

            h_s = extractor(data_source)
            h_s = h_s.view(h_s.size(0), -1)
            h_t = extractor(data_target)
            h_t = h_t.view(h_t.size(0), -1)

            source_preds = classifier(h_s)
            softmax_output_s = nn.Softmax(dim=1)(source_preds)

            target_preds = classifier(h_t)
            softmax_output_t = nn.Softmax(dim=1)(target_preds)

            feature = torch.cat((h_s, h_t), 0)
            softmax_output = torch.cat((softmax_output_s, softmax_output_t), 0)

            cls_loss = nn.CrossEntropyLoss(reduction='none')(source_preds,
                                                             label_source)
            cls_loss = torch.mean(dom_weight * cls_loss)

            if epoch > start_epoch:
                gamma = 2 / (1 + math.exp(-10 *
                                          (epoch) / config['n_epochs'])) - 1
                if config['models'] == 'CDAN_EIW':
                    entropy = loss_func.Entropy(softmax_output)
                    # print('softmax_output {}, entropy {}'.format(softmax_output.size(), entropy.size()))
                    d_loss = loss_func.CDAN(
                        [feature, softmax_output], ad_net, gamma, entropy,
                        loss_func.calc_coeff(num_iter * (epoch - start_epoch) +
                                             step), random_layer)
                elif config['models'] == 'CDAN_IW':
                    d_loss = loss_func.CDAN([feature, softmax_output], ad_net,
                                            gamma, None, None, random_layer)
                elif config['models'] == 'DANN_IW':
                    d_loss = loss_func.DANN(feature, ad_net, gamma)
                else:
                    raise ValueError('Method cannot be recognized.')
            else:
                d_loss = 0

            loss = cls_loss + d_loss
            loss.backward()
            optimizer.step()
            if epoch > start_epoch:
                optimizer_ad.step()
            if (step) % 20 == 0:
                print('Train Epoch {} closs {:.6f}, dloss {:.6f}, Loss {:.6f}'.
                      format(epoch, cls_loss.item(), d_loss.item(),
                             loss.item()))

    best_accuracy = 0
    best_model_index = -1
    for epoch in range(1, config['n_epochs'] + 1):
        train(extractor, classifier, ad_net, config, epoch)
        if epoch % config['TEST_INTERVAL'] == 0:
            # print('test on source_test_loader')
            # test(extractor, classifier, config['source_test_loader'], epoch)
            print('test on target_test_loader')
            accuracy = test(extractor, classifier,
                            config['target_test_loader'], epoch)

            if accuracy > best_accuracy:
                best_accuracy = accuracy
                best_model_index = epoch
                torch.save(extractor.state_dict(), extractor_path)
                torch.save(classifier.state_dict(), classifier_path)
                torch.save(ad_net.state_dict(), adnet_path)
            print(
                'epoch {} accuracy: {:.6f}, best accuracy {:.6f} on epoch {}'.
                format(epoch, accuracy, best_accuracy, best_model_index))

        if epoch % config['VIS_INTERVAL'] == 0:
            title = config['models']
            draw_confusion_matrix(extractor, classifier,
                                  config['target_test_loader'], res_dir, epoch,
                                  title)
            draw_tsne(extractor,
                      classifier,
                      config['source_train_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      title,
                      separate=True)
            draw_tsne(extractor,
                      classifier,
                      config['source_train_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      title,
                      separate=False)