Ejemplo n.º 1
0
def eval_src_threemodal(model1, model2, model3, data_loader):
    """Evaluation for target encoder by source classifier on target dataset."""
    # set eval state for Dropout and BN layers
    model1.eval()
    model2.eval()
    model3.eval()

    # init loss and accuracy
    loss = 0
    acc = 0

    # set loss function
    criterion = nn.CrossEntropyLoss()

    TP = 0.
    TN = 0.
    FP = 0.
    FN = 0.

    # for (images, labels) in data_loader:
    for (image1, image2, image3, label) in data_loader:

        img1_name = image1[0][0]
        img1 = image1[1]
        img2_name = image2[0][0]
        img2 = image2[1]
        img3_name = image3[0][0]
        img3 = image3[1]
        image1 = make_variable(img1)
        image2 = make_variable(img2)
        image3 = make_variable(img3)

        preds1 = model1(image1)
        preds2 = model2(image2)
        preds3 = model3(image3)

        probability1 = torch.nn.functional.softmax(preds1,
                                                   dim=1)[:,
                                                          1].detach().tolist()
        probability2 = torch.nn.functional.softmax(preds2,
                                                   dim=1)[:,
                                                          1].detach().tolist()
        probability3 = torch.nn.functional.softmax(preds3,
                                                   dim=1)[:,
                                                          1].detach().tolist()

        probability1_value = np.array(probability1)
        probability2_value = np.array(probability2)
        probability3_value = np.array(probability3)

        probability = (probability1_value[0] + probability1_value[0] +
                       probability1_value[0]) / 3

        # print("{} {} {} {:.8f} {:.8f} {:.8f}".format(img1_name, img2_name, img3_name,
        #                                              probability1_value[0], probability2_value[0],
        #                                              probability3_value[0]))

        print("{} {} {} {:.8f}".format(img1_name, img2_name, img3_name,
                                       probability2_value[0]))
Ejemplo n.º 2
0
def eval_fusionmodal(model, data_loader):
    """Evaluation for target encoder by source classifier on target dataset."""
    # set eval state for Dropout and BN layers
    model.eval()

    # init loss and accuracy
    loss = 0
    acc = 0

    # set loss function
    criterion = nn.CrossEntropyLoss()

    TP = 0.
    TN = 0.
    FP = 0.
    FN = 0.
    data_write = open("./test_data.txt", 'a+')

    # for (images, labels) in data_loader:
    for image_galley in data_loader:
        img1_name = image_galley[0][0][0]
        img1 = image_galley[0][1]
        img2_name = image_galley[1][0][0]
        img2 = image_galley[1][1]
        img3_name = image_galley[2][0][0]
        img3 = image_galley[2][1]

        image1 = make_variable(img1)
        image2 = make_variable(img2)
        image3 = make_variable(img3)

        # label = make_variable(label)

        # feature_concat = torch.cat((image1, image2, image3), 1)
        feature_concat = torch.cat((image1, image2, image3), 1)

        preds = model(feature_concat)

        probability = torch.nn.functional.softmax(preds[1],
                                                  dim=1)[:,
                                                         1].detach().tolist()

        probability_value = np.array(probability)

        # print("{} {} {} {:.8f} {:.8f} {:.8f}".format(img1_name, img2_name, img3_name,
        #                                              probability1_value[0], probability2_value[0],
        #                                              probability3_value[0]))

        newline = img1_name + " " + img2_name + " " + img3_name + " " + "{:.8f}".format(
            probability_value[0]) + "\n"
        print(newline)
        # print("{} {} {} {:.8f}".format(img1_name, img2_name, img3_name,
        #                                probability_value[0]))
        data_write.write(newline)

    data_write.close()
Ejemplo n.º 3
0
def train_resnet80(model, pretrain_loader, val_loader):

    global lr
    global best_prec1

    lr = params.base_lr

    # model = construct_premodel(model, params)
    # model = construct_resnet18(model, params)
    model.train()

    optimizer = torch.optim.Adam(list(model.parameters()),
                                 lr=params.base_lr,
                                 betas=(0.9, 0.99))
    criterion = nn.CrossEntropyLoss().cuda()

    for epoch in range(params.start_epoch,
                       params.start_epoch + params.num_epochs):
        adjust_learning_rate(optimizer, epoch, params.base_lr)

        # train for one epoch
        # train_batch(train_loader, model, criterion, optimizer, epoch)
        for step, (images, labels) in enumerate(pretrain_loader):
            # make images and labels variable
            images = make_variable(images)
            labels = make_variable(labels.squeeze_())

            # zero gradients for optimizer
            optimizer.zero_grad()

            # compute loss for critic
            preds = model(images)
            loss = criterion(preds, labels)

            # optimize source classifier
            loss.backward()
            optimizer.step()

            # print step info
            if ((step + 1) % params.log_step_pre == 0):
                print("Epoch [{}/{}] Step [{}/{}]: loss={}".format(
                    epoch + 1, params.num_epochs, step + 1,
                    len(pretrain_loader), loss.item()))

        if ((epoch + 1) % params.eval_step_pre == 0):
            eval_pretrain(model, val_loader)

        # save model parameters
        if ((epoch + 1) % params.save_step_pre == 0):
            save_model(model, "Resnet18-{}.pt".format(epoch + 1))

    # # save final model
    save_model(model, "Resnet18-final.pt")

    return model
Ejemplo n.º 4
0
def eval_fusionmodal(model, data_loader):
    """Evaluation for target encoder by source classifier on target dataset."""
    # set eval state for Dropout and BN layers
    model.eval()


    # init loss and accuracy
    loss = 0
    acc = 0

    # set loss function
    criterion = nn.CrossEntropyLoss()

    TP = 0.
    TN = 0.
    FP = 0.
    FN = 0.

    # for (images, labels) in data_loader:
    for (image1, image2, image3, label) in data_loader:
        img1_name = image1[0][0]
        img1 = image1[1]
        img2_name = image2[0][0]
        img2 = image2[1]
        img3_name = image3[0][0]
        img3 = image3[1]
        image1 = make_variable(img1)
        image2 = make_variable(img2)
        image3 = make_variable(img3)
        label = make_variable(label)


        feature_concat = torch.cat((image1, image2, image3), 1)

        preds = model(feature_concat)


        probability = torch.nn.functional.softmax(preds, dim=1)[:, 1].detach().tolist()




        probability_value = np.array(probability)


        # print("{} {} {} {:.8f} {:.8f} {:.8f}".format(img1_name, img2_name, img3_name,
        #                                              probability1_value[0], probability2_value[0],
        #                                              probability3_value[0]))

        print("{} {} {} {:.8f}".format(img1_name, img2_name, img3_name,
                                       probability_value[0]))
