コード例 #1
0
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    svhn_data_loader = get_svhn(split='train', download=True)
    svhn_data_loader_eval = get_svhn(split='test', download=True)
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)

    # Model init WDGRL
    tgt_encoder = model_init(Encoder(), params.encoder_wdgrl_path)
    critic = model_init(Discriminator(in_dims=params.d_in_dims,
                                      h_dims=params.d_h_dims,
                                      out_dims=params.d_out_dims),
                                        params.disc_wdgrl_path)
    clf = model_init(Classifier(), params.clf_wdgrl_path)

    # Train critic to optimality
    print("====== Training critic ======")
    if not (critic.pretrained and params.model_trained):
        critic = train_critic_wdgrl(tgt_encoder, critic, svhn_data_loader, mnist_data_loader)

    # Train target encoder
    print("====== Training encoder for both SVHN and MNIST domains ======")
    if not (tgt_encoder.pretrained and clf.pretrained and params.model_trained):
        tgt_encoder, clf = train_tgt_wdgrl(tgt_encoder, clf, critic,
                                     svhn_data_loader, mnist_data_loader, robust=False)

    # Eval target encoder on test set of target dataset
    print("====== Evaluating classifier for encoded SVHN and MNIST domains ======")
    print("-------- SVHN domain --------")
    eval_tgt(tgt_encoder, clf, svhn_data_loader_eval)
    print("-------- MNIST adaption --------")
    eval_tgt(tgt_encoder, clf, mnist_data_loader_eval)
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)
    usps_data_loader = get_usps(train=True, download=True)
    usps_data_loader_eval = get_usps(train=False, download=True)

    # Model init Revgard
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_revgrad_path)
    critic = model_init(Discriminator(), params.disc_revgard_path)
    clf = model_init(Classifier(), params.clf_revgrad_path)

    # Train models
    print("====== Training source encoder and classifier in MNIST and USPS domains ======")
    if not (tgt_encoder.pretrained and clf.pretrained and critic.pretrained and params.model_trained):
        tgt_encoder, clf, critic = train_revgrad(tgt_encoder, clf, critic,
                                                 mnist_data_loader, usps_data_loader, robust=False)

    # Eval target encoder on test set of target dataset
    print("====== Evaluating classifier for encoded MNIST and USPS domain ======")
    print("-------- MNIST domain --------")
    eval_tgt(tgt_encoder, clf, mnist_data_loader_eval)
    print("-------- USPS adaption --------")
    eval_tgt(tgt_encoder, clf, usps_data_loader_eval)
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)
    usps_data_loader = get_usps(train=True, download=True)
    usps_data_loader_eval = get_usps(train=False, download=True)

    # Model init WDGRL
    tgt_encoder = model_init(Encoder(), params.encoder_wdgrl_rb_path)
    critic = model_init(Discriminator(), params.disc_wdgrl_rb_path)
    clf = model_init(Classifier(), params.clf_wdgrl_rb_path)

    # Train target encoder
    print("====== Robust Training encoder for both MNIST and USPS domains ======")
    if not (tgt_encoder.pretrained and clf.pretrained and params.model_trained):
        tgt_encoder, clf = train_tgt_wdgrl(tgt_encoder, clf, critic,
                                           mnist_data_loader, usps_data_loader, usps_data_loader_eval, robust=True)

    # Eval target encoder on test set of target dataset
    print("====== Evaluating classifier for encoded MNIST and USPS domains ======")
    print("-------- MNIST domain --------")
    eval_tgt_robust(tgt_encoder, clf, mnist_data_loader_eval)
    print("-------- USPS adaption --------")
    eval_tgt_robust(tgt_encoder, clf, usps_data_loader_eval)
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    mnist_data_loader = get_usps(train=True, download=True)
    mnist_data_loader_eval = get_usps(train=False, download=True)
    usps_data_loader = get_usps(train=True, download=True)
    usps_data_loader_eval = get_usps(train=False, download=True)

    # Model init ADDA
    src_encoder = model_init(Encoder(), params.src_encoder_adda_rb_path)
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_adda_rb_path)
    critic = model_init(Discriminator(), params.disc_adda_rb_path)
    clf = model_init(Classifier(), params.clf_adda_rb_path)

    # Train source model for adda
    print(
        "====== Robust training source encoder and classifier in MNIST domain ======"
    )
    if not (src_encoder.pretrained and clf.pretrained
            and params.model_trained):
        src_encoder, clf = train_src_robust(src_encoder, clf,
                                            mnist_data_loader)

    # Eval source model
    print("====== Evaluating classifier for MNIST domain ======")
    eval_tgt(src_encoder, clf, mnist_data_loader_eval)

    # Train target encoder
    print("====== Robust training encoder for USPS domain ======")
    # Initialize target encoder's weights with those of the source encoder
    if not tgt_encoder.pretrained:
        tgt_encoder.load_state_dict(src_encoder.state_dict())

    if not (tgt_encoder.pretrained and critic.pretrained
            and params.model_trained):
        tgt_encoder = train_tgt_adda(src_encoder,
                                     tgt_encoder,
                                     clf,
                                     critic,
                                     mnist_data_loader,
                                     usps_data_loader,
                                     usps_data_loader_eval,
                                     robust=True)

    # Eval target encoder on test set of target dataset
    print("====== Ealuating classifier for encoded USPS domain ======")
    print("-------- Source only --------")
    eval_tgt(src_encoder, clf, usps_data_loader_eval)
    print("-------- Domain adaption --------")
    eval_tgt(tgt_encoder, clf, usps_data_loader_eval)
