Ejemplo n.º 1
0
def eval_tgt_robust(encoder, classifier, critic, data_loader):
    """Evaluate model for target domain with attack on labels and domains"""
    # Set eval state for Dropout and BN layers
    encoder.eval()
    classifier.eval()
    critic.eval()

    # Init loss and accuracy
    loss, acc = 0, 0
    test_robust_loss, test_robust_acc = 0, 0

    # Set loss function
    criterion = nn.CrossEntropyLoss()

    # Evaluate network
    for (images, labels) in data_loader:
        images = make_variable(images)
        labels = make_variable(labels)
        domain_tgt = make_variable(torch.ones(images.size(0)).long())

        delta_src = attack_pgd(encoder, classifier, images, labels)
        delta_domain = attack_pgd(encoder, critic, images, domain_tgt)
        delta_src = delta_src.detach()
        delta_domain = delta_domain.detach()

        # Compute loss
        robust_images = normalize(
            torch.clamp(images + delta_src[:images.size(0)] +
                        delta_domain[:images.size(0)],
                        min=params.lower_limit,
                        max=params.upper_limit))
        robust_preds = classifier(encoder(robust_images))
        test_robust_loss += criterion(robust_preds, labels).item()

        out = classifier(encoder(images))
        loss += criterion(out, labels).item()

        test_robust_acc += torch.sum(
            robust_preds.max(1)[1] == labels.data).double()
        acc += torch.sum(out.max(1)[1] == labels.data).double()

    loss /= len(data_loader)
    test_robust_loss /= len(data_loader)
    acc = acc / len(data_loader.dataset)
    test_robust_acc = test_robust_acc / len(data_loader.dataset)

    print(
        "Avg Evaluation Loss: {:.4f}, Avg Evaluation Accuracy: {:.4%}, Ave Evaluation Robust Loss: {:.4f}, "
        "Ave Robust Accuracy: {:.4%}".format(loss, acc, test_robust_loss,
                                             test_robust_acc))