Ejemplo n.º 5
0
def eval_tgt_robust(encoder, classifier, critic, data_loader):
    """Evaluate model for target domain with attack on labels and domains"""
    # Set eval state for Dropout and BN layers
    encoder.eval()
    classifier.eval()
    critic.eval()

    # Init loss and accuracy
    loss, acc = 0, 0
    test_robust_loss, test_robust_acc = 0, 0

    # Set loss function
    criterion = nn.CrossEntropyLoss()

    # Evaluate network
    for (images, labels) in data_loader:
        images = make_variable(images)
        labels = make_variable(labels)
        domain_tgt = make_variable(torch.ones(images.size(0)).long())

        delta_src = attack_pgd(encoder, classifier, images, labels)
        delta_domain = attack_pgd(encoder, critic, images, domain_tgt)
        delta_src = delta_src.detach()
        delta_domain = delta_domain.detach()

        # Compute loss
        robust_images = normalize(
            torch.clamp(images + delta_src[:images.size(0)] +
                        delta_domain[:images.size(0)],
                        min=params.lower_limit,
                        max=params.upper_limit))
        robust_preds = classifier(encoder(robust_images))
        test_robust_loss += criterion(robust_preds, labels).item()

        out = classifier(encoder(images))
        loss += criterion(out, labels).item()

        test_robust_acc += torch.sum(
            robust_preds.max(1)[1] == labels.data).double()
        acc += torch.sum(out.max(1)[1] == labels.data).double()

    loss /= len(data_loader)
    test_robust_loss /= len(data_loader)
    acc = acc / len(data_loader.dataset)
    test_robust_acc = test_robust_acc / len(data_loader.dataset)

    print(
        "Avg Evaluation Loss: {:.4f}, Avg Evaluation Accuracy: {:.4%}, Ave Evaluation Robust Loss: {:.4f}, "
        "Ave Robust Accuracy: {:.4%}".format(loss, acc, test_robust_loss,
                                             test_robust_acc))
Ejemplo n.º 6
0
def eval_src_score(model, data_loader):
    """Evaluation for target encoder by source classifier on target dataset."""
    # set eval state for Dropout and BN layers
    model.eval()

    # init loss and accuracy
    loss = 0
    acc = 0

    # set loss function
    criterion = nn.CrossEntropyLoss()

    TP = 0.
    TN = 0.
    FP = 0.
    FN = 0.

    for (image1, image2, image3) in data_loader:
        img1_name = image1[0][0]
        img1 = image1[1]
        img2_name = image2[0][0]
        img2 = image2[1]
        img3_name = image3[0][0]
        img3 = image3[1]
        image1 = make_variable(img1)
        image2 = make_variable(img2)
        image3 = make_variable(img3)

        preds1 = model(image1)
        preds2 = model(image2)
        preds3 = model(image3)
        probability1 = torch.nn.functional.softmax(preds1,
                                                   dim=1)[:,
                                                          0].detach().tolist()
        probability2 = torch.nn.functional.softmax(preds2,
                                                   dim=1)[:,
                                                          0].detach().tolist()
        probability3 = torch.nn.functional.softmax(preds3,
                                                   dim=1)[:,
                                                          0].detach().tolist()

        probability = (probability1[0] + probability2[0] + probability3[0]) / 3
Ejemplo n.º 7
0
def eval_src(model, data_loader):
    """Evaluation for target encoder by source classifier on target dataset."""
    # set eval state for Dropout and BN layers
    model.eval()

    # init loss and accuracy
    loss = 0
    acc = 0

    # set loss function
    criterion = nn.CrossEntropyLoss()

    TP = 0.
    TN = 0.
    FP = 0.
    FN = 0.

    # for (images, labels) in data_loader:
    for (image1, image2, image3, label) in data_loader:

        img1_name = image1[0][0]
        img1 = image1[1]
        img2_name = image2[0][0]
        img2 = image2[1]
        img3_name = image3[0][0]
        img3 = image3[1]

        image1 = make_variable(img1)
        image2 = make_variable(img2)
        image3 = make_variable(img3)

        label = make_variable(label)

        feature_concat = torch.cat((image1, image2, image3), 1)

        feat, preds = model(feature_concat)

        probability = torch.nn.functional.softmax(preds,
                                                  dim=1)[:,
                                                         1].detach().tolist()

        probability1_value = np.array(probability)

        pred_cls = preds.data.max(1)[1]
        acc += pred_cls.eq(label.data).cpu().sum()

        target_list = label.cpu().numpy()
        pred_list = pred_cls.cpu().numpy()

        for i in range(len(target_list)):
            if target_list[i] == 1 and pred_list[i] == 1:
                TP += 1
            elif target_list[i] == 0 and pred_list[i] == 0:
                TN += 1
            elif target_list[i] == 1 and pred_list[i] == 0:
                FN += 1
            elif target_list[i] == 0 and pred_list[i] == 1:
                FP += 1

    loss /= len(data_loader)
    acc = (TP + TN) / len(data_loader.dataset)

    print("Avg Loss = {}, Avg Accuracy = {:2%}".format(loss, acc))
    print('TP:{}, TP+FN:{}, TN:{}, TN+FP:{}'.format(TP, TP + FN, TN, TN + FP))

    TP_rate = float(TP / (TP + FN))
    TN_rate = float(TN / (TN + FP))

    APCER = float(FP / (TN + FP))
    NPCER = float(FN / (FN + TP))
    ACER = (APCER + NPCER) / 2

    HTER = 1 - (TP_rate + TN_rate) / 2

    print('APCER:{}, NPCER:{}, HTER:{}, ACER:{}'.format(
        APCER, NPCER, HTER, ACER))
