예제 #1
0
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    svhn_data_loader = get_svhn(split='train', download=True)
    svhn_data_loader_eval = get_svhn(split='test', download=True)
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)

    # Model init WDGRL
    tgt_encoder = model_init(Encoder(), params.encoder_wdgrl_path)
    critic = model_init(Discriminator(in_dims=params.d_in_dims,
                                      h_dims=params.d_h_dims,
                                      out_dims=params.d_out_dims),
                                        params.disc_wdgrl_path)
    clf = model_init(Classifier(), params.clf_wdgrl_path)

    # Train critic to optimality
    print("====== Training critic ======")
    if not (critic.pretrained and params.model_trained):
        critic = train_critic_wdgrl(tgt_encoder, critic, svhn_data_loader, mnist_data_loader)

    # Train target encoder
    print("====== Training encoder for both SVHN and MNIST domains ======")
    if not (tgt_encoder.pretrained and clf.pretrained and params.model_trained):
        tgt_encoder, clf = train_tgt_wdgrl(tgt_encoder, clf, critic,
                                     svhn_data_loader, mnist_data_loader, robust=False)

    # Eval target encoder on test set of target dataset
    print("====== Evaluating classifier for encoded SVHN and MNIST domains ======")
    print("-------- SVHN domain --------")
    eval_tgt(tgt_encoder, clf, svhn_data_loader_eval)
    print("-------- MNIST adaption --------")
    eval_tgt(tgt_encoder, clf, mnist_data_loader_eval)
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)
    usps_data_loader = get_usps(train=True, download=True)
    usps_data_loader_eval = get_usps(train=False, download=True)

    # Model init Revgard
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_revgrad_path)
    critic = model_init(Discriminator(), params.disc_revgard_path)
    clf = model_init(Classifier(), params.clf_revgrad_path)

    # Train models
    print("====== Training source encoder and classifier in MNIST and USPS domains ======")
    if not (tgt_encoder.pretrained and clf.pretrained and critic.pretrained and params.model_trained):
        tgt_encoder, clf, critic = train_revgrad(tgt_encoder, clf, critic,
                                                 mnist_data_loader, usps_data_loader, robust=False)

    # Eval target encoder on test set of target dataset
    print("====== Evaluating classifier for encoded MNIST and USPS domain ======")
    print("-------- MNIST domain --------")
    eval_tgt(tgt_encoder, clf, mnist_data_loader_eval)
    print("-------- USPS adaption --------")
    eval_tgt(tgt_encoder, clf, usps_data_loader_eval)
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    svhn_data_loader = get_svhn(split='train', download=True)
    svhn_data_loader_eval = get_svhn(split='test', download=True)
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)

    # Model init DANN
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_dann_rb_path)
    critic = model_init(
        Discriminator(in_dims=params.d_in_dims,
                      h_dims=params.d_h_dims,
                      out_dims=params.d_out_dims), params.disc_dann_path)
    clf = model_init(Classifier(), params.clf_dann_path)

    # Train models
    print(
        "====== Training source encoder and classifier in SVHN and MNIST domains ======"
    )
    if not (tgt_encoder.pretrained and clf.pretrained and critic.pretrained
            and params.model_trained):
        tgt_encoder, clf, critic = train_dann(tgt_encoder,
                                              clf,
                                              critic,
                                              svhn_data_loader,
                                              mnist_data_loader,
                                              mnist_data_loader_eval,
                                              robust=False)

    # Eval target encoder on test set of target dataset
    print(
        "====== Evaluating classifier for encoded SVHN and MNIST domain ======"
    )
    print("-------- SVHN domain --------")
    eval_tgt(tgt_encoder, clf, svhn_data_loader_eval)
    print("-------- MNIST adaption --------")
    eval_tgt(tgt_encoder, clf, mnist_data_loader_eval)
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    mnist_data_loader = get_usps(train=True, download=True)
    mnist_data_loader_eval = get_usps(train=False, download=True)
    usps_data_loader = get_usps(train=True, download=True)
    usps_data_loader_eval = get_usps(train=False, download=True)

    # Model init ADDA
    src_encoder = model_init(Encoder(), params.src_encoder_adda_rb_path)
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_adda_rb_path)
    critic = model_init(Discriminator(), params.disc_adda_rb_path)
    clf = model_init(Classifier(), params.clf_adda_rb_path)

    # Train source model for adda
    print(
        "====== Robust training source encoder and classifier in MNIST domain ======"
    )
    if not (src_encoder.pretrained and clf.pretrained
            and params.model_trained):
        src_encoder, clf = train_src_robust(src_encoder, clf,
                                            mnist_data_loader)

    # Eval source model
    print("====== Evaluating classifier for MNIST domain ======")
    eval_tgt(src_encoder, clf, mnist_data_loader_eval)

    # Train target encoder
    print("====== Robust training encoder for USPS domain ======")
    # Initialize target encoder's weights with those of the source encoder
    if not tgt_encoder.pretrained:
        tgt_encoder.load_state_dict(src_encoder.state_dict())

    if not (tgt_encoder.pretrained and critic.pretrained
            and params.model_trained):
        tgt_encoder = train_tgt_adda(src_encoder,
                                     tgt_encoder,
                                     clf,
                                     critic,
                                     mnist_data_loader,
                                     usps_data_loader,
                                     usps_data_loader_eval,
                                     robust=True)

    # Eval target encoder on test set of target dataset
    print("====== Ealuating classifier for encoded USPS domain ======")
    print("-------- Source only --------")
    eval_tgt(src_encoder, clf, usps_data_loader_eval)
    print("-------- Domain adaption --------")
    eval_tgt(tgt_encoder, clf, usps_data_loader_eval)