Ejemplo n.º 2
0
def train_src_robust(encoder, classifier, data_loader, mode='ADDA'):
    """Train classifier for source domain with robust training for ADDA"""

    # Step 1: Network setup
    # Set train state for both Dropout and BN layers
    encoder.train()
    classifier.train()

    # Set up optimizer and criterion
    optimizer = optim.Adam(list(encoder.parameters()) +
                           list(classifier.parameters()),
                           lr=params.learning_rate,
                           weight_decay=params.weight_decay)

    criterion = nn.CrossEntropyLoss()
    num_epochs = params.num_epochs_pre if mode == 'ADDA' else params.num_epochs
    # Step 2: Pretrain the source model
    for epoch in range(num_epochs):

        # Init accuracy and loss
        start_time = time.time()
        train_loss, train_acc, train_n = 0, 0, 0
        train_robust_loss, train_robust_acc = 0, 0

        for step, (images, labels) in enumerate(data_loader):

            # Make images and labels variable
            images = make_variable(images)
            labels = make_variable(labels)

            # Zero gradients for optimizer
            optimizer.zero_grad()

            delta = attack_pgd(encoder, classifier, images, labels)

            # Compute loss for critic with attack img
            robust_images = normalize(
                torch.clamp(images + delta[:images.size(0)],
                            min=params.lower_limit,
                            max=params.upper_limit))
            robust_preds = classifier(encoder(robust_images))
            robust_loss = criterion(robust_preds, labels)

            # Optimize source classifier
            robust_loss.backward()
            optimizer.step()

            # Compute loss for critic with original image
            preds = classifier(encoder(images))
            loss = criterion(preds, labels)

            train_robust_loss += robust_loss.item() * labels.size(0)
            train_robust_acc += torch.sum(
                robust_preds.max(1)[1] == labels).double()
            train_loss += loss.item() * labels.size(0)
            train_acc += torch.sum(preds.max(1)[1] == labels.data).double()
            train_n += labels.size(0)

            # Print step info
            if (step + 1) % params.log_step_pre == 0:
                print(
                    "Epoch [{}/{}] Step [{}/{}]: Avg Training loss: {:.4f} Avg Training Accuracy: {:.4%}"
                    " Avg Robust Training Loss: {:.4f} Avg Robust Training Accuracy: {:.4%}"
                    .format(epoch + 1, num_epochs, step + 1, len(data_loader),
                            train_loss / train_n, train_acc / train_n,
                            train_robust_loss / train_n,
                            train_robust_acc / train_n))

        time_elapsed = time.time() - start_time

        # Eval model on test set
        if (epoch + 1) % params.eval_step_pre == 0:
            eval_tgt(encoder, classifier, data_loader)

        # Save model parameters
        if (epoch + 1) % params.save_step_pre == 0:
            print('Epoch [{}/{}] completed in {:.0f}m {:.0f}s'.format(
                epoch + 1, num_epochs, time_elapsed // 60, time_elapsed % 60))
            root = params.adda_root if mode == 'ADDA' else params.model_root
            save_model(encoder, root,
                       "{}-source-encoder-rb-{}.pt".format(mode, epoch + 1))
            save_model(classifier, root,
                       "{}-source-classifier-rb-{}.pt".format(mode, epoch + 1))

    # Save final model
    root = params.adda_root if mode == 'ADDA' else params.model_root

    save_model(encoder, root, "{}-source-encoder-rb-final.pt".format(mode))
    save_model(classifier, root,
               "{}-source-classifier-rb-final.pt".format(mode))

    return encoder, classifier
Ejemplo n.º 3
0
def train_alda(encoder,
               classifier,
               critic,
               src_data_loader,
               tgt_data_loader,
               tgt_data_loader_eval,
               robust=True):
    """Train encoder for DANN """

    # 1. Network Setup
    encoder.train()
    classifier.train()
    critic.train()

    # Set up optimizer and criterion
    optimizer = optim.Adam(list(encoder.parameters()) +
                           list(classifier.parameters()),
                           lr=params.learning_rate,
                           weight_decay=params.weight_decay)
    optimizer_critic = optim.Adam(critic.parameters(),
                                  lr=params.learning_rate,
                                  weight_decay=params.weight_decay)

    # 2. Train network
    for epoch in range(params.num_epochs):

        start_time = time.time()
        total_loss, train_n = 0, 0
        loss_target_value = 0
        train_clf_loss, train_clf_acc, train_clf_n = 0, 0, 0

        # Zip source and target data pair
        len_data_loader = min(len(src_data_loader), len(tgt_data_loader))
        data_zip = enumerate(zip(src_data_loader, tgt_data_loader))

        for step, ((images_src, labels_src), (images_tgt, _)) in data_zip:

            p = float(step + epoch * len_data_loader) / \
                params.num_epochs / len_data_loader
            alpha = 2. / (1. + np.exp(-10 * p)) - 1

            # Make images variable
            images_src = make_variable(images_src)
            images_tgt = make_variable(images_tgt)
            labels_src = make_variable(labels_src)
            images_concat = torch.cat((images_src, images_tgt), 0)

            # Prepare real and fake label (domain labels)
            domain_src = make_variable(torch.ones(images_src.size(0)).long())
            domain_tgt = make_variable(torch.zeros(images_tgt.size(0)).long())
            domain_concat = torch.cat((domain_src, domain_tgt), 0)

            # Zero gradients for optimizer
            optimizer.zero_grad()

            if robust:
                delta_src = attack_pgd(encoder, classifier, images_src,
                                       labels_src)
                robust_src = normalize(
                    torch.clamp(images_src + delta_src[:images_src.size(0)],
                                min=params.lower_limit,
                                max=params.upper_limit))

            # Train on source domain
            feats = encoder(images_src) if not robust else encoder(robust_src)
            preds_src = classifier(feats)
            preds_tgt = classifier(encoder(images_tgt))
            feats = encoder(images_concat)

            preds_critic = critic(feats, alpha=alpha)

            loss_adv, loss_reg, loss_correct = alda_loss(preds_critic,
                                                         labels_src,
                                                         preds_src,
                                                         preds_tgt,
                                                         threshold=0.6)

            # Computer loss for source classification and domain classification
            transfer_loss = loss_adv + loss_correct if epoch > 2 else 0

            # Loss_reg is only backward to the discrinminator
            set_requires_grad(encoder, requires_grad=False)
            set_requires_grad(classifier, requires_grad=False)
            loss_reg.backward(retain_graph=True)
            set_requires_grad(encoder, requires_grad=True)
            set_requires_grad(classifier, requires_grad=True)

            loss_target_value += transfer_loss.item() * (preds_src.size(0) +
                                                         preds_tgt.size(0))
            train_n += preds_src.size(0) + preds_tgt.size(0)

            loss = preds_src + transfer_loss  # Loss_func.Square(softmax_output) + transfer_loss
            total_loss += loss.item() * (preds_src.size(0) + preds_tgt.size(0))

            train_clf_n += preds_src.size(0)
            train_clf_loss += loss_correct.item() * preds_src.size(0)
            train_clf_acc += torch.sum(
                preds_src.max(1)[1] == labels_src.data).double()

            # Optimize model
            loss.backward()
            optimizer.step()
            if epoch > 2:
                optimizer_critic.step()

            if ((step + 1) % params.log_step == 0):
                print(
                    "Epoch [{}/{}] Step [{}/{}] Avg total loss: {:.4f} Avg Transfer Loss: {:.4f}"
                    "  Avg Classification Loss: {:4f} "
                    "Avg Classification Accuracy: {:.4%}".format(
                        epoch + 1, params.num_epochs, step + 1,
                        len_data_loader, total_loss / train_n,
                        loss_target_value / train_n,
                        train_clf_loss / train_clf_n,
                        train_clf_acc / train_clf_n))

        time_elapsed = start_time - time.time()

        # Eval model
        if (epoch + 1) % params.eval_step == 0:
            eval_tgt_robust(encoder, classifier, tgt_data_loader_eval)

        # Save model parameters
        if (epoch + 1) % params.save_step == 0:
            print('Epoch [{}/{}] completed in {:.0f}m {:.0f}s'.format(
                epoch + 1, params.num_epochs, time_elapsed // 60,
                time_elapsed % 60))
            filename = "ALDA-encoder-{}.pt".format(epoch + 1) if not robust \
                else "ALDA-encoder-rb-{}.pt".format(epoch + 1)
            save_model(encoder, params.dann_root, filename)
            filename = "ALDA-classifier-{}.pt".format(epoch + 1) if not robust \
                else "ALDA-classifier-rb-{}.pt".format(epoch + 1)
            save_model(classifier, params.dann_root, filename)
            filename = "ALDA-critic-{}.pt".format(epoch + 1) if not robust \
                else "ALDA-critic-rb-{}.pt".format(epoch + 1)
            save_model(critic, params.dann_root, filename)

    # Save final model
    filename = "ALDA-encoder-final.pt" if not robust else "ALDA-encoder-rb-final.pt"
    save_model(encoder, params.dann_root, filename)
    filename = "ALDA-classifier-final.pt" if not robust else "ALDA-classifier-rb-final.pt"
    save_model(classifier, params.dann_root, filename)
    filename = "ALDA-critic-final.pt" if not robust else "ALDA-critic-rb-final.pt"
    save_model(critic, params.dann_root, filename)

    return encoder, classifier, critic
Ejemplo n.º 4
0
def train_dann(encoder,
               classifier,
               critic,
               src_data_loader,
               tgt_data_loader,
               tgt_data_loader_eval,
               robust=True):
    """Train encoder for DANN """

    # 1. Network Setup
    encoder.train()
    classifier.train()
    critic.train()

    # Set up optimizer and criterion
    optimizer = optim.Adam(list(encoder.parameters()) +
                           list(classifier.parameters()) +
                           list(critic.parameters()),
                           lr=params.learning_rate,
                           weight_decay=params.weight_decay)

    criterion = nn.CrossEntropyLoss()

    # 2. Train network
    for epoch in range(params.num_epochs):

        start_time = time.time()
        total_loss, train_n = 0, 0
        train_clf_loss, train_clf_acc, train_clf_n = 0, 0, 0
        train_domain_loss, train_domain_acc, train_domain_n = 0, 0, 0

        # Zip source and target data pair
        len_data_loader = min(len(src_data_loader), len(tgt_data_loader))
        data_zip = enumerate(zip(src_data_loader, tgt_data_loader))

        for step, ((images_src, labels_src), (images_tgt, _)) in data_zip:

            p = float(step + epoch * len_data_loader) / \
                params.num_epochs / len_data_loader
            alpha = 2. / (1. + np.exp(-10 * p)) - 1

            # Update lr
            # update_lr(optimizer, lr_scheduler(p))

            # Make images variable
            images_src = make_variable(images_src)
            images_tgt = make_variable(images_tgt)
            labels_src = make_variable(labels_src)
            images_concat = torch.cat((images_src, images_tgt), 0)

            # Prepare real and fake label (domain labels)
            domain_src = make_variable(torch.ones(images_src.size(0)).long())
            domain_tgt = make_variable(torch.zeros(images_tgt.size(0)).long())
            domain_concat = torch.cat((domain_src, domain_tgt), 0)

            # Zero gradients for optimizer
            optimizer.zero_grad()

            if robust:
                delta_src = attack_pgd(encoder, classifier, images_src,
                                       labels_src)
                delta_domain = attack_pgd(encoder, critic, images_concat,
                                          domain_concat)

                robust_src = normalize(
                    torch.clamp(images_src + delta_src[:images_src.size(0)],
                                min=params.lower_limit,
                                max=params.upper_limit))
                robust_domain = normalize(
                    torch.clamp(images_concat +
                                delta_domain[:images_concat.size(0)],
                                min=params.lower_limit,
                                max=params.upper_limit))

            # Train on source domain
            feats = encoder(images_src) if not robust else encoder(robust_src)
            preds_src = classifier(feats)
            feats = encoder(images_concat) if not robust else encoder(
                robust_domain)

            preds_domain = critic(feats, alpha=alpha)

            # Computer loss for source classification and domain classification
            loss_src = criterion(preds_src, labels_src)
            loss_domain = criterion(preds_domain, domain_concat)

            loss = loss_src + loss_domain

            train_clf_n += preds_src.size(0)
            train_domain_n += preds_domain.size(0)
            train_n += train_clf_n + train_domain_n

            total_loss += loss.item() * (preds_src.size(0) +
                                         preds_domain.size(0))
            train_clf_loss += loss_src.item() * preds_src.size(0)
            train_domain_loss += loss_domain.item()

            train_domain_acc += torch.sum(
                preds_domain.max(1)[1] == domain_concat.data).double()
            train_clf_acc += torch.sum(
                preds_src.max(1)[1] == labels_src.data).double()

            # Optimize model
            loss.backward()
            optimizer.step()

            if ((step + 1) % params.log_step == 0):
                print(
                    "Epoch [{}/{}] Step [{}/{}] Avg total loss: {:.4f} Avg Domain Loss: {:.4f}"
                    " Avg Domain Accuracy: {:.4%} Avg Classification Loss: {:4f} "
                    "Avg Classification Accuracy: {:.4%}".format(
                        epoch + 1, params.num_epochs, step + 1,
                        len_data_loader, total_loss / train_n,
                        train_domain_loss / train_domain_n,
                        train_domain_acc / train_domain_n,
                        train_clf_loss / train_clf_n,
                        train_clf_acc / train_clf_n))

        time_elapsed = start_time - time.time()

        # Eval model
        if (epoch + 1) % params.eval_step == 0:
            eval_tgt_robust(encoder, classifier, tgt_data_loader_eval)

        # Save model parameters
        if (epoch + 1) % params.save_step == 0:
            print('Epoch [{}/{}] completed in {:.0f}m {:.0f}s'.format(
                epoch + 1, params.num_epochs, time_elapsed // 60,
                time_elapsed % 60))
            filename = "DANN-encoder-{}.pt".format(epoch + 1) if not robust \
                else "DANN-encoder-rb-{}.pt".format(epoch + 1)
            save_model(encoder, params.dann_root, filename)
            filename = "DANN-classifier-{}.pt".format(epoch + 1) if not robust \
                else "DANN-classifier-rb-{}.pt".format(epoch + 1)
            save_model(classifier, params.dann_root, filename)
            filename = "DANN-critic-{}.pt".format(epoch + 1) if not robust \
                else "DANN-critic-rb-{}.pt".format(epoch + 1)
            save_model(critic, params.dann_root, filename)

    # Save final model
    filename = "DANN-encoder-final.pt" if not robust else "DANN-encoder-rb-final.pt"
    save_model(encoder, params.dann_root, filename)
    filename = "DANN-classifier-final.pt" if not robust else "DANN-classifier-rb-final.pt"
    save_model(classifier, params.dann_root, filename)
    filename = "DANN-critic-final.pt" if not robust else "DANN-critic-rb-final.pt"
    save_model(critic, params.dann_root, filename)

    return encoder, classifier, critic
Ejemplo n.º 5
0
def train_tgt_wdgrl(encoder,
                    classifier,
                    critic,
                    src_data_loader,
                    tgt_data_loader,
                    tgt_data_loader_eval,
                    robust=False):
    """Train encoder encoder for wdgrl """

    # Set state
    encoder.train()
    classifier.train()
    critic.train()

    # Setup criterion and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(list(encoder.parameters()) +
                           list(classifier.parameters()),
                           lr=params.learning_rate,
                           weight_decay=params.weight_decay)
    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=params.learning_rate,
                                  weight_decay=params.weight_decay)
    len_data_loader = min(len(src_data_loader), len(tgt_data_loader))

    # Step 2 Train network
    for epoch in range(params.num_epochs):
        train_acc, train_loss, total_n = 0, 0, 0
        start_time = time.time()
        # Zip source and target data pair
        data_zip = enumerate(zip(src_data_loader, tgt_data_loader))
        for step, ((images_src, labels_src), (images_tgt, _)) in data_zip:

            images_src = make_variable(images_src)
            images_tgt = make_variable(images_tgt)
            labels_src = make_variable(labels_src)

            if robust:
                # PDG attack on the source image
                delta_src = attack_pgd(encoder, classifier, images_src,
                                       labels_src)

                robust_src = normalize(
                    torch.clamp(images_src + delta_src[:images_src.size(0)],
                                min=params.lower_limit,
                                max=params.upper_limit))

            if robust:
                critic_loss = train_critic_wdgrl(encoder, critic,
                                                 critic_optimizer, robust_src,
                                                 images_tgt)
            else:
                critic_loss = train_critic_wdgrl(encoder, critic,
                                                 critic_optimizer, images_src,
                                                 images_tgt)

            feat_src = encoder(images_src) if not robust else encoder(
                robust_src)
            feat_tgt = encoder(images_tgt)

            preds_src = classifier(feat_src)
            clf_loss = criterion(preds_src, labels_src)
            wasserstein_distance = critic(feat_src).mean() - (
                1 + params.beta_ratio) * critic(feat_tgt).mean()

            loss = clf_loss + params.wd_clf * wasserstein_distance
            train_loss += loss.item() * labels_src.size(0) + critic_loss
            total_n += labels_src.size(0)
            train_acc += torch.sum(
                preds_src.max(1)[1] == labels_src.data).double()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if ((step + 1) % params.log_step == 0):
                print(
                    "Epoch [{}/{}] Step [{}/{}]: Avg Training loss: {:.4f} Ave Training Accuracy {:.4%}"
                    .format(epoch + 1, params.num_epochs, step + 1,
                            len_data_loader, train_loss / total_n,
                            train_acc / total_n))

        time_elapsed = time.time() - start_time

        # Eval model
        if (epoch + 1) % params.eval_step == 0:
            eval_tgt_robust(encoder, classifier, tgt_data_loader_eval)

        # 2.4 Save model parameters #
        if ((epoch + 1) % params.save_step == 0):
            print('Epoch [{}/{}] completed in {:.0f}m {:.0f}s'.format(
                epoch + 1, params.num_epochs, time_elapsed // 60,
                time_elapsed % 60))
            filename = "WDGRL-encoder-{}.pt".format(epoch + 1) if not robust \
                else "WDGRL-encoder-rb-{}.pt".format(epoch + 1)
            save_model(encoder, params.wdgrl_root, filename)

            filename = "WDGRL-classifier-{}.pt".format(epoch + 1) if not robust \
                else "WDGRL-classifier-rb-{}.pt".format(epoch + 1)
            save_model(classifier, params.wdgrl_root, filename)

    filename = "WDGRL-classifier-final.pt" if not robust else "WDGRL-classifier-rb-final.pt"
    save_model(classifier, params.wdgrl_root, filename)

    filename = "WDGRL-encoder-final.pt" if not robust else "WDGRL-encoder-rb-final.pt"
    save_model(encoder, params.wdgrl_root, filename)

    return encoder, classifier