Ejemplo n.º 8
0
def train_src_threemodal(model1, model2, model3, train_loader1, train_loader2,
                         train_loader3, val_loader):

    global lr
    global best_prec1

    lr = params.base_lr

    # model1 = construct_resnet18(model1, params)
    # model2 = construct_resnet18(model2, params)
    # model3 = construct_resnet18(model3, params)
    model1 = construct_resnet34(model1, params)
    model2 = construct_resnet34(model2, params)
    model3 = construct_resnet34(model3, params)

    model1.train()
    model2.train()
    model3.train()

    optimizer1 = torch.optim.Adam(list(model1.parameters()),
                                  lr=params.base_lr,
                                  betas=(0.9, 0.99))
    optimizer2 = torch.optim.Adam(list(model2.parameters()),
                                  lr=params.base_lr,
                                  betas=(0.9, 0.99))
    optimizer3 = torch.optim.Adam(list(model3.parameters()),
                                  lr=params.base_lr,
                                  betas=(0.9, 0.99))
    # criterion = nn.CrossEntropyLoss().cuda()
    focalloss = FocalLoss(gamma=2)

    for epoch in range(params.start_epoch,
                       params.start_epoch + params.num_epochs):
        adjust_learning_rate(optimizer1, epoch, params.base_lr)
        adjust_learning_rate(optimizer2, epoch, params.base_lr)
        adjust_learning_rate(optimizer3, epoch, params.base_lr)
        # train for one epoch
        # train_batch(train_loader, model, criterion, optimizer, epoch)
        for step, (images, labels) in enumerate(train_loader1):
            # make images and labels variable
            images = make_variable(images)
            labels = make_variable(labels.squeeze_())

            # zero gradients for optimizer
            optimizer1.zero_grad()

            # compute loss for critic
            preds = model1(images)
            loss = focalloss(preds, labels)

            # optimize source classifier
            loss.backward()
            optimizer1.step()

            # print step info
            if ((step + 1) % params.log_step_pre == 0):
                print("Color Epoch [{}/{}] Step [{}/{}]: loss={}".format(
                    epoch + 1, params.num_epochs, step + 1, len(train_loader1),
                    loss.item()))

        for step, (images, labels) in enumerate(train_loader2):
            # make images and labels variable
            images = make_variable(images)
            labels = make_variable(labels.squeeze_())

            # zero gradients for optimizer
            optimizer2.zero_grad()

            # compute loss for critic
            preds = model2(images)
            loss = focalloss(preds, labels)

            # optimize source classifier
            loss.backward()
            optimizer2.step()

            # print step info
            if ((step + 1) % params.log_step_pre == 0):
                print("Depth Epoch [{}/{}] Step [{}/{}]: loss={}".format(
                    epoch + 1, params.num_epochs, step + 1, len(train_loader2),
                    loss.item()))

        for step, (images, labels) in enumerate(train_loader3):
            # make images and labels variable
            images = make_variable(images)
            labels = make_variable(labels.squeeze_())

            # zero gradients for optimizer
            optimizer3.zero_grad()

            # compute loss for critic
            preds = model3(images)
            loss = focalloss(preds, labels)

            # optimize source classifier
            loss.backward()
            optimizer3.step()

            # print step info
            if ((step + 1) % params.log_step_pre == 0):
                print("Ir Epoch [{}/{}] Step [{}/{}]: loss={}".format(
                    epoch + 1, params.num_epochs, step + 1, len(train_loader3),
                    loss.item()))

        if ((epoch + 1) % params.eval_step_pre == 0):
            eval_acc(model1, model2, model3, val_loader)

        # save model parameters
        if ((epoch + 1) % params.save_step_pre == 0):
            save_model(model1, "MultiNet-color-{}.pt".format(epoch + 1))
            save_model(model2, "MultiNet-depth-{}.pt".format(epoch + 1))
            save_model(model3, "MultiNet-ir-{}.pt".format(epoch + 1))

    # # save final model
    save_model(model1, "MultiNet-color-final.pt")
    save_model(model2, "MultiNet-depth-final.pt")
    save_model(model3, "MultiNet-ir-final.pt")

    return model1, model2, model3
Ejemplo n.º 9
0
def train_feature_fusion(model, train_loader, val_loader):

    global lr
    global best_prec1

    lr = params.base_lr

    # model = construct_premodel(model, params)
    # model = construct_resnet18(model, params)
    model = torch.nn.DataParallel(model)
    model.train()

    # optimizer = torch.optim.Adam(
    #     list(model.parameters()),
    #     lr=params.base_lr,
    #     betas=(0.9, 0.99))
    optimizer = torch.optim.SGD(list(model.parameters()),
                                lr=params.base_lr,
                                momentum=0.9,
                                weight_decay=0.0005)

    criterion = nn.CrossEntropyLoss().cuda()
    centerloss = CenterLoss(num_classes=2, feat_dim=2, use_gpu=True)

    optimzer4center = torch.optim.SGD(centerloss.parameters(), lr=0.5)
    loss_weight = 0.1
    # focalloss = FocalLoss(gamma=2)

    for epoch in range(params.start_epoch,
                       params.start_epoch + params.num_epochs):
        adjust_learning_rate(optimizer, epoch, params.base_lr)

        # train for one epoch
        # train_batch(train_loader, model, criterion, optimizer, epoch)
        for step, (image1, image2, image3, label) in enumerate(train_loader):

            img1_name = image1[0][0]
            img1 = image1[1]
            img2_name = image2[0][0]
            img2 = image2[1]
            img3_name = image3[0][0]
            img3 = image3[1]

            image1 = make_variable(img1)
            image2 = make_variable(img2)
            image3 = make_variable(img3)

            label = make_variable(label.squeeze_())

            # img1_array = np.array(image1)
            # img2_array = np.array(image2)
            # img3_array = np.array(image3)
            # print(img1_name, img2_name, img3_name, img1.shape, img2.shape, img3.shape)

            # feature_concat = torch.cat((image1, image2, image3), 1)
            feature_concat = torch.cat((image1, image2, image3), 1)

            feat, preds = model(feature_concat)

            loss = criterion(preds,
                             label) + loss_weight * centerloss(feat, label)
            # loss = criterion(preds, label)
            # loss = focalloss(preds, label)
            optimizer.zero_grad()
            optimzer4center.zero_grad()

            # optimize source classifier
            loss.backward()
            optimizer.step()
            optimzer4center.step()

            # print step info
            if ((step + 1) % params.log_step_pre == 0):
                print("fusion Epoch [{}/{}] Step [{}/{}]: loss={}".format(
                    epoch + 1, params.num_epochs, step + 1, len(train_loader),
                    loss.item()))

        if ((epoch + 1) % params.eval_step_pre == 0):
            eval_src(model, val_loader)

        # save model parameters
        if ((epoch + 1) % params.save_step_pre == 0):
            save_model(model, "MultiNet-fusion-{}.pt".format(epoch + 1))

    # # save final model
    save_model(model, "MultiNet-fusion-final.pt")

    return model