コード例 #5
0
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    svhn_data_loader = get_svhn(split='train', download=True)
    svhn_data_loader_eval = get_svhn(split='test', download=True)
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)

    # Model init DANN
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_dann_rb_path)
    critic = model_init(
        Discriminator(in_dims=params.d_in_dims,
                      h_dims=params.d_h_dims,
                      out_dims=params.d_out_dims), params.disc_dann_rb_path)
    clf = model_init(Classifier(), params.clf_dann_rb_path)

    # Train models
    print(
        "====== Training source encoder and classifier in SVHN and MNIST domains ======"
    )
    if not (tgt_encoder.pretrained and clf.pretrained and critic.pretrained
            and params.model_trained):
        tgt_encoder, clf, critic = train_dann(tgt_encoder,
                                              clf,
                                              critic,
                                              svhn_data_loader,
                                              mnist_data_loader,
                                              mnist_data_loader_eval,
                                              robust=True)

    # Eval target encoder on test set of target dataset
    print(
        "====== Evaluating classifier for encoded SVHN and MNIST domains ======"
    )
    print("-------- SVHN domain --------")
    eval_tgt_robust(tgt_encoder, clf, svhn_data_loader_eval)
    print("-------- MNIST adaption --------")
    eval_tgt_robust(tgt_encoder, clf, mnist_data_loader_eval)
コード例 #6
0
    num_epochs = 1000
    log_step = 10  # iters
    save_step = 500
    eval_step = 1  # epochs

    manual_seed = 8888
    alpha = 0

    # params for optimizing models
    lr = 2e-4


params = Config()

# init random seed
init_random_seed(params.manual_seed)

# load dataset
src_data_loader = get_data_loader(params.src_dataset, params.dataset_root,
                                  params.batch_size)
tgt_data_loader = get_data_loader(params.tgt_dataset, params.dataset_root,
                                  params.batch_size)

# load dann model
dann = init_model(net=AlexModel(), restore=None)

# train dann model
print("Start training dann model.")

if not (dann.restored and params.dann_restore):
    dann = train_dann(dann, params, src_data_loader, tgt_data_loader,
コード例 #7
0
    model_restore = os.path.join(
        model_root, src_dataset + "-final.pt")

    # params for training network
    num_gpu = 1
    num_epochs = 200
    log_step = 10  # iter
    eval_step = 1  # epoch
    save_step = 500
    manual_seed = None

params = Config()

if __name__ == '__main__':
    # init random seed
    init_random_seed(params.manual_seed)

    # init device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # load dataset
    src_data_loader = get_data_loader(
        params.src_dataset, dataset_root=params.dataset_root, batch_size=params.batch_size, train=True)
    src_data_loader_eval = get_data_loader(
        params.src_dataset, dataset_root=params.dataset_root, batch_size=params.batch_size_eval, train=False)
    tgt_data_loader = get_data_loader(
        params.tgt_dataset, dataset_root=params.dataset_root, batch_size=params.batch_size, train=True)
    tgt_data_loader_eval = get_data_loader(
        params.tgt_dataset, dataset_root=params.dataset_root, batch_size=params.batch_size_eval, train=False)

    # load models
def main():
    # init random seed
    init_random_seed(params.manual_seed)

    # Load dataset
    mnist_data_loader = get_mnist(train=True, download=True)
    mnist_data_loader_eval = get_mnist(train=False, download=True)
    usps_data_loader = get_usps(train=True, download=True)
    usps_data_loader_eval = get_usps(train=False, download=True)

    # Model init DANN
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_dann_rb_path)
    critic = model_init(Discriminator(), params.disc_dann_rb_path)
    clf = model_init(Classifier(), params.clf_dann_rb_path)

    # Train models
    print(
        "====== Robust Training source encoder and classifier in MNIST and USPS domains ======"
    )
    if not (tgt_encoder.pretrained and clf.pretrained and critic.pretrained
            and params.model_trained):
        tgt_encoder, clf, critic = train_dann(tgt_encoder,
                                              clf,
                                              critic,
                                              mnist_data_loader,
                                              usps_data_loader,
                                              usps_data_loader_eval,
                                              robust=False)

    # Eval target encoder on test set of target dataset
    print(
        "====== Evaluating classifier for encoded MNIST and USPS domains ======"
    )
    print("-------- MNIST domain --------")
    eval_tgt_robust(tgt_encoder, clf, critic, mnist_data_loader_eval)
    print("-------- USPS adaption --------")
    eval_tgt_robust(tgt_encoder, clf, critic, usps_data_loader_eval)

    print("====== Pseudo labeling on USPS domain ======")
    pseudo_label(tgt_encoder, clf, "usps_train_pseudo", usps_data_loader)

    # Init a new model
    tgt_encoder = model_init(Encoder(), params.tgt_encoder_path)
    clf = model_init(Classifier(), params.clf_path)

    # Load pseudo labeled dataset
    usps_pseudo_loader = get_usps(train=True, download=True, get_pseudo=True)

    print("====== Standard training on USPS domain with pseudo labels ======")
    if not (tgt_encoder.pretrained and clf.pretrained):
        train_src_adda(tgt_encoder, clf, usps_pseudo_loader, mode='ADV')
    print("====== Evaluating on USPS domain with real labels ======")
    eval_tgt(tgt_encoder, clf, usps_data_loader_eval)

    tgt_encoder = model_init(Encoder(), params.tgt_encoder_rb_path)
    clf = model_init(Classifier(), params.clf_rb_path)
    print("====== Robust training on USPS domain with pseudo labels ======")
    if not (tgt_encoder.pretrained and clf.pretrained):
        train_src_robust(tgt_encoder, clf, usps_pseudo_loader, mode='ADV')
    print("====== Evaluating on USPS domain with real labels ======")
    eval_tgt(tgt_encoder, clf, usps_data_loader_eval)