Exemplo n.º 1
0
def main():
    rootdir = "/mnt/xxx/dataset/office_caltech_10/"
    tr.manual_seed(1)
    domain_str = ['webcam', 'dslr']
    X_s = data_loader.load_train(root_dir=rootdir,
                                 domain=domain_str[0],
                                 batch_size=BATCH_SIZE[0])
    X_t = data_loader.load_test(root_dir=rootdir,
                                domain=domain_str[1],
                                batch_size=BATCH_SIZE[1])

    # train and test
    start = time.time()
    mmd_type = ['mmd', 'jmmd', 'jpmmd', 'djpmmd']
    for mt in mmd_type:
        print('-' * 10 + domain_str[0] + ' -->  ' + domain_str[1] + '-' * 10)
        print('MMD loss type: ' + mt + '\n')
        acc, loss = {}, {}
        train_acc = []
        test_acc = []
        train_loss = []
        test_loss = []
        y_pse = tr.zeros(14, 64).long().cuda()

        mdl = DaNN.DaNN(n_input=28 * 28, n_hidden=256, n_class=10)
        mdl = mdl.to(DEVICE)

        # optimization
        opt_Adam = optim.Adam(mdl.parameters(), lr=LEARNING_RATE)

        for ep in tqdm(range(1, N_EPOCH + 1)):
            tmp_train_acc, tmp_train_loss, mdl = \
                model_train(model=mdl, optimizer=opt_Adam, epoch=ep, data_src=X_s, data_tar=X_t, y_pse=y_pse,
                            mmd_type=mt)
            tmp_test_acc, tmp_test_loss = model_test(mdl, X_t, ep)
            train_acc.append(tmp_train_acc)
            test_acc.append(tmp_test_acc)
            train_loss.append(tmp_train_loss)
            test_loss.append(tmp_test_loss)
        acc['train'], acc['test'] = train_acc, test_acc
        loss['train'], loss['test'] = train_loss, test_loss

        # visualize
        plt.plot(acc['train'], label='train-' + mt)
        plt.plot(acc['test'], label='test-' + mt, ls='--')

    plt.title(domain_str[0] + ' to ' + domain_str[1])
    plt.xticks(np.linspace(1, N_EPOCH, num=5, dtype=np.int8))
    plt.xlim(1, N_EPOCH)
    plt.ylim(0, 100)
    plt.legend(loc='upper right')
    plt.xlabel("epochs")
    plt.ylabel("accuracy")
    plt.savefig(domain_str[0] + '_' + domain_str[1] + "_acc.jpg")
    plt.close()

    # time and save model
    end = time.time()
    print("Total run time: %.2f" % float(end - start))
        accuracy = correct * 100. / len(data_tar.dataset)
        res = 'Test: total loss: {:.6f}, correct: [{}/{}], testing accuracy: {:.4f}%'.format(
            total_loss_test, correct, len(data_tar.dataset), accuracy)
    tqdm.write(res)
    RESULT_TEST.append([e, total_loss_test, accuracy])
    log_test.write(res + '\n')


if __name__ == '__main__':
    rootdir = '../../../data/office_caltech_10/'
    torch.manual_seed(1)
    data_src = data_loader.load_data(root_dir=rootdir,
                                     domain='amazon',
                                     batch_size=BATCH_SIZE[0])
    data_tar = data_loader.load_test(root_dir=rootdir,
                                     domain='webcam',
                                     batch_size=BATCH_SIZE[1])
    model = DaNN.DaNN(n_input=28 * 28, n_hidden=256, n_class=10)
    model = model.to(DEVICE)
    optimizer = optim.SGD(model.parameters(),
                          lr=LEARNING_RATE,
                          momentum=MOMEMTUN,
                          weight_decay=L2_WEIGHT)
    for e in tqdm(range(1, N_EPOCH + 1)):
        model = train(model=model,
                      optimizer=optimizer,
                      epoch=e,
                      data_src=data_src,
                      data_tar=data_tar)
        test(model, data_tar, e)
    torch.save(model, 'model_dann.pkl')
Exemplo n.º 3
0
        2.样本种类一致
    另外,每类样本数据个数可以不一样  ,也就是两个数据集的样本个数可以不一样
    '''
    # src_dir = './source_datasets.mat'
    # tar_dir = './target_datasets.mat'
    src_dir = './source_datasets_Eq_shuffle.mat'
    tar_dir = './target_datasets_Eq_shuffle.mat'
    torch.manual_seed(1)
    data_src = data_loader.load_data(root_dir=src_dir,
                                     domain='source_data_train',
                                     batch_size=BATCH_SIZE[0])

    # root_dir=src_dir, domain='source_data', batch_size=BATCH_SIZE[0])
    # root_dir=src_dir, domain='source_data_device', batch_size=BATCH_SIZE[0])
    data_tar = data_loader.load_test(root_dir=tar_dir,
                                     domain='target_data_train',
                                     batch_size=BATCH_SIZE[1])

    # root_dir=tar_dir, domain='target_data', batch_size=BATCH_SIZE[1])
    # root_dir=tar_dir, domain='target_data_device', batch_size=BATCH_SIZE[1])
    # print(data_src.size())
    model = CNN1d.CNN1d(n_hidden=100, n_class=Classes)
    # 打印输出模型结构
    print(model)
    model = model.to(DEVICE)

    # 定义优化器
    optimizer = optim.Adamax(
        model.parameters(),
        lr=LEARNING_RATE,
        # momentum=MOMEMTUN,
# f = gzip.open(filename, 'rb')
# training_data, validation_data, test_data = cPickle.load(f)
# f.close()
from __future__ import print_function
import data_loader

tmp_data = data_loader.load_test()
training_data = tmp_data

print(training_data[0][0])
print(training_data[1][1])

#### Libraries

# Standard library
import cPickle
import gzip
import os.path
import random

# Third-party libraries
import numpy as np

print("Expanding the MNIST training set")

if False:
    print("The expanded training set already exists.  Exiting.")
else:
    # f = gzip.open("../data/mnist.pkl.gz", 'rb')
    # training_data, validation_data, test_data = cPickle.load(f)
    # f.close()
# f = gzip.open(filename, 'rb')
# training_data, validation_data, test_data = cPickle.load(f)
# f.close()
from __future__ import print_function
import data_loader

tmp_data = data_loader.load_test()
training_data = tmp_data

print (training_data[0][0])
print (training_data[1][1])


#### Libraries

# Standard library
import cPickle
import gzip
import os.path
import random

# Third-party libraries
import numpy as np

print("Expanding the MNIST training set")

if False:
    print("The expanded training set already exists.  Exiting.")
else:
    # f = gzip.open("../data/mnist.pkl.gz", 'rb')
    # training_data, validation_data, test_data = cPickle.load(f)