Ejemplo n.º 10
0
def train_feature_fusion(model, train_loader, val_loader):


    global lr
    global best_prec1

    lr = params.base_lr

    # model = construct_resnet18(model, params)
    # model = construct_resnet34(model, params)
    model.train()

    optimizer = torch.optim.Adam(
        list(model.parameters()),
        lr=params.base_lr,
        betas=(0.9, 0.99))
    criterion = nn.CrossEntropyLoss().cuda()


    for epoch in range(params.start_epoch, params.start_epoch + params.num_epochs):
        adjust_learning_rate(optimizer, epoch, params.base_lr)

        # train for one epoch
        # train_batch(train_loader, model, criterion, optimizer, epoch)
        for step, (image1, image2, image3, label) in enumerate(train_loader):


            img1_name = image1[0][0]
            img1 = image1[1]
            img2_name = image2[0][0]
            img2 = image2[1]
            img3_name = image3[0][0]
            img3 = image3[1]
            image1 = make_variable(img1)
            image2 = make_variable(img2)
            image3 = make_variable(img3)
            label = make_variable(label.squeeze_())

            # img1_array = np.array(image1)
            # img2_array = np.array(image2)
            # img3_array = np.array(image3)
            # print(img1_name, img2_name, img3_name, img1.shape, img2.shape, img3.shape)

            feature_concat = torch.cat((image1, image2, image3), 1)

            preds = model(feature_concat)

            loss = criterion(preds, label)

            # optimize source classifier
            loss.backward()
            optimizer.step()

            # print step info
            if ((step + 1) % params.log_step_pre == 0):
                print("fusion Epoch [{}/{}] Step [{}/{}]: loss={}"
                      .format(epoch + 1,
                              params.num_epochs,
                              step + 1,
                              len(train_loader),
                              loss.item()))


        if ((epoch + 1) % params.eval_step_pre == 0):
            eval_src(model, val_loader)

        # save model parameters
        if ((epoch + 1) % params.save_step_pre == 0):
            save_model(model, "MultiNet-fusion-{}.pt".format(epoch + 1))


    # # save final model
    save_model(model, "MultiNet-fusion-final.pt")


    return model
