Esempio n. 1
0
def train_pada(config):
    if config['network'] == 'inceptionv1':
        extractor_s = InceptionV1(num_classes=32)
        extractor_t = InceptionV1(num_classes=32)
    elif config['network'] == 'inceptionv1s':
        extractor_s = InceptionV1s(num_classes=32)
        extractor_t = InceptionV1s(num_classes=32)
    else:
        extractor_s = Extractor(n_flattens=config['n_flattens'],
                                n_hiddens=config['n_hiddens'])
        extractor_t = Extractor(n_flattens=config['n_flattens'],
                                n_hiddens=config['n_hiddens'])

    classifier_s = Classifier(n_flattens=config['n_flattens'],
                              n_hiddens=config['n_hiddens'],
                              n_class=config['n_class'])
    classifier_t = Classifier(n_flattens=config['n_flattens'],
                              n_hiddens=config['n_hiddens'],
                              n_class=config['n_class'])

    if torch.cuda.is_available():
        extractor_s = extractor_s.cuda()
        classifier_s = classifier_s.cuda()

        extractor_t = extractor_t.cuda()
        classifier_t = classifier_t.cuda()

    cdan_random = config['random_layer']
    res_dir = os.path.join(
        config['res_dir'],
        'normal{}-{}-cons{}-lr{}'.format(config['normal'], config['network'],
                                         config['pada_cons_w'], config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    print('train_pada')
    print(config)

    set_log_config(res_dir)
    logging.debug('train_pada')
    # logging.debug(extractor)
    # logging.debug(classifier)
    logging.debug(config)

    if config['models'] == 'PADA':
        random_layer = None
        ad_net = AdversarialNetwork(config['n_flattens'], config['n_hiddens'])
    elif cdan_random:
        random_layer = RandomLayer([config['n_flattens'], config['n_class']],
                                   config['n_hiddens'])
        ad_net = AdversarialNetwork(config['n_hiddens'], config['n_hiddens'])
        random_layer.cuda()
    else:
        random_layer = None
        ad_net = AdversarialNetwork(config['n_flattens'] * config['n_class'],
                                    config['n_hiddens'])
    ad_net = ad_net.cuda()
    optimizer_s = torch.optim.Adam([{
        'params': extractor_s.parameters(),
        'lr': config['lr']
    }, {
        'params': classifier_s.parameters(),
        'lr': config['lr']
    }])
    optimizer_t = torch.optim.Adam([{
        'params': extractor_t.parameters(),
        'lr': config['lr']
    }, {
        'params': classifier_t.parameters(),
        'lr': config['lr']
    }])

    optimizer_ad = torch.optim.Adam(ad_net.parameters(), lr=config['lr'])

    def train_stage1(extractor_s, classifier_s, config, epoch):
        extractor_s.train()
        classifier_s.train()

        # STAGE 1:
        # 在labeled source上训练extractor_s和classifier_s
        # 训练完成后freeze这两个model
        iter_source = iter(config['source_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        for step in range(1, len_source_loader + 1):
            data_source, label_source = iter_source.next()
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()

            optimizer_s.zero_grad()

            h_s = extractor_s(data_source)
            h_s = h_s.view(h_s.size(0), -1)
            source_preds = classifier_s(h_s)
            cls_loss = nn.CrossEntropyLoss()(source_preds, label_source)

            cls_loss.backward()
            optimizer_s.step()

    def train(extractor_s, classifier_s, extractor_t, classifier_t, ad_net,
              config, epoch):
        start_epoch = 0

        # extractor_s.train()
        # classifier_s.train()
        # ad_net.train()

        # # STAGE 1:
        # # 在labeled source上训练extractor_s和classifier_s
        # # 训练完成后freeze这两个model
        # iter_source = iter(config['source_train_loader'])
        # len_source_loader = len(config['source_train_loader'])
        # for step in range(1, len_source_loader + 1):
        #     data_source, label_source = iter_source.next()
        #     if torch.cuda.is_available():
        #         data_source, label_source = data_source.cuda(), label_source.cuda()

        #     optimizer_s.zero_grad()

        #     h_s = extractor_s(data_source)
        #     h_s = h_s.view(h_s.size(0), -1)
        #     source_preds = classifier_s(h_s)
        #     cls_loss = nn.CrossEntropyLoss()(source_preds, label_source)

        #     cls_loss.backward()
        #     optimizer_s.step()

        # for param in extractor_s.parameters():
        #     param.requires_grad = False
        # for param in classifier_s.parameters():
        #     param.requires_grad = False

        # STAGE 2:
        # 使用新的extractor和classifier进行DANN训练
        # 不同的地方是,每个target 同时使用extractor_s和extractor_t

        extractor_t.train()
        classifier_t.train()
        ad_net.train()

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        num_iter = len_source_loader
        for step in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, label_target = iter_target.next()
            if step % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target, label_target = data_target.cuda(
                ), label_target.cuda()

            optimizer_t.zero_grad()
            optimizer_ad.zero_grad()

            h_s = extractor_t(data_source)
            h_s = h_s.view(h_s.size(0), -1)
            h_t = extractor_t(data_target)
            h_t = h_t.view(h_t.size(0), -1)

            source_preds = classifier_t(h_s)
            cls_loss = nn.CrossEntropyLoss()(source_preds, label_source)
            softmax_output_s = nn.Softmax(dim=1)(source_preds)

            target_preds = classifier_t(h_t)
            softmax_output_t = nn.Softmax(dim=1)(target_preds)
            if config['target_labeling'] == 1:
                cls_loss += nn.CrossEntropyLoss()(target_preds, label_target)

            feature = torch.cat((h_s, h_t), 0)
            softmax_output = torch.cat((softmax_output_s, softmax_output_t), 0)

            if epoch > start_epoch:
                gamma = 2 / (1 + math.exp(-10 *
                                          (epoch) / config['n_epochs'])) - 1
                if config['models'] == 'CDAN-E':
                    entropy = loss_func.Entropy(softmax_output)
                    d_loss = loss_func.CDAN(
                        [feature, softmax_output], ad_net, gamma, entropy,
                        loss_func.calc_coeff(num_iter * (epoch - start_epoch) +
                                             step), random_layer)
                elif config['models'] == 'CDAN':
                    d_loss = loss_func.CDAN([feature, softmax_output], ad_net,
                                            gamma, None, None, random_layer)
                elif config['models'] == 'PADA':
                    d_loss = loss_func.DANN(feature, ad_net, gamma)
                else:
                    raise ValueError('Method cannot be recognized.')
            else:
                d_loss = 0

            # constraints loss
            h_s_prev = extractor_s(data_source)
            cons_loss = nn.L1Loss()(h_s, h_s_prev)

            loss = cls_loss + d_loss + config['pada_cons_w'] * cons_loss
            loss.backward()
            optimizer_t.step()
            if epoch > start_epoch:
                optimizer_ad.step()
            if (step) % 20 == 0:
                print(
                    'Train Epoch {} closs {:.6f}, dloss {:.6f}, cons_loss {:.6f}, Loss {:.6f}'
                    .format(epoch, cls_loss.item(), d_loss.item(),
                            cons_loss.item(), loss.item()))

    for epoch in range(1, config['n_epochs'] + 1):
        train_stage1(extractor_s, classifier_s, config, epoch)
        if epoch % config['TEST_INTERVAL'] == 0:
            print('test on source_test_loader')
            test(extractor_s, classifier_s, config['source_test_loader'],
                 epoch)
            # print('test on target_test_loader')
            # accuracy = test(extractor_s, classifier_s, config['target_test_loader'], epoch)

    extractor_t.load_state_dict(extractor_s.state_dict())
    classifier_t.load_state_dict(classifier_s.state_dict())

    for param in extractor_s.parameters():
        param.requires_grad = False
    for param in classifier_s.parameters():
        param.requires_grad = False

    for epoch in range(1, config['n_epochs'] + 1):
        train(extractor_s, classifier_s, extractor_t, classifier_t, ad_net,
              config, epoch)
        if epoch % config['TEST_INTERVAL'] == 0:
            # print('test on source_test_loader')
            # test(extractor, classifier, config['source_test_loader'], epoch)
            print('test on target_test_loader')
            accuracy = test(extractor_t, classifier_t,
                            config['target_test_loader'], epoch)

        if epoch % config['VIS_INTERVAL'] == 0:
            title = config['models']
            draw_confusion_matrix(extractor_t, classifier_t,
                                  config['target_test_loader'], res_dir, epoch,
                                  title)
            draw_tsne(extractor_t,
                      classifier_t,
                      config['source_train_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      title,
                      separate=True)
Esempio n. 2
0
def train_cdan(config):
    if config['network'] == 'inceptionv1':
        extractor = InceptionV1(num_classes=32)
    elif config['network'] == 'inceptionv1s':
        extractor = InceptionV1s(num_classes=32)
    else:
        extractor = Extractor(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], bn=config['bn'])

    classifier = Classifier(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], n_class=config['n_class'])
    vat_loss = VAT(extractor, classifier, n_power=1, radius=3.5).cuda()

    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()

    cdan_random = config['random_layer'] 
    res_dir = os.path.join(config['res_dir'], 'slim{}-targetLabel{}-snr{}-snrp{}-lr{}'.format(config['slim'],
                                                                        config['target_labeling'],
                                                                        config['snr'],
                                                                        config['snrp'],
                                                                        config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    print('train_cdan')
    #print(extractor)
    #print(classifier)
    print(config)

    set_log_config(res_dir)
    logging.debug('train_cdan')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(config)

    if config['models'] == 'DANN':
        random_layer = None
        ad_net = AdversarialNetwork(config['n_flattens'], config['n_hiddens'])
    elif cdan_random:
        random_layer = RandomLayer([config['n_flattens'], config['n_class']], config['n_hiddens'])
        ad_net = AdversarialNetwork(config['n_hiddens'], config['n_hiddens'])
        random_layer.cuda()
    else:
        random_layer = None
        ad_net = AdversarialNetwork(config['n_flattens'] * config['n_class'], config['n_hiddens'])
    ad_net = ad_net.cuda()
    optimizer = torch.optim.Adam([
        {'params': extractor.parameters(), 'lr': config['lr']},
        {'params': classifier.parameters(), 'lr': config['lr']}
        ])
    optimizer_ad = torch.optim.Adam(ad_net.parameters(), lr=config['lr'])


    extractor_path = os.path.join(res_dir, "extractor.pth")
    classifier_path = os.path.join(res_dir, "classifier.pth")
    adnet_path = os.path.join(res_dir, "adnet.pth")

    def train(extractor, classifier, ad_net, config, epoch):
        start_epoch = 0

        extractor.train()
        classifier.train()
        ad_net.train()

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        if config['slim'] > 0:
            iter_target_semi = iter(config['target_train_semi_loader'])
            len_target_semi_loader = len(config['target_train_semi_loader'])

        num_iter = len_source_loader
        for step in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, label_target = iter_target.next()
            if config['slim'] > 0:
                data_target_semi, label_target_semi = iter_target_semi.next()
                if step % len_target_semi_loader == 0:
                    iter_target_semi = iter(config['target_train_semi_loader'])
            if step % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(), label_source.cuda()
                data_target, label_target = data_target.cuda(), label_target.cuda()
                if config['slim'] > 0:
                    data_target_semi, label_target_semi = data_target_semi.cuda(), label_target_semi.cuda()

            optimizer.zero_grad()
            optimizer_ad.zero_grad()

            h_s = extractor(data_source)
            h_s = h_s.view(h_s.size(0), -1)
            h_t = extractor(data_target)
            h_t = h_t.view(h_t.size(0), -1)

            source_preds = classifier(h_s)
            cls_loss = nn.CrossEntropyLoss()(source_preds, label_source)
            softmax_output_s = nn.Softmax(dim=1)(source_preds)

            target_preds = classifier(h_t)
            softmax_output_t = nn.Softmax(dim=1)(target_preds)
            if config['target_labeling'] == 1:
                cls_loss += nn.CrossEntropyLoss()(target_preds, label_target)

            feature = torch.cat((h_s, h_t), 0)
            softmax_output = torch.cat((softmax_output_s, softmax_output_t), 0)

            if epoch > start_epoch:
                gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1
                if config['models'] == 'CDAN-E':
                    entropy = loss_func.Entropy(softmax_output)
                    d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, entropy, loss_func.calc_coeff(num_iter*(epoch-start_epoch)+step), random_layer)
                elif config['models'] == 'CDAN':
                    d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, None, None, random_layer)
                elif config['models'] == 'DANN':
                    d_loss = loss_func.DANN(feature, ad_net, gamma)
                else:
                    raise ValueError('Method cannot be recognized.')
            else:
                d_loss = 0

            loss = cls_loss + d_loss

            err_t_bnm = get_loss_bnm(target_preds)
            err_s_vat = vat_loss(data_source, source_preds)
            err_t_vat = vat_loss(data_target, target_preds)
            loss +=  1.0 * err_s_vat + 1.0 * err_t_vat + 1.0 * err_t_bnm

            if config['slim'] > 0:
                feature_target_semi = extractor(data_target_semi)
                feature_target_semi = feature_target_semi.view(feature_target_semi.size(0), -1)
                preds_target_semi = classifier(feature_target_semi)
                loss += nn.CrossEntropyLoss()(preds_target_semi, label_target_semi)

            loss.backward()
            optimizer.step()
            if epoch > start_epoch:
                optimizer_ad.step()
            if (step) % 100 == 0:
                print('Train Epoch {} closs {:.6f}, dloss {:.6f}, Loss {:.6f}'.format(epoch, cls_loss.item(), d_loss.item(), loss.item()))

    if config['testonly'] == 0:
        best_accuracy = 0
        best_model_index = -1
        for epoch in range(1, config['n_epochs'] + 1):
            train(extractor, classifier, ad_net, config, epoch)
            if epoch % config['TEST_INTERVAL'] == 0:
                # print('test on source_test_loader')
                # test(extractor, classifier, config['source_test_loader'], epoch)
                print('test on target_test_loader')
                accuracy = test(extractor, classifier, config['target_test_loader'], epoch)

                if accuracy > best_accuracy:
                    best_accuracy = accuracy
                    best_model_index = epoch
                    torch.save(extractor.state_dict(), extractor_path)
                    torch.save(classifier.state_dict(), classifier_path)
                    torch.save(ad_net.state_dict(), adnet_path)
                print('epoch {} accuracy: {:.6f}, best accuracy {:.6f} on epoch {}'.format(epoch, accuracy, best_accuracy, best_model_index))


            if epoch % config['VIS_INTERVAL'] == 0:
                title = config['models']
                draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, epoch, title)
                draw_tsne(extractor, classifier, config['source_train_loader'], config['target_test_loader'], res_dir, epoch, title, separate=True)
                # draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, title, separate=False)
    else:
        if os.path.exists(extractor_path) and os.path.exists(classifier_path) and os.path.exists(adnet_path):
            extractor.load_state_dict(torch.load(extractor_path))
            classifier.load_state_dict(torch.load(classifier_path))
            ad_net.load_state_dict(torch.load(adnet_path))
            print('Test only mode, model loaded')

            # print('test on source_test_loader')
            # test(extractor, classifier, config['source_test_loader'], -1)
            print('test on target_test_loader')
            test(extractor, classifier, config['target_test_loader'], -1)

            title = config['models']
            draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, -1, title)
            # draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, -1, title, separate=True)
            draw_tsne(extractor, classifier, config['source_train_loader'], config['target_test_loader'], res_dir, -1, title, separate=True)
        else:
            print('no saved model found')
Esempio n. 3
0
def train_cdan_vat(config):
    extractor = Extractor(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'])
    classifier = Classifier(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], n_class=config['n_class'])
    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()

    xi = 1e-06
    ip = 1
    eps = 15
    vat = VirtualAdversarialPerturbationGenerator(extractor, classifier, xi=xi, eps=eps, ip=ip)

    cdan_random = config['random_layer'] 
    res_dir = os.path.join(config['res_dir'], 'random{}-bs{}-lr{}'.format(cdan_random, config['batch_size'], config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    print('train_cdan')
    print(extractor)
    print(classifier)
    print(config)

    set_log_config(res_dir)
    logging.debug('train_cdan')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(config)

    if config['models'] == 'DANN':
        random_layer = None
        ad_net = AdversarialNetwork(config['n_flattens'], config['n_hiddens'])
    elif cdan_random:
        random_layer = RandomLayer([config['n_flattens'], config['n_class']], config['n_hiddens'])
        ad_net = AdversarialNetwork(config['n_hiddens'], config['n_hiddens'])
        random_layer.cuda()
    else:
        random_layer = None
        ad_net = AdversarialNetwork(config['n_flattens'] * config['n_class'], config['n_hiddens'])
    ad_net = ad_net.cuda()
    optimizer = torch.optim.Adam([
        {'params': extractor.parameters(), 'lr': config['lr']},
        {'params': classifier.parameters(), 'lr': config['lr']}
        ])
    optimizer_ad = torch.optim.Adam(ad_net.parameters(), lr=config['lr'])


    extractor_path = os.path.join(res_dir, "extractor.pth")
    classifier_path = os.path.join(res_dir, "classifier.pth")
    adnet_path = os.path.join(res_dir, "adnet.pth")

    def train(extractor, classifier, ad_net, config, epoch):
        start_epoch = 0

        extractor.train()
        classifier.train()
        ad_net.train()

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        num_iter = len_source_loader
        for step in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, _ = iter_target.next()
            if step % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(), label_source.cuda()
                data_target = data_target.cuda()

            optimizer.zero_grad()
            optimizer_ad.zero_grad()

            h_s = extractor(data_source)
            h_s = h_s.view(h_s.size(0), -1)
            h_t = extractor(data_target)
            h_t = h_t.view(h_t.size(0), -1)

            source_preds = classifier(h_s)
            cls_loss = nn.CrossEntropyLoss()(source_preds, label_source)
            softmax_output_s = nn.Softmax(dim=1)(source_preds)

            target_preds = classifier(h_t)
            softmax_output_t = nn.Softmax(dim=1)(target_preds)

            feature = torch.cat((h_s, h_t), 0)
            softmax_output = torch.cat((softmax_output_s, softmax_output_t), 0)

            if epoch > start_epoch:
                gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1
                if config['models'] == 'CDAN-E':
                    entropy = loss_func.Entropy(softmax_output)
                    d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, entropy, loss_func.calc_coeff(num_iter*(epoch-start_epoch)+step), random_layer)
                elif config['models'] == 'CDAN':
                    d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, None, None, random_layer)
                elif config['models'] == 'DANN':
                    d_loss = loss_func.DANN(feature, ad_net, gamma)
                elif config['models'] == 'CDAN_VAT':
                    # entropy = loss_func.Entropy(softmax_output)
                    # d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, entropy, loss_func.calc_coeff(num_iter*(epoch-start_epoch)+step), random_layer)
                    d_loss = loss_func.CDAN([feature, softmax_output], ad_net, gamma, None, None, random_layer)

                    # vat_loss = loss_func.VAT(vat, data_target, extractor, classifier, target_consistency_criterion)

                    # vat_adv, clean_vat_logits = vat(data_target)
                    # vat_adv_inputs = data_target + vat_adv
                    # adv_vat_features = extractor(vat_adv_inputs)
                    # adv_vat_logits = classifier(adv_vat_features)
                    # target_vat_loss = target_consistency_criterion(adv_vat_logits, clean_vat_logits)
                    # vat_loss = target_vat_loss_weight * target_vat_loss

                else:
                    raise ValueError('Method cannot be recognized.')
            else:
                d_loss = 0

            loss = cls_loss + d_loss
            loss.backward()
            optimizer.step()

            vat_adv, clean_vat_logits = vat(data_target)
            vat_adv_inputs = data_target + vat_adv
            adv_vat_features = extractor(vat_adv_inputs)
            adv_vat_logits = classifier(adv_vat_features)
            target_vat_loss = target_consistency_criterion(adv_vat_logits, clean_vat_logits)
            vat_loss = target_vat_loss_weight * target_vat_loss
            vat_loss.backward()
            # optimizer.step()

            if epoch > start_epoch:
                optimizer_ad.step()
            if (step) % 20 == 0:
                print('Train Epoch {} closs {:.6f}, dloss {:.6f}, vat_loss {:.6f}, Loss {:.6f}'.format(epoch, cls_loss.item(), d_loss.item(), vat_loss.item(), loss.item()))

    if config['testonly'] == 0:
        best_accuracy = 0
        best_model_index = -1
        for epoch in range(1, config['n_epochs'] + 1):
            train(extractor, classifier, ad_net, config, epoch)
            if epoch % config['TEST_INTERVAL'] == 0:
                print('test on source_test_loader')
                test(extractor, classifier, config['source_test_loader'], epoch)
                print('test on target_test_loader')
                accuracy = test(extractor, classifier, config['target_test_loader'], epoch)

                if accuracy > best_accuracy:
                    best_accuracy = accuracy
                    best_model_index = epoch
                    torch.save(extractor.state_dict(), extractor_path)
                    torch.save(classifier.state_dict(), classifier_path)
                    torch.save(ad_net.state_dict(), adnet_path)
                print('epoch {} accuracy: {:.6f}, best accuracy {:.6f} on epoch {}'.format(epoch, accuracy, best_accuracy, best_model_index))


            if epoch % config['VIS_INTERVAL'] == 0:
                title = config['models']
                draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, epoch, title)
                draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, title, separate=True)
                # draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, title, separate=False)
    else:
        if os.path.exists(extractor_path) and os.path.exists(classifier_path) and os.path.exists(adnet_path):
            extractor.load_state_dict(torch.load(extractor_path))
            classifier.load_state_dict(torch.load(classifier_path))
            ad_net.load_state_dict(torch.load(adnet_path))
            print('Test only mode, model loaded')

            print('test on source_test_loader')
            test(extractor, classifier, config['source_test_loader'], -1)
            print('test on target_test_loader')
            test(extractor, classifier, config['target_test_loader'], -1)

            title = config['models']
            draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, -1, title)
            draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, -1, title, separate=True)
        else:
            print('no saved model found')
