Пример #1
0
def train_cnn(config):
    if config['network'] == 'inceptionv1':
        extractor = InceptionV1(num_classes=32, dilation=config['dilation'])
    elif config['network'] == 'inceptionv1s':
        extractor = InceptionV1s(num_classes=32, dilation=config['dilation'])
    else:
        extractor = Extractor(n_flattens=config['n_flattens'],
                              n_hiddens=config['n_hiddens'])
    classifier = Classifier2(n_flattens=config['n_flattens'],
                             n_hiddens=config['n_hiddens'],
                             n_class=config['n_class'])

    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()

    res_dir = os.path.join(
        config['res_dir'],
        'slim{}-snr{}-snrp{}-lr{}'.format(config['slim'], config['snr'],
                                          config['snrp'], config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    set_log_config(res_dir)
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(config)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.Adam(list(extractor.parameters()) +
                           list(classifier.parameters()),
                           lr=config['lr'])

    def train(extractor, classifier, config, epoch):
        extractor.train()
        classifier.train()

        for step, (features,
                   labels) in enumerate(config['source_train_loader']):
            if torch.cuda.is_available():
                features, labels = features.cuda(), labels.cuda()

            optimizer.zero_grad()
            # if config['aux_classifier'] == 1:
            #     x1, x2, x3 = extractor(features)
            #     preds = classifier(x1, x2, x3)
            preds, _ = classifier(extractor(features))
            # print('preds {}, labels {}'.format(preds.shape, labels.shape))
            # print(preds[0])
            # preds_l = F.softmax(preds, dim=1)
            # print('preds_l {}'.format(preds_l.shape))
            # print(preds_l[0])
            # print('------')

            loss = criterion(preds, labels)
            loss.backward()
            optimizer.step()

    for epoch in range(1, config['n_epochs'] + 1):
        train(extractor, classifier, config, epoch)
        if epoch % config['TEST_INTERVAL'] == 0:
            print('test on source_test_loader')
            test(extractor, classifier, config['source_test_loader'], epoch)
            print('test on target_test_loader')
            test(extractor, classifier, config['target_test_loader'], epoch)
        if epoch % config['VIS_INTERVAL'] == 0:
            draw_confusion_matrix(extractor, classifier,
                                  config['target_test_loader'], res_dir, epoch,
                                  config['models'])
            draw_tsne(extractor,
                      classifier,
                      config['source_test_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      config['models'],
                      separate=True)
            draw_tsne(extractor,
                      classifier,
                      config['source_test_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      config['models'],
                      separate=False)
Пример #2
0
def train_dann(config):
    if config['network'] == 'inceptionv1':
        extractor = InceptionV1(num_classes=32, dilation=config['dilation'])
    elif config['network'] == 'inceptionv1s':
        extractor = InceptionV1s(num_classes=32, dilation=config['dilation'])
    else:
        extractor = Extractor(n_flattens=config['n_flattens'],
                              n_hiddens=config['n_hiddens'])
    classifier = Classifier2(n_flattens=config['n_flattens'],
                             n_hiddens=config['n_hiddens'],
                             n_class=config['n_class'])

    critic = Critic2(n_flattens=config['n_flattens'],
                     n_hiddens=config['n_hiddens'])
    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()
        critic = critic.cuda()

    criterion = torch.nn.CrossEntropyLoss()
    loss_class = torch.nn.CrossEntropyLoss()
    loss_domain = torch.nn.CrossEntropyLoss()

    res_dir = os.path.join(
        config['res_dir'],
        'VIS-slim{}-targetLabel{}-mmd{}-bnm{}-vat{}-ent{}-ew{}-bn{}-bs{}-lr{}'.
        format(config['slim'], config['target_labeling'], config['mmd'],
               config['bnm'], config['vat'], config['ent'], config['bnm_ew'],
               config['bn'], config['batch_size'], config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    set_log_config(res_dir)
    logging.debug('train_dann')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(critic)
    logging.debug(config)

    optimizer = optim.Adam([{
        'params': extractor.parameters()
    }, {
        'params': classifier.parameters()
    }, {
        'params': critic.parameters()
    }],
                           lr=config['lr'])

    vat_loss = VAT(extractor, classifier, n_power=1, radius=3.5).cuda()

    def dann(input_data, alpha):
        feature = extractor(input_data)
        feature = feature.view(feature.size(0), -1)
        reverse_feature = ReverseLayerF.apply(feature, alpha)
        class_output, _ = classifier(feature)
        domain_output = critic(reverse_feature)

        return class_output, domain_output, feature

    def train(extractor, classifier, critic, config, epoch):
        extractor.train()
        classifier.train()
        critic.train()

        gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1
        mmd_loss = MMD_loss()

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        if config['slim'] > 0:
            iter_target_semi = iter(config['target_train_semi_loader'])
            len_target_semi_loader = len(config['target_train_semi_loader'])

        num_iter = len_source_loader
        for i in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, label_target = iter_target.next()
            if config['slim'] > 0:
                data_target_semi, label_target_semi = iter_target_semi.next()
                if i % len_target_semi_loader == 0:
                    iter_target_semi = iter(config['target_train_semi_loader'])

            if i % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target, label_target = data_target.cuda(
                ), label_target.cuda()
                if config['slim'] > 0:
                    data_target_semi, label_target_semi = data_target_semi.cuda(
                    ), label_target_semi.cuda()

            optimizer.zero_grad()

            class_output_s, domain_output, feature_s = dann(
                input_data=data_source, alpha=gamma)
            # print('domain_output {}'.format(domain_output.size()))
            err_s_label = loss_class(class_output_s, label_source)
            domain_label = torch.zeros(data_source.size(0)).long().cuda()
            err_s_domain = loss_domain(domain_output, domain_label)

            # Training model using target data
            domain_label = torch.ones(data_target.size(0)).long().cuda()
            class_output_t, domain_output, feature_t = dann(
                input_data=data_target, alpha=gamma)
            #class_output_t, domain_output, _ = dann(input_data=data_target, alpha=0.5)
            err_t_domain = loss_domain(domain_output, domain_label)

            err = err_s_label + err_s_domain + err_t_domain

            # if config['target_labeling'] == 1:
            #     err_t_class_healthy = nn.CrossEntropyLoss()(class_output_t, label_target)
            #     err += err_t_class_healthy
            #     if i % 100 == 0:
            #         print('err_t_class_healthy {:.2f}'.format(err_t_class_healthy.item()))

            if config['mmd'] == 1:
                #err +=  gamma * mmd_linear(feature_s, feature_t)
                err += config['bnm_ew'] * mmd_loss(feature_s, feature_t)

            if config['bnm'] == 1 and epoch >= config['startiter']:
                err_t_bnm = config['bnm_ew'] * get_loss_bnm(class_output_t)
                err += err_t_bnm
                if i == 1:
                    print('epoch {}, loss_t_bnm {:.2f}'.format(
                        epoch, err_t_bnm.item()))

            if config['ent'] == 1 and epoch >= config['startiter']:
                err_t_ent = config['bnm_ew'] * get_loss_entropy(class_output_t)
                err += err_t_ent
                if i == 1:
                    print('epoch {}, loss_t_ent {:.2f}'.format(
                        epoch, err_t_ent.item()))

            if config['vat'] == 1 and epoch >= config['startiter']:
                err_t_vat = config['bnm_ew'] * vat_loss(
                    data_target, class_output_t)
                err += err_t_vat
                if i == 1:
                    print('epoch {}, loss_t_vat {:.2f}'.format(
                        epoch, err_t_vat.item()))

            if config['slim'] > 0:
                feature_target_semi = extractor(data_target_semi)
                feature_target_semi = feature_target_semi.view(
                    feature_target_semi.size(0), -1)
                preds_target_semi, _ = classifier(feature_target_semi)
                err_t_class_semi = loss_class(preds_target_semi,
                                              label_target_semi)
                err += err_t_class_semi
                if i == 1:
                    print('epoch {}, err_t_class_semi {:.2f}'.format(
                        epoch, err_t_class_semi.item()))

            if i == 1:
                print(
                    'epoch {}, err_s_label {:.2f}, err_s_domain {:.2f}, err_t_domain {:.2f}, total err {:.2f}'
                    .format(epoch, err_s_label.item(), err_s_domain.item(),
                            err_t_domain.item(), err.item()))

            err.backward()
            optimizer.step()

    for epoch in range(1, config['n_epochs'] + 1):
        train(extractor, classifier, critic, config, epoch)
        if epoch % config['TEST_INTERVAL'] == 0:
            # print('test on source_test_loader')
            # test(extractor, classifier, config['source_test_loader'], epoch)
            print('test on target_test_loader')
            test(extractor, classifier, config['target_test_loader'], epoch)
        if epoch % config['VIS_INTERVAL'] == 0:
            title = 'DANN'
            if config['bnm'] == 1 and config['vat'] == 1:
                title = '(b) Proposed'
            elif config['bnm'] == 1:
                title = 'BNM'
            elif config['vat'] == 1:
                title = 'VADA'
            elif config['mmd'] == 1:
                title = 'DCTLN'
            elif config['ent'] == 1:
                title = 'EntMin'
            # draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, epoch, config['models'])
            draw_tsne(extractor,
                      classifier,
                      config['source_train_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      title,
                      separate=True)
            draw_tsne(extractor,
                      classifier,
                      config['source_train_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      title,
                      separate=False)
Пример #3
0
def train_dann_vat(config):
    extractor = Extractor(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'])
    classifier = Classifier2(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'], n_class=config['n_class'])
    critic = Critic2(n_flattens=config['n_flattens'], n_hiddens=config['n_hiddens'])
    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()
        critic = critic.cuda()

    res_dir = os.path.join(config['res_dir'], 'BNM-slim{}-targetLabel{}-sw{}-tw{}-ew{}-lr{}'.format(config['slim'], config['target_labeling'], sw, tw, ew, config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    criterion = torch.nn.CrossEntropyLoss()
    loss_class = torch.nn.CrossEntropyLoss()
    loss_domain = torch.nn.CrossEntropyLoss()
    cent = ConditionalEntropyLoss().cuda()
    vat_loss = VAT(extractor, classifier, n_power=1, radius=3.5).cuda()


    print('train_dann_vat')
    # print(extractor)
    # print(classifier)
    # print(critic)
    print(config)

    set_log_config(res_dir)
    logging.debug('train_dann_vat')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(critic)
    logging.debug(config)

    optimizer = optim.Adam([{'params': extractor.parameters()},
        {'params': classifier.parameters()},
        {'params': critic.parameters()}],
        lr=config['lr'])

    def dann(input_data, alpha):
        feature = extractor(input_data)
        feature = feature.view(feature.size(0), -1)
        reverse_feature = ReverseLayerF.apply(feature, alpha)
        class_output = classifier(feature)
        domain_output = critic(reverse_feature)

        return class_output, domain_output, feature


    def train(extractor, classifier, critic, config, epoch):
        extractor.train()
        classifier.train()
        critic.train()

        gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        if config['slim'] > 0:
            iter_target_semi = iter(config['target_train_semi_loader'])
            len_target_semi_loader = len(config['target_train_semi_loader'])

        num_iter = len_source_loader
        for i in range(1, num_iter+1):
            data_source, label_source = iter_source.next()
            data_target, label_target = iter_target.next()
            if config['slim'] > 0:
                data_target_semi, label_target_semi = iter_target_semi.next()
                if i % len_target_semi_loader == 0:
                    iter_target_semi = iter(config['target_train_semi_loader'])

            if i % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(), label_source.cuda()
                data_target, label_target = data_target.cuda(), label_target.cuda()
                if config['slim'] > 0:
                    data_target_semi, label_target_semi = data_target_semi.cuda(), label_target_semi.cuda()

            optimizer.zero_grad()

            class_output_s, domain_output_s, _ = dann(input_data=data_source, alpha=gamma)
            err_s_class = loss_class(class_output_s, label_source)
            domain_label_s = torch.zeros(data_source.size(0)).long().cuda()
            err_s_domain = loss_domain(domain_output_s, domain_label_s)

            # Training model using target data
            class_output_t, domain_output_t, _ = dann(input_data=data_target, alpha=gamma)
            domain_label_t = torch.ones(data_target.size(0)).long().cuda()
            err_t_domain = loss_domain(domain_output_t, domain_label_t)
            # target entropy loss
            # err_t_entropy = get_loss_entropy(class_output_t)
            err_t_entropy = get_loss_bnm(class_output_t)

            # virtual adversarial loss.
            err_s_vat = vat_loss(data_source, class_output_s)
            err_t_vat = vat_loss(data_target, class_output_t)

            err_domain = 1 * (err_s_domain + err_t_domain)

            err_all = (
                    dw * err_domain +
                    cw * err_s_class +
                    sw * err_s_vat +
                    tw * err_t_vat + 
                    ew * err_t_entropy
            )

            if config['slim'] > 0:
                feature_target_semi = extractor(data_target_semi)
                feature_target_semi = feature_target_semi.view(feature_target_semi.size(0), -1)
                preds_target_semi = classifier(feature_target_semi)
                err_t_class_semi = loss_class(preds_target_semi, label_target_semi)
                err_all += err_t_class_semi
                # if i % 100 == 0:
                #     print('epoch {}, err_t_class_semi {:.2f}'.format(epoch, err_t_class_semi.item()))

            # if config['target_labeling'] == 1:
            #     err_t_class_healthy = nn.CrossEntropyLoss()(class_output_t, label_target)
            #     err_all += err_t_class_healthy
                # if i % 100 == 0:
                    # print('err_t_class_healthy {:.2f}'.format(err_t_class_healthy.item()))

            # if i % 100 == 0:
            #     print('err_s_class {:.2f}, err_s_domain {:.2f}, gamma {:.2f}, err_t_domain {:.2f}, err_s_vat {:.2f}, err_t_vat {:.2f}, err_t_entropy {:.2f}, err_all {:.2f}'.format(err_s_class.item(), 
            #                                     err_s_domain.item(), 
            #                                     gamma, 
            #                                     err_t_domain.item(),
            #                                     err_s_vat.item(),
            #                                     err_t_vat.item(),
            #                                     err_t_entropy.item(),
            #                                     err_all.item()))

            err_all.backward()
            optimizer.step()


    for epoch in range(1, config['n_epochs'] + 1):
        print('epoch {}'.format(epoch))
        print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
        train(extractor, classifier, critic, config, epoch)
        if epoch % config['TEST_INTERVAL'] == 0:
            print('test on source_test_loader')
            test(extractor, classifier, config['source_test_loader'], epoch)
            print('test on target_test_loader')
            test(extractor, classifier, config['target_test_loader'], epoch)

        if epoch % config['VIS_INTERVAL'] == 0:
            draw_confusion_matrix(extractor, classifier, config['target_test_loader'], res_dir, epoch, config['models'])
            draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, config['models'], separate=True)
Пример #4
0
def train_ddc(config):
    if config['network'] == 'inceptionv1':
        extractor = InceptionV1(num_classes=32, dilation=config['dilation'])
    elif config['network'] == 'inceptionv1s':
        extractor = InceptionV1s(num_classes=32, dilation=config['dilation'])
    else:
        extractor = Extractor(n_flattens=config['n_flattens'],
                              n_hiddens=config['n_hiddens'])
    classifier = Classifier2(n_flattens=config['n_flattens'],
                             n_hiddens=config['n_hiddens'],
                             n_class=config['n_class'])

    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()

    res_dir = os.path.join(
        config['res_dir'],
        'slim{}-targetLabel{}-normal{}-{}-dilation{}-lr{}-mmdgamma{}'.format(
            config['normal'], config['slim'], config['target_labeling'],
            config['network'], config['dilation'], config['lr'],
            config['mmd_gamma']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    criterion = torch.nn.CrossEntropyLoss()

    set_log_config(res_dir)
    logging.debug('train_ddc')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(config)

    optimizer = optim.Adam(list(extractor.parameters()) +
                           list(classifier.parameters()),
                           lr=config['lr'])

    def train(extractor, classifier, config, epoch):
        extractor.train()
        classifier.train()

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        if config['slim'] > 0:
            iter_target_semi = iter(config['target_train_semi_loader'])
            len_target_semi_loader = len(config['target_train_semi_loader'])

        num_iter = len_source_loader
        for i in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, _ = iter_target.next()
            if config['slim'] > 0:
                data_target_semi, label_target_semi = iter_target_semi.next()
                if i % len_target_semi_loader == 0:
                    iter_target_semi = iter(config['target_train_semi_loader'])

            if i % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target = data_target.cuda()
                if config['slim'] > 0:
                    data_target_semi, label_target_semi = data_target_semi.cuda(
                    ), label_target_semi.cuda()

            optimizer.zero_grad()

            source = extractor(data_source)
            source = source.view(source.size(0), -1)
            target = extractor(data_target)
            target = target.view(target.size(0), -1)

            preds, _ = classifier(source)
            loss_cls = criterion(preds, label_source)

            loss_mmd = mmd_linear(source, target)

            #gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1
            loss = loss_cls + config['mmd_gamma'] * loss_mmd

            if config['slim'] > 0:
                feature_target_semi = extractor(data_target_semi)
                feature_target_semi = feature_target_semi.view(
                    feature_target_semi.size(0), -1)
                preds_target_semi, _ = classifier(feature_target_semi)
                loss += criterion(preds_target_semi, label_target_semi)

            if i % 50 == 0:
                print(
                    'loss_cls {}, loss_mmd {}, gamma {}, total loss {}'.format(
                        loss_cls.item(), loss_mmd.item(), config['mmd_gamma'],
                        loss.item()))
            loss.backward()
            optimizer.step()

    for epoch in range(1, config['n_epochs'] + 1):
        train(extractor, classifier, config, epoch)
        if epoch % config['TEST_INTERVAL'] == 0:
            print('test on source_test_loader')
            test(extractor, classifier, config['source_test_loader'], epoch)
            print('test on target_test_loader')
            test(extractor, classifier, config['target_test_loader'], epoch)
        if epoch % config['VIS_INTERVAL'] == 0:
            draw_confusion_matrix(extractor, classifier,
                                  config['target_test_loader'], res_dir, epoch,
                                  config['models'])
            draw_tsne(extractor,
                      classifier,
                      config['source_test_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      config['models'],
                      separate=True)
            draw_tsne(extractor,
                      classifier,
                      config['source_test_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      config['models'],
                      separate=False)
Пример #5
0
def train_wasserstein(config):
    extractor = Extractor(n_flattens=config['n_flattens'],
                          n_hiddens=config['n_hiddens'])
    # extractor = InceptionV1(num_classes=32)
    classifier = Classifier2(n_flattens=config['n_flattens'],
                             n_hiddens=config['n_hiddens'],
                             n_class=config['n_class'])
    critic = Critic(n_flattens=config['n_flattens'],
                    n_hiddens=config['n_hiddens'])
    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()
        critic = critic.cuda()

    triplet_type = config['triplet_type']
    gamma = config['w_gamma']
    weight_wd = config['w_weight']
    weight_triplet = config['t_weight']
    t_margin = config['t_margin']
    t_confidence = config['t_confidence']
    k_critic = 3
    k_clf = 1
    TRIPLET_START_INDEX = 95

    if triplet_type == 'none':
        res_dir = os.path.join(
            config['res_dir'],
            'bs{}-lr{}-w{}-gamma{}'.format(config['batch_size'], config['lr'],
                                           weight_wd, gamma))
        if not os.path.exists(res_dir):
            os.makedirs(res_dir)
        extractor_path = os.path.join(res_dir, "extractor.pth")
        classifier_path = os.path.join(res_dir, "classifier.pth")
        critic_path = os.path.join(res_dir, "critic.pth")
        EPOCH_START = 1
        TEST_INTERVAL = 10

    else:
        TEST_INTERVAL = 1
        w_dir = os.path.join(
            config['res_dir'],
            'bs{}-lr{}-w{}-gamma{}'.format(config['batch_size'], config['lr'],
                                           weight_wd, gamma))
        if not os.path.exists(w_dir):
            os.makedirs(w_dir)
        res_dir = os.path.join(
            w_dir, '{}_t_weight{}_margin{}_confidence{}'.format(
                triplet_type, weight_triplet, t_margin, t_confidence))
        if not os.path.exists(res_dir):
            os.makedirs(res_dir)
        extractor_path = os.path.join(w_dir, "extractor.pth")
        classifier_path = os.path.join(w_dir, "classifier.pth")
        critic_path = os.path.join(w_dir, "critic.pth")

        if os.path.exists(extractor_path):
            extractor.load_state_dict(torch.load(extractor_path))
            classifier.load_state_dict(torch.load(classifier_path))
            critic.load_state_dict(torch.load(critic_path))
            print('load models')
            EPOCH_START = TRIPLET_START_INDEX
        else:
            EPOCH_START = 1

    set_log_config(res_dir)
    print('start epoch {}'.format(EPOCH_START))
    print('triplet type {}'.format(triplet_type))
    print(config)

    logging.debug('train_wt')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(critic)
    logging.debug(config)

    criterion = torch.nn.CrossEntropyLoss()
    softmax_layer = nn.Softmax(dim=1)

    critic_opt = torch.optim.Adam(critic.parameters(), lr=config['lr'])
    classifier_opt = torch.optim.Adam(classifier.parameters(), lr=config['lr'])
    feature_opt = torch.optim.Adam(extractor.parameters(),
                                   lr=config['lr'] / 10)

    def train(extractor, classifier, critic, config, epoch):
        extractor.train()
        classifier.train()
        critic.train()

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        num_iter = len_source_loader
        for step in range(1, num_iter):
            data_source, label_source = iter_source.next()
            data_target, _ = iter_target.next()
            if step % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target = data_target.cuda()

            # 1. train critic
            set_requires_grad(extractor, requires_grad=False)
            set_requires_grad(classifier, requires_grad=False)
            set_requires_grad(critic, requires_grad=True)
            with torch.no_grad():
                h_s = extractor(data_source)
                h_s = h_s.view(h_s.size(0), -1)
                h_t = extractor(data_target)
                h_t = h_t.view(h_t.size(0), -1)

            for j in range(k_critic):
                gp = gradient_penalty(critic, h_s, h_t)
                critic_s = critic(h_s)
                critic_t = critic(h_t)
                wasserstein_distance = critic_s.mean() - critic_t.mean()
                critic_cost = -wasserstein_distance + gamma * gp

                critic_opt.zero_grad()
                critic_cost.backward()
                critic_opt.step()

                if step == 10 and j == 0:
                    print('EPOCH {}, DISCRIMINATOR: wd {}, gp {}, loss {}'.
                          format(epoch, wasserstein_distance.item(),
                                 (gamma * gp).item(), critic_cost.item()))
                    logging.debug(
                        'EPOCH {}, DISCRIMINATOR: wd {}, gp {}, loss {}'.
                        format(epoch, wasserstein_distance.item(),
                               (gamma * gp).item(), critic_cost.item()))

            # 2. train feature and class_classifier
            set_requires_grad(extractor, requires_grad=True)
            set_requires_grad(classifier, requires_grad=True)
            set_requires_grad(critic, requires_grad=False)
            for _ in range(k_clf):
                h_s = extractor(data_source)
                h_s = h_s.view(h_s.size(0), -1)
                h_t = extractor(data_target)
                h_t = h_t.view(h_t.size(0), -1)

                source_preds, _ = classifier(h_s)
                clf_loss = criterion(source_preds, label_source)
                wasserstein_distance = critic(h_s).mean() - critic(h_t).mean()

                if triplet_type != 'none' and epoch >= TRIPLET_START_INDEX:
                    target_preds, _ = classifier(h_t)
                    target_labels = target_preds.data.max(1)[1]
                    target_logits = softmax_layer(target_preds)
                    if triplet_type == 'all':
                        triplet_index = np.where(
                            target_logits.data.max(1)[0].cpu().numpy() >
                            t_margin)[0]
                        images = torch.cat((h_s, h_t[triplet_index]), 0)
                        labels = torch.cat(
                            (label_source, target_labels[triplet_index]), 0)
                    elif triplet_type == 'src':
                        images = h_s
                        labels = label_source
                    elif triplet_type == 'tgt':
                        triplet_index = np.where(
                            target_logits.data.max(1)[0].cpu().numpy() >
                            t_confidence)[0]
                        images = h_t[triplet_index]
                        labels = target_labels[triplet_index]
                    elif triplet_type == 'sep':
                        triplet_index = np.where(
                            target_logits.data.max(1)[0].cpu().numpy() >
                            t_confidence)[0]
                        images = h_t[triplet_index]
                        labels = target_labels[triplet_index]
                        t_loss_sep, _ = triplet_loss(extractor, {
                            "X": images,
                            "y": labels
                        }, t_confidence)
                        images = h_s
                        labels = label_source

                    t_loss, _ = triplet_loss(extractor, {
                        "X": images,
                        "y": labels
                    }, t_margin)
                    loss = clf_loss + \
                        weight_wd * wasserstein_distance + \
                        weight_triplet * t_loss
                    if triplet_type == 'sep':
                        loss += t_loss_sep
                    feature_opt.zero_grad()
                    classifier_opt.zero_grad()
                    loss.backward()
                    feature_opt.step()
                    classifier_opt.step()

                    if step == 10:
                        print(
                            'EPOCH {}, CLASSIFIER: clf_loss {}, wd {}, t_loss {}, total loss {}'
                            .format(epoch, clf_loss.item(),
                                    weight_wd * wasserstein_distance.item(),
                                    weight_triplet * t_loss.item(),
                                    loss.item()))
                        logging.debug(
                            'EPOCH {}, CLASSIFIER: clf_loss {}, wd {}, t_loss {}, total loss {}'
                            .format(epoch, clf_loss.item(),
                                    weight_wd * wasserstein_distance.item(),
                                    weight_triplet * t_loss.item(),
                                    loss.item()))

                else:
                    loss = clf_loss + weight_wd * wasserstein_distance
                    feature_opt.zero_grad()
                    classifier_opt.zero_grad()
                    loss.backward()
                    feature_opt.step()
                    classifier_opt.step()

                    if step == 10:
                        print(
                            'EPOCH {}, CLASSIFIER: clf_loss {}, wd {},  loss {}'
                            .format(epoch, clf_loss.item(),
                                    weight_wd * wasserstein_distance.item(),
                                    loss.item()))
                        logging.debug(
                            'EPOCH {}, CLASSIFIER: clf_loss {}, wd {},  loss {}'
                            .format(epoch, clf_loss.item(),
                                    weight_wd * wasserstein_distance.item(),
                                    loss.item()))

    # pretrain(model, config, pretrain_epochs=20)
    for epoch in range(EPOCH_START, config['n_epochs'] + 1):
        train(extractor, classifier, critic, config, epoch)
        if epoch % TEST_INTERVAL == 0:
            # print('test on source_test_loader')
            # test(extractor, classifier, config['source_test_loader'], epoch)
            # print('test on target_train_loader')
            # test(model, config['target_train_loader'], epoch)
            print('test on target_test_loader')
            test(extractor, classifier, config['target_test_loader'], epoch)
        if epoch % config['VIS_INTERVAL'] == 0:
            if triplet_type == 'none':
                title = '(a) WDGRL'
            else:
                title = '(b) TLADA'
            draw_confusion_matrix(extractor, classifier,
                                  config['target_test_loader'], res_dir, epoch,
                                  title)
            draw_tsne(extractor,
                      classifier,
                      config['source_train_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      title,
                      separate=True)
            # draw_tsne(extractor, classifier, config['source_test_loader'], config['target_test_loader'], res_dir, epoch, title, separate=False)
    if triplet_type == 'none':
        torch.save(extractor.state_dict(), extractor_path)
        torch.save(classifier.state_dict(), classifier_path)
        torch.save(critic.state_dict(), critic_path)
Пример #6
0
def train_dctln(config):
    if config['network'] == 'inceptionv1':
        extractor = InceptionV1(num_classes=32, dilation=config['dilation'])
    elif config['network'] == 'inceptionv1s':
        extractor = InceptionV1s(num_classes=32, dilation=config['dilation'])
    else:
        extractor = Extractor(n_flattens=config['n_flattens'],
                              n_hiddens=config['n_hiddens'])
    classifier = Classifier2(n_flattens=config['n_flattens'],
                             n_hiddens=config['n_hiddens'],
                             n_class=config['n_class'])

    critic = Critic2(n_flattens=config['n_flattens'],
                     n_hiddens=config['n_hiddens'])
    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()
        critic = critic.cuda()

    criterion = torch.nn.CrossEntropyLoss()
    loss_class = torch.nn.CrossEntropyLoss()
    loss_domain = torch.nn.CrossEntropyLoss()

    res_dir = os.path.join(
        config['res_dir'],
        'slim{}-snr{}-lr{}'.format(config['slim'], config['snr'],
                                   config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    set_log_config(res_dir)
    logging.debug('train_dann')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(critic)
    logging.debug(config)

    optimizer = optim.Adam([{
        'params': extractor.parameters()
    }, {
        'params': classifier.parameters()
    }, {
        'params': critic.parameters()
    }],
                           lr=config['lr'])

    def dann(input_data, alpha):
        feature = extractor(input_data)
        feature = feature.view(feature.size(0), -1)
        reverse_feature = ReverseLayerF.apply(feature, alpha)
        class_output, _ = classifier(feature)
        domain_output = critic(reverse_feature)

        return class_output, domain_output, feature

    def train(extractor, classifier, critic, config, epoch):
        extractor.train()
        classifier.train()
        critic.train()

        gamma = 2 / (1 + math.exp(-10 * (epoch) / config['n_epochs'])) - 1

        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        if config['slim'] > 0:
            iter_target_semi = iter(config['target_train_semi_loader'])
            len_target_semi_loader = len(config['target_train_semi_loader'])

        num_iter = len_source_loader
        for i in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, _ = iter_target.next()
            if config['slim'] > 0:
                data_target_semi, label_target_semi = iter_target_semi.next()

            if i % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if config['slim'] > 0:
                if i % len_target_semi_loader == 0:
                    iter_target_semi = iter(config['target_train_semi_loader'])

            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target = data_target.cuda()
                if config['slim'] > 0:
                    data_target_semi, label_target_semi = data_target_semi.cuda(
                    ), label_target_semi.cuda()

            optimizer.zero_grad()

            source = extractor(data_source)
            source = source.view(source.size(0), -1)
            target = extractor(data_target)
            target = target.view(target.size(0), -1)
            # loss_mmd = mmd_linear(source, target)

            class_output_s, domain_output, _ = dann(input_data=data_source,
                                                    alpha=gamma)
            # print('domain_output {}'.format(domain_output.size()))
            err_s_label = loss_class(class_output_s, label_source)
            domain_label = torch.zeros(data_source.size(0)).long().cuda()
            err_s_domain = loss_domain(domain_output, domain_label)

            # Training model using target data
            domain_label = torch.ones(data_target.size(0)).long().cuda()
            class_output_t, domain_output, _ = dann(input_data=data_target,
                                                    alpha=gamma)
            err_t_domain = loss_domain(domain_output, domain_label)

            # err = 1.0*err_s_label + err_s_domain + err_t_domain + 0*loss_mmd + err_t_label
            err = 1.0 * err_s_label + err_s_domain + err_t_domain

            if config['slim'] > 0:
                class_output_semi_t, _, _ = dann(input_data=data_target_semi,
                                                 alpha=gamma)
                err_t_label = loss_class(class_output_semi_t,
                                         label_target_semi)
                err += err_t_label

            # if i % 200 == 0:
            #     # print('err_s_label {}, err_s_domain {}, gamma {}, err_t_domain {}, loss_mmd {}, total err {}'.format(err_s_label.item(), err_s_domain.item(), gamma, err_t_domain.item(), loss_mmd.item(), err.item()))
            #     print('err_s_label {:.2f}, err_t_label {:.2f}, err_s_domain {:.2f}, gamma {:.2f}, err_t_domain {:.2f}, total err {:.2f}'.format(err_s_label.item(), err_t_label.item(), err_s_domain.item(), gamma, err_t_domain.item(), err.item()))

            err.backward()
            optimizer.step()

    for epoch in range(1, config['n_epochs'] + 1):
        train(extractor, classifier, critic, config, epoch)
        if epoch % config['TEST_INTERVAL'] == 0:
            # print('test on source_test_loader')
            # test(extractor, classifier, config['source_test_loader'], epoch)
            print('test on target_test_loader')
            test(extractor, classifier, config['target_test_loader'], epoch)
        if epoch % config['VIS_INTERVAL'] == 0:
            draw_confusion_matrix(extractor, classifier,
                                  config['target_test_loader'], res_dir, epoch,
                                  config['models'])
            draw_tsne(extractor,
                      classifier,
                      config['source_train_loader'],
                      config['target_test_loader'],
                      res_dir,
                      epoch,
                      config['models'],
                      separate=True)