Ejemplo n.º 11
0
def train_src_robust(encoder, classifier, data_loader, mode='ADDA'):
    """Train classifier for source domain with robust training for ADDA"""

    # Step 1: Network setup
    # Set train state for both Dropout and BN layers
    encoder.train()
    classifier.train()

    # Set up optimizer and criterion
    optimizer = optim.Adam(list(encoder.parameters()) +
                           list(classifier.parameters()),
                           lr=params.learning_rate,
                           weight_decay=params.weight_decay)

    criterion = nn.CrossEntropyLoss()
    num_epochs = params.num_epochs_pre if mode == 'ADDA' else params.num_epochs
    # Step 2: Pretrain the source model
    for epoch in range(num_epochs):

        # Init accuracy and loss
        start_time = time.time()
        train_loss, train_acc, train_n = 0, 0, 0
        train_robust_loss, train_robust_acc = 0, 0

        for step, (images, labels) in enumerate(data_loader):

            # Make images and labels variable
            images = make_variable(images)
            labels = make_variable(labels)

            # Zero gradients for optimizer
            optimizer.zero_grad()

            delta = attack_pgd(encoder, classifier, images, labels)

            # Compute loss for critic with attack img
            robust_images = normalize(
                torch.clamp(images + delta[:images.size(0)],
                            min=params.lower_limit,
                            max=params.upper_limit))
            robust_preds = classifier(encoder(robust_images))
            robust_loss = criterion(robust_preds, labels)

            # Optimize source classifier
            robust_loss.backward()
            optimizer.step()

            # Compute loss for critic with original image
            preds = classifier(encoder(images))
            loss = criterion(preds, labels)

            train_robust_loss += robust_loss.item() * labels.size(0)
            train_robust_acc += torch.sum(
                robust_preds.max(1)[1] == labels).double()
            train_loss += loss.item() * labels.size(0)
            train_acc += torch.sum(preds.max(1)[1] == labels.data).double()
            train_n += labels.size(0)

            # Print step info
            if (step + 1) % params.log_step_pre == 0:
                print(
                    "Epoch [{}/{}] Step [{}/{}]: Avg Training loss: {:.4f} Avg Training Accuracy: {:.4%}"
                    " Avg Robust Training Loss: {:.4f} Avg Robust Training Accuracy: {:.4%}"
                    .format(epoch + 1, num_epochs, step + 1, len(data_loader),
                            train_loss / train_n, train_acc / train_n,
                            train_robust_loss / train_n,
                            train_robust_acc / train_n))

        time_elapsed = time.time() - start_time

        # Eval model on test set
        if (epoch + 1) % params.eval_step_pre == 0:
            eval_tgt(encoder, classifier, data_loader)

        # Save model parameters
        if (epoch + 1) % params.save_step_pre == 0:
            print('Epoch [{}/{}] completed in {:.0f}m {:.0f}s'.format(
                epoch + 1, num_epochs, time_elapsed // 60, time_elapsed % 60))
            root = params.adda_root if mode == 'ADDA' else params.model_root
            save_model(encoder, root,
                       "{}-source-encoder-rb-{}.pt".format(mode, epoch + 1))
            save_model(classifier, root,
                       "{}-source-classifier-rb-{}.pt".format(mode, epoch + 1))

    # Save final model
    root = params.adda_root if mode == 'ADDA' else params.model_root

    save_model(encoder, root, "{}-source-encoder-rb-final.pt".format(mode))
    save_model(classifier, root,
               "{}-source-classifier-rb-final.pt".format(mode))

    return encoder, classifier
Ejemplo n.º 12
0
def train_alda(encoder,
               classifier,
               critic,
               src_data_loader,
               tgt_data_loader,
               tgt_data_loader_eval,
               robust=True):
    """Train encoder for DANN """

    # 1. Network Setup
    encoder.train()
    classifier.train()
    critic.train()

    # Set up optimizer and criterion
    optimizer = optim.Adam(list(encoder.parameters()) +
                           list(classifier.parameters()),
                           lr=params.learning_rate,
                           weight_decay=params.weight_decay)
    optimizer_critic = optim.Adam(critic.parameters(),
                                  lr=params.learning_rate,
                                  weight_decay=params.weight_decay)

    # 2. Train network
    for epoch in range(params.num_epochs):

        start_time = time.time()
        total_loss, train_n = 0, 0
        loss_target_value = 0
        train_clf_loss, train_clf_acc, train_clf_n = 0, 0, 0

        # Zip source and target data pair
        len_data_loader = min(len(src_data_loader), len(tgt_data_loader))
        data_zip = enumerate(zip(src_data_loader, tgt_data_loader))

        for step, ((images_src, labels_src), (images_tgt, _)) in data_zip:

            p = float(step + epoch * len_data_loader) / \
                params.num_epochs / len_data_loader
            alpha = 2. / (1. + np.exp(-10 * p)) - 1

            # Make images variable
            images_src = make_variable(images_src)
            images_tgt = make_variable(images_tgt)
            labels_src = make_variable(labels_src)
            images_concat = torch.cat((images_src, images_tgt), 0)

            # Prepare real and fake label (domain labels)
            domain_src = make_variable(torch.ones(images_src.size(0)).long())
            domain_tgt = make_variable(torch.zeros(images_tgt.size(0)).long())
            domain_concat = torch.cat((domain_src, domain_tgt), 0)

            # Zero gradients for optimizer
            optimizer.zero_grad()

            if robust:
                delta_src = attack_pgd(encoder, classifier, images_src,
                                       labels_src)
                robust_src = normalize(
                    torch.clamp(images_src + delta_src[:images_src.size(0)],
                                min=params.lower_limit,
                                max=params.upper_limit))

            # Train on source domain
            feats = encoder(images_src) if not robust else encoder(robust_src)
            preds_src = classifier(feats)
            preds_tgt = classifier(encoder(images_tgt))
            feats = encoder(images_concat)

            preds_critic = critic(feats, alpha=alpha)

            loss_adv, loss_reg, loss_correct = alda_loss(preds_critic,
                                                         labels_src,
                                                         preds_src,
                                                         preds_tgt,
                                                         threshold=0.6)

            # Computer loss for source classification and domain classification
            transfer_loss = loss_adv + loss_correct if epoch > 2 else 0

            # Loss_reg is only backward to the discrinminator
            set_requires_grad(encoder, requires_grad=False)
            set_requires_grad(classifier, requires_grad=False)
            loss_reg.backward(retain_graph=True)
            set_requires_grad(encoder, requires_grad=True)
            set_requires_grad(classifier, requires_grad=True)

            loss_target_value += transfer_loss.item() * (preds_src.size(0) +
                                                         preds_tgt.size(0))
            train_n += preds_src.size(0) + preds_tgt.size(0)

            loss = preds_src + transfer_loss  # Loss_func.Square(softmax_output) + transfer_loss
            total_loss += loss.item() * (preds_src.size(0) + preds_tgt.size(0))

            train_clf_n += preds_src.size(0)
            train_clf_loss += loss_correct.item() * preds_src.size(0)
            train_clf_acc += torch.sum(
                preds_src.max(1)[1] == labels_src.data).double()

            # Optimize model
            loss.backward()
            optimizer.step()
            if epoch > 2:
                optimizer_critic.step()

            if ((step + 1) % params.log_step == 0):
                print(
                    "Epoch [{}/{}] Step [{}/{}] Avg total loss: {:.4f} Avg Transfer Loss: {:.4f}"
                    "  Avg Classification Loss: {:4f} "
                    "Avg Classification Accuracy: {:.4%}".format(
                        epoch + 1, params.num_epochs, step + 1,
                        len_data_loader, total_loss / train_n,
                        loss_target_value / train_n,
                        train_clf_loss / train_clf_n,
                        train_clf_acc / train_clf_n))

        time_elapsed = start_time - time.time()

        # Eval model
        if (epoch + 1) % params.eval_step == 0:
            eval_tgt_robust(encoder, classifier, tgt_data_loader_eval)

        # Save model parameters
        if (epoch + 1) % params.save_step == 0:
            print('Epoch [{}/{}] completed in {:.0f}m {:.0f}s'.format(
                epoch + 1, params.num_epochs, time_elapsed // 60,
                time_elapsed % 60))
            filename = "ALDA-encoder-{}.pt".format(epoch + 1) if not robust \
                else "ALDA-encoder-rb-{}.pt".format(epoch + 1)
            save_model(encoder, params.dann_root, filename)
            filename = "ALDA-classifier-{}.pt".format(epoch + 1) if not robust \
                else "ALDA-classifier-rb-{}.pt".format(epoch + 1)
            save_model(classifier, params.dann_root, filename)
            filename = "ALDA-critic-{}.pt".format(epoch + 1) if not robust \
                else "ALDA-critic-rb-{}.pt".format(epoch + 1)
            save_model(critic, params.dann_root, filename)

    # Save final model
    filename = "ALDA-encoder-final.pt" if not robust else "ALDA-encoder-rb-final.pt"
    save_model(encoder, params.dann_root, filename)
    filename = "ALDA-classifier-final.pt" if not robust else "ALDA-classifier-rb-final.pt"
    save_model(classifier, params.dann_root, filename)
    filename = "ALDA-critic-final.pt" if not robust else "ALDA-critic-rb-final.pt"
    save_model(critic, params.dann_root, filename)

    return encoder, classifier, critic
Ejemplo n.º 13
0
def train_dann(encoder,
               classifier,
               critic,
               src_data_loader,
               tgt_data_loader,
               tgt_data_loader_eval,
               robust=True):
    """Train encoder for DANN """

    # 1. Network Setup
    encoder.train()
    classifier.train()
    critic.train()

    # Set up optimizer and criterion
    optimizer = optim.Adam(list(encoder.parameters()) +
                           list(classifier.parameters()) +
                           list(critic.parameters()),
                           lr=params.learning_rate,
                           weight_decay=params.weight_decay)

    criterion = nn.CrossEntropyLoss()

    # 2. Train network
    for epoch in range(params.num_epochs):

        start_time = time.time()
        total_loss, train_n = 0, 0
        train_clf_loss, train_clf_acc, train_clf_n = 0, 0, 0
        train_domain_loss, train_domain_acc, train_domain_n = 0, 0, 0

        # Zip source and target data pair
        len_data_loader = min(len(src_data_loader), len(tgt_data_loader))
        data_zip = enumerate(zip(src_data_loader, tgt_data_loader))

        for step, ((images_src, labels_src), (images_tgt, _)) in data_zip:

            p = float(step + epoch * len_data_loader) / \
                params.num_epochs / len_data_loader
            alpha = 2. / (1. + np.exp(-10 * p)) - 1

            # Update lr
            # update_lr(optimizer, lr_scheduler(p))

            # Make images variable
            images_src = make_variable(images_src)
            images_tgt = make_variable(images_tgt)
            labels_src = make_variable(labels_src)
            images_concat = torch.cat((images_src, images_tgt), 0)

            # Prepare real and fake label (domain labels)
            domain_src = make_variable(torch.ones(images_src.size(0)).long())
            domain_tgt = make_variable(torch.zeros(images_tgt.size(0)).long())
            domain_concat = torch.cat((domain_src, domain_tgt), 0)

            # Zero gradients for optimizer
            optimizer.zero_grad()

            if robust:
                delta_src = attack_pgd(encoder, classifier, images_src,
                                       labels_src)
                delta_domain = attack_pgd(encoder, critic, images_concat,
                                          domain_concat)

                robust_src = normalize(
                    torch.clamp(images_src + delta_src[:images_src.size(0)],
                                min=params.lower_limit,
                                max=params.upper_limit))
                robust_domain = normalize(
                    torch.clamp(images_concat +
                                delta_domain[:images_concat.size(0)],
                                min=params.lower_limit,
                                max=params.upper_limit))

            # Train on source domain
            feats = encoder(images_src) if not robust else encoder(robust_src)
            preds_src = classifier(feats)
            feats = encoder(images_concat) if not robust else encoder(
                robust_domain)

            preds_domain = critic(feats, alpha=alpha)

            # Computer loss for source classification and domain classification
            loss_src = criterion(preds_src, labels_src)
            loss_domain = criterion(preds_domain, domain_concat)

            loss = loss_src + loss_domain

            train_clf_n += preds_src.size(0)
            train_domain_n += preds_domain.size(0)
            train_n += train_clf_n + train_domain_n

            total_loss += loss.item() * (preds_src.size(0) +
                                         preds_domain.size(0))
            train_clf_loss += loss_src.item() * preds_src.size(0)
            train_domain_loss += loss_domain.item()

            train_domain_acc += torch.sum(
                preds_domain.max(1)[1] == domain_concat.data).double()
            train_clf_acc += torch.sum(
                preds_src.max(1)[1] == labels_src.data).double()

            # Optimize model
            loss.backward()
            optimizer.step()

            if ((step + 1) % params.log_step == 0):
                print(
                    "Epoch [{}/{}] Step [{}/{}] Avg total loss: {:.4f} Avg Domain Loss: {:.4f}"
                    " Avg Domain Accuracy: {:.4%} Avg Classification Loss: {:4f} "
                    "Avg Classification Accuracy: {:.4%}".format(
                        epoch + 1, params.num_epochs, step + 1,
                        len_data_loader, total_loss / train_n,
                        train_domain_loss / train_domain_n,
                        train_domain_acc / train_domain_n,
                        train_clf_loss / train_clf_n,
                        train_clf_acc / train_clf_n))

        time_elapsed = start_time - time.time()

        # Eval model
        if (epoch + 1) % params.eval_step == 0:
            eval_tgt_robust(encoder, classifier, tgt_data_loader_eval)

        # Save model parameters
        if (epoch + 1) % params.save_step == 0:
            print('Epoch [{}/{}] completed in {:.0f}m {:.0f}s'.format(
                epoch + 1, params.num_epochs, time_elapsed // 60,
                time_elapsed % 60))
            filename = "DANN-encoder-{}.pt".format(epoch + 1) if not robust \
                else "DANN-encoder-rb-{}.pt".format(epoch + 1)
            save_model(encoder, params.dann_root, filename)
            filename = "DANN-classifier-{}.pt".format(epoch + 1) if not robust \
                else "DANN-classifier-rb-{}.pt".format(epoch + 1)
            save_model(classifier, params.dann_root, filename)
            filename = "DANN-critic-{}.pt".format(epoch + 1) if not robust \
                else "DANN-critic-rb-{}.pt".format(epoch + 1)
            save_model(critic, params.dann_root, filename)

    # Save final model
    filename = "DANN-encoder-final.pt" if not robust else "DANN-encoder-rb-final.pt"
    save_model(encoder, params.dann_root, filename)
    filename = "DANN-classifier-final.pt" if not robust else "DANN-classifier-rb-final.pt"
    save_model(classifier, params.dann_root, filename)
    filename = "DANN-critic-final.pt" if not robust else "DANN-critic-rb-final.pt"
    save_model(critic, params.dann_root, filename)

    return encoder, classifier, critic
Ejemplo n.º 14
0
def train_tgt_wdgrl(encoder,
                    classifier,
                    critic,
                    src_data_loader,
                    tgt_data_loader,
                    tgt_data_loader_eval,
                    robust=False):
    """Train encoder encoder for wdgrl """

    # Set state
    encoder.train()
    classifier.train()
    critic.train()

    # Setup criterion and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(list(encoder.parameters()) +
                           list(classifier.parameters()),
                           lr=params.learning_rate,
                           weight_decay=params.weight_decay)
    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=params.learning_rate,
                                  weight_decay=params.weight_decay)
    len_data_loader = min(len(src_data_loader), len(tgt_data_loader))

    # Step 2 Train network
    for epoch in range(params.num_epochs):
        train_acc, train_loss, total_n = 0, 0, 0
        start_time = time.time()
        # Zip source and target data pair
        data_zip = enumerate(zip(src_data_loader, tgt_data_loader))
        for step, ((images_src, labels_src), (images_tgt, _)) in data_zip:

            images_src = make_variable(images_src)
            images_tgt = make_variable(images_tgt)
            labels_src = make_variable(labels_src)

            if robust:
                # PDG attack on the source image
                delta_src = attack_pgd(encoder, classifier, images_src,
                                       labels_src)

                robust_src = normalize(
                    torch.clamp(images_src + delta_src[:images_src.size(0)],
                                min=params.lower_limit,
                                max=params.upper_limit))

            if robust:
                critic_loss = train_critic_wdgrl(encoder, critic,
                                                 critic_optimizer, robust_src,
                                                 images_tgt)
            else:
                critic_loss = train_critic_wdgrl(encoder, critic,
                                                 critic_optimizer, images_src,
                                                 images_tgt)

            feat_src = encoder(images_src) if not robust else encoder(
                robust_src)
            feat_tgt = encoder(images_tgt)

            preds_src = classifier(feat_src)
            clf_loss = criterion(preds_src, labels_src)
            wasserstein_distance = critic(feat_src).mean() - (
                1 + params.beta_ratio) * critic(feat_tgt).mean()

            loss = clf_loss + params.wd_clf * wasserstein_distance
            train_loss += loss.item() * labels_src.size(0) + critic_loss
            total_n += labels_src.size(0)
            train_acc += torch.sum(
                preds_src.max(1)[1] == labels_src.data).double()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if ((step + 1) % params.log_step == 0):
                print(
                    "Epoch [{}/{}] Step [{}/{}]: Avg Training loss: {:.4f} Ave Training Accuracy {:.4%}"
                    .format(epoch + 1, params.num_epochs, step + 1,
                            len_data_loader, train_loss / total_n,
                            train_acc / total_n))

        time_elapsed = time.time() - start_time

        # Eval model
        if (epoch + 1) % params.eval_step == 0:
            eval_tgt_robust(encoder, classifier, tgt_data_loader_eval)

        # 2.4 Save model parameters #
        if ((epoch + 1) % params.save_step == 0):
            print('Epoch [{}/{}] completed in {:.0f}m {:.0f}s'.format(
                epoch + 1, params.num_epochs, time_elapsed // 60,
                time_elapsed % 60))
            filename = "WDGRL-encoder-{}.pt".format(epoch + 1) if not robust \
                else "WDGRL-encoder-rb-{}.pt".format(epoch + 1)
            save_model(encoder, params.wdgrl_root, filename)

            filename = "WDGRL-classifier-{}.pt".format(epoch + 1) if not robust \
                else "WDGRL-classifier-rb-{}.pt".format(epoch + 1)
            save_model(classifier, params.wdgrl_root, filename)

    filename = "WDGRL-classifier-final.pt" if not robust else "WDGRL-classifier-rb-final.pt"
    save_model(classifier, params.wdgrl_root, filename)

    filename = "WDGRL-encoder-final.pt" if not robust else "WDGRL-encoder-rb-final.pt"
    save_model(encoder, params.wdgrl_root, filename)

    return encoder, classifier