예제 #5
0
def train_src_robust(encoder, classifier, data_loader, mode='ADDA'):
    """Train classifier for source domain with robust training for ADDA"""

    # Step 1: Network setup
    # Set train state for both Dropout and BN layers
    encoder.train()
    classifier.train()

    # Set up optimizer and criterion
    optimizer = optim.Adam(list(encoder.parameters()) +
                           list(classifier.parameters()),
                           lr=params.learning_rate,
                           weight_decay=params.weight_decay)

    criterion = nn.CrossEntropyLoss()
    num_epochs = params.num_epochs_pre if mode == 'ADDA' else params.num_epochs
    # Step 2: Pretrain the source model
    for epoch in range(num_epochs):

        # Init accuracy and loss
        start_time = time.time()
        train_loss, train_acc, train_n = 0, 0, 0
        train_robust_loss, train_robust_acc = 0, 0

        for step, (images, labels) in enumerate(data_loader):

            # Make images and labels variable
            images = make_variable(images)
            labels = make_variable(labels)

            # Zero gradients for optimizer
            optimizer.zero_grad()

            delta = attack_pgd(encoder, classifier, images, labels)

            # Compute loss for critic with attack img
            robust_images = normalize(
                torch.clamp(images + delta[:images.size(0)],
                            min=params.lower_limit,
                            max=params.upper_limit))
            robust_preds = classifier(encoder(robust_images))
            robust_loss = criterion(robust_preds, labels)

            # Optimize source classifier
            robust_loss.backward()
            optimizer.step()

            # Compute loss for critic with original image
            preds = classifier(encoder(images))
            loss = criterion(preds, labels)

            train_robust_loss += robust_loss.item() * labels.size(0)
            train_robust_acc += torch.sum(
                robust_preds.max(1)[1] == labels).double()
            train_loss += loss.item() * labels.size(0)
            train_acc += torch.sum(preds.max(1)[1] == labels.data).double()
            train_n += labels.size(0)

            # Print step info
            if (step + 1) % params.log_step_pre == 0:
                print(
                    "Epoch [{}/{}] Step [{}/{}]: Avg Training loss: {:.4f} Avg Training Accuracy: {:.4%}"
                    " Avg Robust Training Loss: {:.4f} Avg Robust Training Accuracy: {:.4%}"
                    .format(epoch + 1, num_epochs, step + 1, len(data_loader),
                            train_loss / train_n, train_acc / train_n,
                            train_robust_loss / train_n,
                            train_robust_acc / train_n))

        time_elapsed = time.time() - start_time

        # Eval model on test set
        if (epoch + 1) % params.eval_step_pre == 0:
            eval_tgt(encoder, classifier, data_loader)

        # Save model parameters
        if (epoch + 1) % params.save_step_pre == 0:
            print('Epoch [{}/{}] completed in {:.0f}m {:.0f}s'.format(
                epoch + 1, num_epochs, time_elapsed // 60, time_elapsed % 60))
            root = params.adda_root if mode == 'ADDA' else params.model_root
            save_model(encoder, root,
                       "{}-source-encoder-rb-{}.pt".format(mode, epoch + 1))
            save_model(classifier, root,
                       "{}-source-classifier-rb-{}.pt".format(mode, epoch + 1))

    # Save final model
    root = params.adda_root if mode == 'ADDA' else params.model_root

    save_model(encoder, root, "{}-source-encoder-rb-final.pt".format(mode))
    save_model(classifier, root,
               "{}-source-classifier-rb-final.pt".format(mode))

    return encoder, classifier
예제 #6
0
def train_tgt_adda(src_encoder,
                   tgt_encoder,
                   classifier,
                   critic,
                   src_data_loader,
                   tgt_data_loader,
                   tgt_data_loader_eval,
                   robust=False):
    """Train adda encoder for target domain """

    # Step 1: Network Setup
    # Set train state for Dropout and BN layers
    src_encoder.eval()
    tgt_encoder.train()
    critic.train()

    # Setup criterion and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer_tgt = optim.Adam(tgt_encoder.parameters(),
                               lr=params.learning_rate,
                               weight_decay=params.weight_decay)
    optimizer_critic = optim.Adam(critic.parameters(),
                                  lr=params.learning_rate,
                                  weight_decay=params.weight_decay)

    len_data_loader = min(len(src_data_loader), len(tgt_data_loader))

    # Step 2 Train network
    for epoch in range(params.num_epochs):

        start_time = time.time()
        train_disc_loss, train_disc_acc, train_n = 0, 0, 0
        # Zip source and target data pair
        data_zip = enumerate(zip(src_data_loader, tgt_data_loader))

        for step, ((images_src, _), (images_tgt, _)) in data_zip:

            # 2.1 train discriminator with fixed src_encoder
            # Make images variable
            images_src = make_variable(images_src)
            images_tgt = make_variable(images_tgt)

            # Prepare real and fake label (domain labels)
            domain_src = make_variable(torch.ones(images_src.size(0)).long())
            domain_tgt = make_variable(torch.zeros(images_tgt.size(0)).long())
            domain_concat = torch.cat((domain_src, domain_tgt), 0)

            if robust:
                # Attack images with domain labels
                delta_src = attack_pgd(src_encoder, critic, images_src,
                                       domain_src)
                delta_tgt = attack_pgd(tgt_encoder, critic, images_tgt,
                                       domain_tgt)

                robust_src = normalize(
                    torch.clamp(images_src + delta_src[:images_src.size(0)],
                                min=params.lower_limit,
                                max=params.upper_limit))
                robust_tgt = normalize(
                    torch.clamp(images_tgt + delta_tgt[:images_tgt.size(0)],
                                min=params.lower_limit,
                                max=params.upper_limit))

            # Zero gradients for optimizer for the discriminator
            optimizer_critic.zero_grad()

            # Extract and concat features
            feat_src = src_encoder(images_src) if not robust else src_encoder(
                robust_src)
            feat_tgt = tgt_encoder(images_tgt) if not robust else tgt_encoder(
                robust_tgt)
            feat_concat = torch.cat((feat_src, feat_tgt), 0)

            # Predict on discriminator
            preds_src_domain = critic(feat_src)
            preds_tgt_domain = critic(feat_tgt)
            # pred_concat = critic(feat_concat)

            # Compute loss for critic
            l1 = criterion(preds_src_domain, domain_src)
            l2 = criterion(preds_tgt_domain, domain_tgt)
            # loss_critic = criterion((pred_concat, domain_concat)
            loss_critic = l1 + l2
            train_disc_loss += loss_critic.item() * domain_concat.size(0)
            # train_disc_acc += torch.sum(pred_concat.max(1)[1] == domain_concat.data).double()
            train_disc_acc += torch.sum(
                preds_src_domain.max(1)[1] == domain_src.data).double()
            train_disc_acc += torch.sum(
                preds_tgt_domain.max(1)[1] == domain_tgt.data).double()
            train_n += domain_concat.size(0)
            loss_critic.backward()
            # Optimize critic
            optimizer_critic.step()

            # 2.2 Train target encoder
            # Zero gradients for optimizer
            optimizer_critic.zero_grad()
            optimizer_tgt.zero_grad()

            # Prepare fake labels (flip labels)
            domain_tgt = make_variable(torch.ones(feat_tgt.size(0)).long())

            if robust:
                # Attack the target images with domain labels
                delta_tgt = attack_pgd(tgt_encoder, critic, images_tgt,
                                       domain_tgt)
                robust_tgt = normalize(
                    torch.clamp(images_tgt + delta_tgt[:images_tgt.size(0)],
                                min=params.lower_limit,
                                max=params.upper_limit))

            # Extract target features
            feat_tgt = tgt_encoder(images_tgt) if not robust else tgt_encoder(
                robust_tgt)

            # Predict on discriminator
            pred_tgt = critic(feat_tgt)
            # Compute loss for target encoder
            loss_tgt = criterion(pred_tgt, domain_tgt)
            loss_tgt.backward()

            # Optimize target encoder
            optimizer_tgt.step()

            # 2.3 Print step info
            if (step + 1) % params.log_step == 0:
                print(
                    "Epoch [{}/{}] Step [{}/{}]: "
                    "Avg Discriminator Loss: {:.4f} Avg Discriminator Accuracy: {:.4%}"
                    .format(epoch + 1, params.num_epochs, step + 1,
                            len_data_loader, train_disc_loss / train_n,
                            train_disc_acc / train_n))

        time_elapsed = time.time() - start_time

        # Eval model
        if (epoch + 1) % params.eval_step == 0:
            if not robust:
                eval_tgt(tgt_encoder, classifier, tgt_data_loader_eval)
            else:
                eval_tgt_robust(tgt_encoder, classifier, critic,
                                tgt_data_loader_eval)

        # 2.4 Save model parameters #
        if (epoch + 1) % params.save_step == 0:
            print('Epoch [{}/{}] completec in {:.0f}m {:.0f}s'.format(
                epoch + 1, params.num_epochs, time_elapsed // 60,
                time_elapsed % 60))
            filename = "ADDA-critic-{}.pt".format(epoch + 1) if not robust \
                else "ADDA-critic-rb-{}.pt".format(epoch + 1)
            save_model(critic, params.adda_root, filename)

            filename = "ADDA-target-encoder-{}.pt".format(epoch + 1) if not robust \
                else "ADDA-target-encoder-rb-{}.pt".format(epoch + 1)
            save_model(tgt_encoder, params.adda_root, filename)

    filename = "ADDA-critic-final.pt" if not robust else "ADDA-critic-rb-final.pt"
    save_model(critic, params.adda_root, filename)

    filename = "ADDA-target-encoder-final.pt" if not robust else "ADDA-target-encoder-rb-final.pt"
    save_model(tgt_encoder, params.adda_root, filename)

    return tgt_encoder
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)
    usps_data_loader = get_usps(train=True, download=True)
    usps_data_loader_eval = get_usps(train=False, download=True)

    # Model init DANN
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_dann_rb_path)
    critic = model_init(Discriminator(), params.disc_dann_rb_path)
    clf = model_init(Classifier(), params.clf_dann_rb_path)

    # Train models
    print(
        "====== Robust Training source encoder and classifier in MNIST and USPS domains ======"
    )
    if not (tgt_encoder.pretrained and clf.pretrained and critic.pretrained
            and params.model_trained):
        tgt_encoder, clf, critic = train_dann(tgt_encoder,
                                              clf,
                                              critic,
                                              mnist_data_loader,
                                              usps_data_loader,
                                              usps_data_loader_eval,
                                              robust=False)

    # Eval target encoder on test set of target dataset
    print(
        "====== Evaluating classifier for encoded MNIST and USPS domains ======"
    )
    print("-------- MNIST domain --------")
    eval_tgt_robust(tgt_encoder, clf, critic, mnist_data_loader_eval)
    print("-------- USPS adaption --------")
    eval_tgt_robust(tgt_encoder, clf, critic, usps_data_loader_eval)

    print("====== Pseudo labeling on USPS domain ======")
    pseudo_label(tgt_encoder, clf, "usps_train_pseudo", usps_data_loader)

    # Init a new model
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_path)
    clf = model_init(Classifier(), params.clf_path)

    # Load pseudo labeled dataset
    usps_pseudo_loader = get_usps(train=True, download=True, get_pseudo=True)

    print("====== Standard training on USPS domain with pseudo labels ======")
    if not (tgt_encoder.pretrained and clf.pretrained):
        train_src_adda(tgt_encoder, clf, usps_pseudo_loader, mode='ADV')
    print("====== Evaluating on USPS domain with real labels ======")
    eval_tgt(tgt_encoder, clf, usps_data_loader_eval)

    tgt_encoder = model_init(Encoder(), params.tgt_encoder_rb_path)
    clf = model_init(Classifier(), params.clf_rb_path)
    print("====== Robust training on USPS domain with pseudo labels ======")
    if not (tgt_encoder.pretrained and clf.pretrained):
        train_src_robust(tgt_encoder, clf, usps_pseudo_loader, mode='ADV')
    print("====== Evaluating on USPS domain with real labels ======")
    eval_tgt(tgt_encoder, clf, usps_data_loader_eval)