Beispiel #1
0
def main(args):
    # prepare the source data and target data

    src_train_dataloader = utils.get_train_loader('MNIST')
    src_test_dataloader = utils.get_test_loader('MNIST')
    tgt_train_dataloader = utils.get_train_loader('MNIST_M')
    tgt_test_dataloader = utils.get_test_loader('MNIST_M')

    if args.plot:
        print('Images from training on source domain:')
        utils.displayImages(src_train_dataloader)

        print('Images from test on target domain:')
        utils.displayImages(tgt_test_dataloader)

    # init models
    feature_extractor = models.Extractor()
    class_classifier = models.Class_classifier()
    domain_classifier = models.Domain_classifier()

    if params.use_gpu:
        feature_extractor.cuda()
        class_classifier.cuda()
        domain_classifier.cuda()

    # init criterions
    class_criterion = nn.NLLLoss()
    domain_criterion = nn.NLLLoss()

    # init optimizer
    optimizer = optim.SGD([{
        'params': feature_extractor.parameters()
    }, {
        'params': class_classifier.parameters()
    }, {
        'params': domain_classifier.parameters()
    }],
                          lr=0.01,
                          momentum=0.9)

    for epoch in range(params.epochs):
        print('Epoch: {}'.format(epoch))
        train.train(args.training_mode, feature_extractor, class_classifier,
                    domain_classifier, class_criterion, domain_criterion,
                    src_train_dataloader, tgt_train_dataloader, optimizer,
                    epoch)
        test.test(feature_extractor, class_classifier, domain_classifier,
                  src_test_dataloader, tgt_test_dataloader)

    if args.plot:
        visualizePerformance(feature_extractor, class_classifier,
                             domain_classifier, src_test_dataloader,
                             tgt_test_dataloader)
Beispiel #2
0
mnistm_5_path = data_root + '/MNIST_M_5'
svhn_path = data_root + '/SVHN'
syndig_path = data_root + '/SynthDigits'
amazon_path = data_root + "/Office/amazon"
dslr_path = data_root + "/Office/dslr"
webcam_path = data_root + "/Office/webcam"
art_path = data_root + "/Office_Home/Art"
clipart_path = data_root + "/Office_Home/Clipart"
product_path = data_root + "/Office_Home/Product"
realworld_path = data_root + "/Office_Home/Real_World"

save_dir = './experiment'

# specific dataset params
extractor_dict = {
    'MNIST_MNIST_M': models.Extractor(),
    'MNIST_MNIST_M_5': models.Extractor(),
    'SVHN_MNIST': models.SVHN_Extractor(),
    'SynDig_SVHN': models.SVHN_Extractor(),
    'ResNet50': models.Res50_Extractor()
}

class_dict = {
    'MNIST_MNIST_M': models.Class_classifier(),
    'MNIST_MNIST_M_5': models.Class_classifier(),
    'SVHN_MNIST': models.SVHN_Class_classifier(),
    'SynDig_SVHN': models.SVHN_Class_classifier(),
    'ResNet50': models.Res50_Class_classifier(class_num=65)
}

domain_dict = {
Beispiel #3
0
def main(args):

    # Set global parameters.
    #params.fig_mode = args.fig_mode
    params.epochs = args.max_epoch
    params.training_mode = args.training_mode
    source_domain = args.source_domain
    print("source domain is: ", source_domain)
    target_domain = args.target_domain
    print("target domain is: ", target_domain)

    params.modality = args.modality
    print("modality is :", params.modality)
    params.extractor_layers = args.extractor_layers
    print("number of layers in feature extractor: ", params.extractor_layers)
    #params.class_layers = args.class_layers
    #params.domain_layers  = args.domain_layers
    lr = args.lr

    #set output dims for classifier
    #TODO: change this to len of params dict?
    if source_domain == 'iemocap':
        params.output_dim = 4
    elif source_domain == 'mosei':
        params.output_dim = 6

    # prepare the source data and target data

    src_train_dataloader = dataloaders.get_train_loader(source_domain)
    src_test_dataloader = dataloaders.get_test_loader(source_domain)
    src_valid_dataloader = dataloaders.get_valid_loader(source_domain)
    tgt_train_dataloader = dataloaders.get_train_loader(target_domain)
    tgt_test_dataloader = dataloaders.get_test_loader(target_domain)
    tgt_valid_dataloader = dataloaders.get_valid_loader(target_domain)

    print(params.mod_dim)

    # init models
    #model_index = source_domain + '_' + target_domain

    feature_extractor = models.Extractor()
    class_classifier = models.Class_classifier()
    domain_classifier = models.Domain_classifier()
    # feature_extractor = params.extractor_dict[model_index]
    # class_classifier = params.class_dict[model_index]
    # domain_classifier = params.domain_dict[model_index]

    if params.use_gpu:
        feature_extractor.cuda()
        class_classifier.cuda()
        domain_classifier.cuda()

    # init criterions
    class_criterion = nn.BCEWithLogitsLoss()
    domain_criterion = nn.BCEWithLogitsLoss()

    # init optimizer
    optimizer = optim.Adam([{
        'params': feature_extractor.parameters()
    }, {
        'params': class_classifier.parameters()
    }, {
        'params': domain_classifier.parameters()
    }],
                           lr=lr)

    for epoch in range(params.epochs):
        print('Epoch: {}'.format(epoch))
        train.train(args.training_mode, feature_extractor, class_classifier,
                    domain_classifier, class_criterion, domain_criterion,
                    src_train_dataloader, tgt_train_dataloader, optimizer,
                    epoch)
        test.test(feature_extractor, class_classifier, domain_classifier,
                  src_valid_dataloader, tgt_valid_dataloader, epoch)
        if epoch == params.epochs - 1:
            test.test(feature_extractor,
                      class_classifier,
                      domain_classifier,
                      src_test_dataloader,
                      tgt_test_dataloader,
                      epoch,
                      mode='test')
        else:
            continue