Ejemplo n.º 15
0
def train_tgt_adda(src_encoder,
                   tgt_encoder,
                   classifier,
                   critic,
                   src_data_loader,
                   tgt_data_loader,
                   tgt_data_loader_eval,
                   robust=False):
    """Train adda encoder for target domain """

    # Step 1: Network Setup
    # Set train state for Dropout and BN layers
    src_encoder.eval()
    tgt_encoder.train()
    critic.train()

    # Setup criterion and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer_tgt = optim.Adam(tgt_encoder.parameters(),
                               lr=params.learning_rate,
                               weight_decay=params.weight_decay)
    optimizer_critic = optim.Adam(critic.parameters(),
                                  lr=params.learning_rate,
                                  weight_decay=params.weight_decay)

    len_data_loader = min(len(src_data_loader), len(tgt_data_loader))

    # Step 2 Train network
    for epoch in range(params.num_epochs):

        start_time = time.time()
        train_disc_loss, train_disc_acc, train_n = 0, 0, 0
        # Zip source and target data pair
        data_zip = enumerate(zip(src_data_loader, tgt_data_loader))

        for step, ((images_src, _), (images_tgt, _)) in data_zip:

            # 2.1 train discriminator with fixed src_encoder
            # Make images variable
            images_src = make_variable(images_src)
            images_tgt = make_variable(images_tgt)

            # Prepare real and fake label (domain labels)
            domain_src = make_variable(torch.ones(images_src.size(0)).long())
            domain_tgt = make_variable(torch.zeros(images_tgt.size(0)).long())
            domain_concat = torch.cat((domain_src, domain_tgt), 0)

            if robust:
                # Attack images with domain labels
                delta_src = attack_pgd(src_encoder, critic, images_src,
                                       domain_src)
                delta_tgt = attack_pgd(tgt_encoder, critic, images_tgt,
                                       domain_tgt)

                robust_src = normalize(
                    torch.clamp(images_src + delta_src[:images_src.size(0)],
                                min=params.lower_limit,
                                max=params.upper_limit))
                robust_tgt = normalize(
                    torch.clamp(images_tgt + delta_tgt[:images_tgt.size(0)],
                                min=params.lower_limit,
                                max=params.upper_limit))

            # Zero gradients for optimizer for the discriminator
            optimizer_critic.zero_grad()

            # Extract and concat features
            feat_src = src_encoder(images_src) if not robust else src_encoder(
                robust_src)
            feat_tgt = tgt_encoder(images_tgt) if not robust else tgt_encoder(
                robust_tgt)
            feat_concat = torch.cat((feat_src, feat_tgt), 0)

            # Predict on discriminator
            preds_src_domain = critic(feat_src)
            preds_tgt_domain = critic(feat_tgt)
            # pred_concat = critic(feat_concat)

            # Compute loss for critic
            l1 = criterion(preds_src_domain, domain_src)
            l2 = criterion(preds_tgt_domain, domain_tgt)
            # loss_critic = criterion((pred_concat, domain_concat)
            loss_critic = l1 + l2
            train_disc_loss += loss_critic.item() * domain_concat.size(0)
            # train_disc_acc += torch.sum(pred_concat.max(1)[1] == domain_concat.data).double()
            train_disc_acc += torch.sum(
                preds_src_domain.max(1)[1] == domain_src.data).double()
            train_disc_acc += torch.sum(
                preds_tgt_domain.max(1)[1] == domain_tgt.data).double()
            train_n += domain_concat.size(0)
            loss_critic.backward()
            # Optimize critic
            optimizer_critic.step()

            # 2.2 Train target encoder
            # Zero gradients for optimizer
            optimizer_critic.zero_grad()
            optimizer_tgt.zero_grad()

            # Prepare fake labels (flip labels)
            domain_tgt = make_variable(torch.ones(feat_tgt.size(0)).long())

            if robust:
                # Attack the target images with domain labels
                delta_tgt = attack_pgd(tgt_encoder, critic, images_tgt,
                                       domain_tgt)
                robust_tgt = normalize(
                    torch.clamp(images_tgt + delta_tgt[:images_tgt.size(0)],
                                min=params.lower_limit,
                                max=params.upper_limit))

            # Extract target features
            feat_tgt = tgt_encoder(images_tgt) if not robust else tgt_encoder(
                robust_tgt)

            # Predict on discriminator
            pred_tgt = critic(feat_tgt)
            # Compute loss for target encoder
            loss_tgt = criterion(pred_tgt, domain_tgt)
            loss_tgt.backward()

            # Optimize target encoder
            optimizer_tgt.step()

            # 2.3 Print step info
            if (step + 1) % params.log_step == 0:
                print(
                    "Epoch [{}/{}] Step [{}/{}]: "
                    "Avg Discriminator Loss: {:.4f} Avg Discriminator Accuracy: {:.4%}"
                    .format(epoch + 1, params.num_epochs, step + 1,
                            len_data_loader, train_disc_loss / train_n,
                            train_disc_acc / train_n))

        time_elapsed = time.time() - start_time

        # Eval model
        if (epoch + 1) % params.eval_step == 0:
            if not robust:
                eval_tgt(tgt_encoder, classifier, tgt_data_loader_eval)
            else:
                eval_tgt_robust(tgt_encoder, classifier, critic,
                                tgt_data_loader_eval)

        # 2.4 Save model parameters #
        if (epoch + 1) % params.save_step == 0:
            print('Epoch [{}/{}] completec in {:.0f}m {:.0f}s'.format(
                epoch + 1, params.num_epochs, time_elapsed // 60,
                time_elapsed % 60))
            filename = "ADDA-critic-{}.pt".format(epoch + 1) if not robust \
                else "ADDA-critic-rb-{}.pt".format(epoch + 1)
            save_model(critic, params.adda_root, filename)

            filename = "ADDA-target-encoder-{}.pt".format(epoch + 1) if not robust \
                else "ADDA-target-encoder-rb-{}.pt".format(epoch + 1)
            save_model(tgt_encoder, params.adda_root, filename)

    filename = "ADDA-critic-final.pt" if not robust else "ADDA-critic-rb-final.pt"
    save_model(critic, params.adda_root, filename)

    filename = "ADDA-target-encoder-final.pt" if not robust else "ADDA-target-encoder-rb-final.pt"
    save_model(tgt_encoder, params.adda_root, filename)

    return tgt_encoder