Ejemplo n.º 6
0
def train_tgt_adda(src_encoder,
                   tgt_encoder,
                   classifier,
                   critic,
                   src_data_loader,
                   tgt_data_loader,
                   tgt_data_loader_eval,
                   robust=False):
    """Train adda encoder for target domain """

    # Step 1: Network Setup
    # Set train state for Dropout and BN layers
    src_encoder.eval()
    tgt_encoder.train()
    critic.train()

    # Setup criterion and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer_tgt = optim.Adam(tgt_encoder.parameters(),
                               lr=params.learning_rate,
                               weight_decay=params.weight_decay)
    optimizer_critic = optim.Adam(critic.parameters(),
                                  lr=params.learning_rate,
                                  weight_decay=params.weight_decay)

    len_data_loader = min(len(src_data_loader), len(tgt_data_loader))

    # Step 2 Train network
    for epoch in range(params.num_epochs):

        start_time = time.time()
        train_disc_loss, train_disc_acc, train_n = 0, 0, 0
        # Zip source and target data pair
        data_zip = enumerate(zip(src_data_loader, tgt_data_loader))

        for step, ((images_src, _), (images_tgt, _)) in data_zip:

            # 2.1 train discriminator with fixed src_encoder
            # Make images variable
            images_src = make_variable(images_src)
            images_tgt = make_variable(images_tgt)

            # Prepare real and fake label (domain labels)
            domain_src = make_variable(torch.ones(images_src.size(0)).long())
            domain_tgt = make_variable(torch.zeros(images_tgt.size(0)).long())
            domain_concat = torch.cat((domain_src, domain_tgt), 0)

            if robust:
                # Attack images with domain labels
                delta_src = attack_pgd(src_encoder, critic, images_src,
                                       domain_src)
                delta_tgt = attack_pgd(tgt_encoder, critic, images_tgt,
                                       domain_tgt)

                robust_src = normalize(
                    torch.clamp(images_src + delta_src[:images_src.size(0)],
                                min=params.lower_limit,
                                max=params.upper_limit))
                robust_tgt = normalize(
                    torch.clamp(images_tgt + delta_tgt[:images_tgt.size(0)],
                                min=params.lower_limit,
                                max=params.upper_limit))

            # Zero gradients for optimizer for the discriminator
            optimizer_critic.zero_grad()

            # Extract and concat features
            feat_src = src_encoder(images_src) if not robust else src_encoder(
                robust_src)
            feat_tgt = tgt_encoder(images_tgt) if not robust else tgt_encoder(
                robust_tgt)
            feat_concat = torch.cat((feat_src, feat_tgt), 0)

            # Predict on discriminator
            preds_src_domain = critic(feat_src)
            preds_tgt_domain = critic(feat_tgt)
            # pred_concat = critic(feat_concat)

            # Compute loss for critic
            l1 = criterion(preds_src_domain, domain_src)
            l2 = criterion(preds_tgt_domain, domain_tgt)
            # loss_critic = criterion((pred_concat, domain_concat)
            loss_critic = l1 + l2
            train_disc_loss += loss_critic.item() * domain_concat.size(0)
            # train_disc_acc += torch.sum(pred_concat.max(1)[1] == domain_concat.data).double()
            train_disc_acc += torch.sum(
                preds_src_domain.max(1)[1] == domain_src.data).double()
            train_disc_acc += torch.sum(
                preds_tgt_domain.max(1)[1] == domain_tgt.data).double()
            train_n += domain_concat.size(0)
            loss_critic.backward()
            # Optimize critic
            optimizer_critic.step()

            # 2.2 Train target encoder
            # Zero gradients for optimizer
            optimizer_critic.zero_grad()
            optimizer_tgt.zero_grad()

            # Prepare fake labels (flip labels)
            domain_tgt = make_variable(torch.ones(feat_tgt.size(0)).long())

            if robust:
                # Attack the target images with domain labels
                delta_tgt = attack_pgd(tgt_encoder, critic, images_tgt,
                                       domain_tgt)
                robust_tgt = normalize(
                    torch.clamp(images_tgt + delta_tgt[:images_tgt.size(0)],
                                min=params.lower_limit,
                                max=params.upper_limit))

            # Extract target features
            feat_tgt = tgt_encoder(images_tgt) if not robust else tgt_encoder(
                robust_tgt)

            # Predict on discriminator
            pred_tgt = critic(feat_tgt)
            # Compute loss for target encoder
            loss_tgt = criterion(pred_tgt, domain_tgt)
            loss_tgt.backward()

            # Optimize target encoder
            optimizer_tgt.step()

            # 2.3 Print step info
            if (step + 1) % params.log_step == 0:
                print(
                    "Epoch [{}/{}] Step [{}/{}]: "
                    "Avg Discriminator Loss: {:.4f} Avg Discriminator Accuracy: {:.4%}"
                    .format(epoch + 1, params.num_epochs, step + 1,
                            len_data_loader, train_disc_loss / train_n,
                            train_disc_acc / train_n))

        time_elapsed = time.time() - start_time

        # Eval model
        if (epoch + 1) % params.eval_step == 0:
            if not robust:
                eval_tgt(tgt_encoder, classifier, tgt_data_loader_eval)
            else:
                eval_tgt_robust(tgt_encoder, classifier, critic,
                                tgt_data_loader_eval)

        # 2.4 Save model parameters #
        if (epoch + 1) % params.save_step == 0:
            print('Epoch [{}/{}] completec in {:.0f}m {:.0f}s'.format(
                epoch + 1, params.num_epochs, time_elapsed // 60,
                time_elapsed % 60))
            filename = "ADDA-critic-{}.pt".format(epoch + 1) if not robust \
                else "ADDA-critic-rb-{}.pt".format(epoch + 1)
            save_model(critic, params.adda_root, filename)

            filename = "ADDA-target-encoder-{}.pt".format(epoch + 1) if not robust \
                else "ADDA-target-encoder-rb-{}.pt".format(epoch + 1)
            save_model(tgt_encoder, params.adda_root, filename)

    filename = "ADDA-critic-final.pt" if not robust else "ADDA-critic-rb-final.pt"
    save_model(critic, params.adda_root, filename)

    filename = "ADDA-target-encoder-final.pt" if not robust else "ADDA-target-encoder-rb-final.pt"
    save_model(tgt_encoder, params.adda_root, filename)

    return tgt_encoder