Esempio n. 4
0
def train_wasserstein(config):
    extractor = Extractor(n_flattens=config['n_flattens'],
                          n_hiddens=config['n_hiddens'])
    # extractor = InceptionV1(num_classes=32)
    classifier = Classifier2(n_flattens=config['n_flattens'],
                             n_hiddens=config['n_hiddens'],
                             n_class=config['n_class'])
    critic = Critic(n_flattens=config['n_flattens'],
                    n_hiddens=config['n_hiddens'])
    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()
        critic = critic.cuda()

    triplet_type = config['triplet_type']
    gamma = config['w_gamma']
    weight_wd = config['w_weight']
    weight_triplet = config['t_weight']
    t_margin = config['t_margin']
    t_confidence = config['t_confidence']
    k_critic = 3
    k_clf = 1
    TRIPLET_START_INDEX = 95

    if triplet_type == 'none':
        res_dir = os.path.join(
            config['res_dir'],
            'bs{}-lr{}-w{}-gamma{}'.format(config['batch_size'], config['lr'],
                                           weight_wd, gamma))
        if not os.path.exists(res_dir):
            os.makedirs(res_dir)
        extractor_path = os.path.join(res_dir, "extractor.pth")
        classifier_path = os.path.join(res_dir, "classifier.pth")
        critic_path = os.path.join(res_dir, "critic.pth")
        EPOCH_START = 1
        TEST_INTERVAL = 10

    else:
        TEST_INTERVAL = 1
        w_dir = os.path.join(
            config['res_dir'],
            'bs{}-lr{}-w{}-gamma{}'.format(config['batch_size'], config['lr'],
                                           weight_wd, gamma))
        if not os.path.exists(w_dir):
            os.makedirs(w_dir)
        res_dir = os.path.join(
            w_dir, '{}_t_weight{}_margin{}_confidence{}'.format(
                triplet_type, weight_triplet, t_margin, t_confidence))
        if not os.path.exists(res_dir):
            os.makedirs(res_dir)
        extractor_path = os.path.join(w_dir, "extractor.pth")
        classifier_path = os.path.join(w_dir, "classifier.pth")
        critic_path = os.path.join(w_dir, "critic.pth")

        if os.path.exists(extractor_path):
            extractor.load_state_dict(torch.load(extractor_path))
            classifier.load_state_dict(torch.load(classifier_path))
            critic.load_state_dict(torch.load(critic_path))
            print('load models')
            EPOCH_START = TRIPLET_START_INDEX
        else:
            EPOCH_START = 1

    set_log_config(res_dir)
    print('start epoch {}'.format(EPOCH_START))
    print('triplet type {}'.format(triplet_type))
    print(config)

    logging.debug('train_wt')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(critic)
    logging.debug(config)

    criterion = torch.nn.CrossEntropyLoss()
    softmax_layer = nn.Softmax(dim=1)

    critic_opt = torch.optim.Adam(critic.parameters(), lr=config['lr'])
    classifier_opt = torch.optim.Adam(classifier.parameters(), lr=config['lr'])
    feature_opt = torch.optim.Adam(extractor.parameters(),
                                   lr=config['lr'] / 10)

    def train(extractor, classifier, critic, config, epoch):
        extractor.train()
        classifier.train()
        critic.train()

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        num_iter = len_source_loader
        for step in range(1, num_iter):
            data_source, label_source = iter_source.next()
            data_target, _ = iter_target.next()
            if step % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target = data_target.cuda()

            # 1. train critic
            set_requires_grad(extractor, requires_grad=False)
            set_requires_grad(classifier, requires_grad=False)
            set_requires_grad(critic, requires_grad=True)
            with torch.no_grad():
                h_s = extractor(data_source)
                h_s = h_s.view(h_s.size(0), -1)
                h_t = extractor(data_target)
                h_t = h_t.view(h_t.size(0), -1)

            for j in range(k_critic):
                gp = gradient_penalty(critic, h_s, h_t)
                critic_s = critic(h_s)
                critic_t = critic(h_t)
                wasserstein_distance = critic_s.mean() - critic_t.mean()
                critic_cost = -wasserstein_distance + gamma * gp

                critic_opt.zero_grad()
                critic_cost.backward()
                critic_opt.step()

                if step == 10 and j == 0:
                    print('EPOCH {}, DISCRIMINATOR: wd {}, gp {}, loss {}'.
                          format(epoch, wasserstein_distance.item(),
                                 (gamma * gp).item(), critic_cost.item()))
                    logging.debug(
                        'EPOCH {}, DISCRIMINATOR: wd {}, gp {}, loss {}'.
                        format(epoch, wasserstein_distance.item(),
                               (gamma * gp).item(), critic_cost.item()))

            # 2. train feature and class_classifier
            set_requires_grad(extractor, requires_grad=True)
            set_requires_grad(classifier, requires_grad=True)
            set_requires_grad(critic, requires_grad=False)
            for _ in range(k_clf):
                h_s = extractor(data_source)
                h_s = h_s.view(h_s.size(0), -1)
                h_t = extractor(data_target)
                h_t = h_t.view(h_t.size(0), -1)

                source_preds, _ = classifier(h_s)
                clf_loss = criterion(source_preds, label_source)
                wasserstein_distance = critic(h_s).mean() - critic(h_t).mean()

                if triplet_type != 'none' and epoch >= TRIPLET_START_INDEX:
                    target_preds, _ = classifier(h_t)
                    target_labels = target_preds.data.max(1)[1]
                    target_logits = softmax_layer(target_preds)
                    if triplet_type == 'all':
                        triplet_index = np.where(
                            target_logits.data.max(1)[0].cpu().numpy() >
                            t_margin)[0]
                        images = torch.cat((h_s, h_t[triplet_index]), 0)
                        labels = torch.cat(
                            (label_source, target_labels[triplet_index]), 0)
                    elif triplet_type == 'src':
                        images = h_s
                        labels = label_source
                    elif triplet_type == 'tgt':
                        triplet_index = np.where(
                            target_logits.data.max(1)[0].cpu().numpy() >
                            t_confidence)[0]
                        images = h_t[triplet_index]
                        labels = target_labels[triplet_index]
                    elif triplet_type == 'sep':
                        triplet_index = np.where(
                            target_logits.data.max(1)[0].cpu().numpy() >
                            t_confidence)[0]
                        images = h_t[triplet_index]
                        labels = target_labels[triplet_index]
                        t_loss_sep, _ = triplet_loss(extractor, {
                            "X": images,
                            "y": labels
                        }, t_confidence)
                        images = h_s
                        labels = label_source

                    t_loss, _ = triplet_loss(extractor, {
                        "X": images,
                        "y": labels
                    }, t_margin)
                    loss = clf_loss + \
                        weight_wd * wasserstein_distance + \
                        weight_triplet * t_loss
                    if triplet_type == 'sep':
                        loss += t_loss_sep
                    feature_opt.zero_grad()
                    classifier_opt.zero_grad()
                    loss.backward()
                    feature_opt.step()
                    classifier_opt.step()

                    if step == 10:
                        print(
                            'EPOCH {}, CLASSIFIER: clf_loss {}, wd {}, t_loss {}, total loss {}'
                            .format(epoch, clf_loss.item(),
                                    weight_wd * wasserstein_distance.item(),
                                    weight_triplet * t_loss.item(),
                                    loss.item()))
                        logging.debug(
                            'EPOCH {}, CLASSIFIER: clf_loss {}, wd {}, t_loss {}, total loss {}'
                            .format(epoch, clf_loss.item(),
                                    weight_wd * wasserstein_distance.item(),
                                    weight_triplet * t_loss.item(),
                                    loss.item()))

                else:
                    loss = clf_loss + weight_wd * wasserstein_distance
                    feature_opt.zero_grad()
                    classifier_opt.zero_grad()
                    loss.backward()
                    feature_opt.step()
                    classifier_opt.step()

                    if step == 10:
                        print(
                            'EPOCH {}, CLASSIFIER: clf_loss {}, wd {},  loss {}'
                            .format(epoch, clf_loss.item(),
                                    weight_wd * wasserstein_distance.item(),
                                    loss.item()))
                        logging.debug(
                            'EPOCH {}, CLASSIFIER: clf_loss {}, wd {},  loss {}'
                            .format(epoch, clf_loss.item(),
                                    weight_wd * wasserstein_distance.item(),
                                    loss.item()))

    # pretrain(model, config, pretrain_epochs=20)
    for epoch in range(EPOCH_START, config['n_epochs'] + 1):
        train(extractor, classifier, critic, config, epoch)
        if epoch % TEST_INTERVAL == 0:
            # print('test on source_test_loader')
            # test(extractor, classifier, config['source_test_loader'], epoch)
            # print('test on target_train_loader')
            # test(model, config['target_train_loader'], epoch)
            print('test on target_test_loader')
            test(extractor, classifier, config['target_test_loader'], epoch)
        if epoch % config['VIS_INTERVAL'] == 0:
            if triplet_type == 'none':
                title = '(a) WDGRL'
            else:
                title = '(b) TLADA'
            draw_confusion_matrix(extractor, classifier,
                                  config['target_test_loader'], res_dir, epoch,
                                  title)
            draw_tsne(extractor,
                      classifier,
                      config['source_train_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      title,
                      separate=True)
            # draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, title, separate=False)
    if triplet_type == 'none':
        torch.save(extractor.state_dict(), extractor_path)
        torch.save(classifier.state_dict(), classifier_path)
        torch.save(critic.state_dict(), critic_path)