Ejemplo n.º 16
0
def eval_acc(model1, model2, model3, data_loader):
    """Evaluation for target encoder by source classifier on target dataset."""
    # set eval state for Dropout and BN layers
    model1.eval()
    model2.eval()
    model3.eval()

    # init loss and accuracy
    loss = 0
    acc = 0

    # set loss function
    criterion = nn.CrossEntropyLoss()

    TP = 0.
    TN = 0.
    FP = 0.
    FN = 0.

    # for (images, labels) in data_loader:
    for (image1, image2, image3, label) in data_loader:

        img1_name = image1[0][0]
        img1 = image1[1]
        img2_name = image2[0][0]
        img2 = image2[1]
        img3_name = image3[0][0]
        img3 = image3[1]
        image1 = make_variable(img1)
        image2 = make_variable(img2)
        image3 = make_variable(img3)
        label = make_variable(label)

        preds1 = model1(image1)
        preds2 = model2(image2)
        preds3 = model3(image3)

        probability1 = torch.nn.functional.softmax(preds1,
                                                   dim=1)[:,
                                                          1].detach().tolist()
        probability2 = torch.nn.functional.softmax(preds2,
                                                   dim=1)[:,
                                                          1].detach().tolist()
        probability3 = torch.nn.functional.softmax(preds3,
                                                   dim=1)[:,
                                                          1].detach().tolist()

        probability = (probability1[0] + probability2[0] + probability3[0]) / 3

        probability1_value = np.array(probability1)
        probability2_value = np.array(probability2)
        probability3_value = np.array(probability3)

        pred_cls = preds1.data.max(1)[1]
        acc += pred_cls.eq(label.data).cpu().sum()

        target_list = label.cpu().numpy()
        pred_list = pred_cls.cpu().numpy()

        for i in range(len(target_list)):
            if target_list[i] == 1 and pred_list[i] == 1:
                TP += 1
            elif target_list[i] == 0 and pred_list[i] == 0:
                TN += 1
            elif target_list[i] == 1 and pred_list[i] == 0:
                FN += 1
            elif target_list[i] == 0 and pred_list[i] == 1:
                FP += 1

    loss /= len(data_loader)
    acc = (TP + TN) / len(data_loader.dataset)

    print("Avg Loss = {}, Avg Accuracy = {:2%}".format(loss, acc))
    print('TP:{}, TP+FN:{}, TN:{}, TN+FP:{}'.format(TP, TP + FN, TN, TN + FP))

    TP_rate = float(TP / (TP + FN))
    TN_rate = float(TN / (TN + FP))

    HTER = 1 - (TP_rate + TN_rate) / 2

    print('TP rate:{}, TN rate:{}, HTER:{}'.format(float(TP / (TP + FN)),
                                                   float(TN / (TN + FP)),
                                                   HTER))
Ejemplo n.º 17
0
def train_tgt(src_encoder, tgt_encoder, critic, src_data_loader,
              tgt_data_loader, src_classifier):
    """Train encoder for target domain."""
    ####################
    # 1. setup network #
    ####################

    # set train state for Dropout and BN layers
    tgt_encoder.train()
    critic.train()

    #src_encoder = torch.nn.DataParallel(src_encoder)
    #tgt_encoder = torch.nn.DataParallel(tgt_encoder)
    #critic = torch.nn.DataParallel(critic)

    # setup criterion and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer_tgt = optim.Adam(tgt_encoder.parameters(),
                               lr=params.c_learning_rate,
                               betas=(params.beta1, params.beta2))
    optimizer_critic = optim.Adam(critic.parameters(),
                                  lr=params.d_learning_rate,
                                  betas=(params.beta1, params.beta2))
    len_data_loader = min(len(src_data_loader), len(tgt_data_loader))

    ####################
    # 2. train network #
    ####################

    for epoch in range(params.adapt_num_epochs):
        # zip source and target data pair
        data_zip = enumerate(zip(src_data_loader, tgt_data_loader))
        for step, ((images_src, _), (images_tgt, _)) in data_zip:
            ###########################
            # 2.1 train discriminator #
            ###########################

            # make images variable
            images_src = make_variable(images_src)
            images_tgt = make_variable(images_tgt)

            # zero gradients for optimizer
            optimizer_critic.zero_grad()

            # extract and concat features
            feat_src = src_encoder(images_src)
            feat_tgt = tgt_encoder(images_tgt)
            feat_concat = torch.cat((feat_src, feat_tgt), 0)

            # predict on discriminator
            pred_concat = critic(feat_concat.detach())

            # prepare real and fake label
            label_src = make_variable(torch.ones(feat_src.size(0)).long())
            label_tgt = make_variable(torch.zeros(feat_tgt.size(0)).long())
            label_concat = torch.cat((label_src, label_tgt), 0)

            # compute loss for critic
            loss_critic = criterion(pred_concat, label_concat)
            loss_critic.backward()

            # optimize critic
            optimizer_critic.step()

            pred_cls = torch.squeeze(pred_concat.max(1)[1])
            acc = (pred_cls == label_concat).float().mean()

            ############################
            # 2.2 train target encoder #
            ############################

            # zero gradients for optimizer
            optimizer_critic.zero_grad()
            optimizer_tgt.zero_grad()

            # extract and target features
            feat_tgt = tgt_encoder(images_tgt)

            # predict on discriminator
            pred_tgt = critic(feat_tgt)

            # prepare fake labels
            label_tgt = make_variable(torch.ones(feat_tgt.size(0)).long())

            # compute loss for target encoder
            loss_tgt = criterion(pred_tgt, label_tgt)
            loss_tgt.backward()

            # optimize target encoder
            optimizer_tgt.step()

            #######################
            # 2.3 print step info #
            #######################
            if ((step + 1) % params.log_step == 0):
                print("Epoch [{}/{}] Step [{}/{}]:"
                      "d_loss={:.5f} g_loss={:.5f} acc={:.5f}".format(
                          epoch + 1, params.adapt_num_epochs, step + 1,
                          len_data_loader, loss_critic.item(), loss_tgt.item(),
                          acc.item()))

        #############################
        # 2.4 save model parameters #
        #############################
        # if ((epoch + 1) % params.save_step == 0):
        #     torch.save(critic.state_dict(), os.path.join(
        #         params.model_root,
        #         "ADDA-critic-{}.pt".format(epoch + 1)))
        #     torch.save(tgt_encoder.state_dict(), os.path.join(
        #         params.model_root,
        #         "ADDA-target-encoder-{}.pt".format(epoch + 1)))

        if ((epoch + 1) % params.save_step == 0):
            torch.save(
                critic.state_dict(),
                os.path.join(params.model_root,
                             "ADDA-critic-{}.pt".format(epoch + 1)))
            torch.save(
                tgt_encoder.state_dict(),
                os.path.join(params.model_root,
                             "ADDA-target-encoder-{}.pt".format(epoch + 1)))
            eval_tgt(tgt_encoder, src_classifier, tgt_data_loader)

    torch.save(critic.state_dict(),
               os.path.join(params.model_root, "ADDA-critic-final.pt"))
    torch.save(tgt_encoder.state_dict(),
               os.path.join(params.model_root, "ADDA-target-encoder-final.pt"))
    return tgt_encoder