示例#1
0
use_cuda = True if torch.cuda.is_available() else False
device = torch.device('cuda:0') if use_cuda else torch.device('cpu')
torch.manual_seed(1)
if use_cuda:
    torch.cuda.manual_seed(1)

#--------------pretrain g and h for step 1---------------------------------
train_dataloader = dataloader.mnist_dataloader(batch_size=opt['batch_size'],
                                               train=True)
test_dataloader = dataloader.mnist_dataloader(batch_size=opt['batch_size'],
                                              train=False)

classifier = main_models.Classifier()
encoder = main_models.Encoder()
discriminator = main_models.DCD(input_features=128)

classifier.to(device)
encoder.to(device)
discriminator.to(device)
loss_fn = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(classifier.parameters()))

for epoch in range(opt['n_epoches_1']):
    for data, labels in train_dataloader:
        data = data.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
示例#2
0
def train_fada(config):
    if config['network'] == 'inceptionv1':
        extractor = InceptionV1(num_classes=32, dilation=config['dilation'])
    elif config['network'] == 'inceptionv1s':
        extractor = InceptionV1s(num_classes=32, dilation=config['dilation'])
    else:
        extractor = Extractor(n_flattens=config['n_flattens'],
                              n_hiddens=config['n_hiddens'])
    classifier = Classifier(n_flattens=config['n_flattens'],
                            n_hiddens=config['n_hiddens'],
                            n_class=config['n_class'])

    critic = Critic2(n_flattens=config['n_flattens'],
                     n_hiddens=config['n_hiddens'])
    if torch.cuda.is_available():
        extractor = extractor.cuda()
        classifier = classifier.cuda()
        critic = critic.cuda()

    criterion = torch.nn.CrossEntropyLoss()
    loss_class = torch.nn.CrossEntropyLoss()
    loss_domain = torch.nn.CrossEntropyLoss()

    res_dir = os.path.join(config['res_dir'],
                           'snr{}-lr{}'.format(config['snr'], config['lr']))
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    set_log_config(res_dir)
    logging.debug('train_dann')
    logging.debug(extractor)
    logging.debug(classifier)
    logging.debug(critic)
    logging.debug(config)

    optimizer = optim.Adam([{
        'params': extractor.parameters()
    }, {
        'params': classifier.parameters()
    }, {
        'params': critic.parameters()
    }],
                           lr=config['lr'])

    # TODO
    discriminator = main_models.DCD(input_features=128)
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.001)

    # source samples预训练
    #--------------pretrain g and h for step 1---------------------------------
    for epoch in range(config['n_epoches_1']):
        for data, labels in config['source_train_loader']:
            data = data.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            y_pred = classifier(extractor(data))

            loss = loss_class(y_pred, labels)
            loss.backward()

            optimizer.step()

        acc = 0
        for data, labels in config['target_test_loader']:
            data = data.to(device)
            labels = labels.to(device)
            y_test_pred = classifier(extractor(data))
            acc += (torch.max(y_test_pred,
                              1)[1] == labels).float().mean().item()

        accuracy = round(acc / float(len(config['target_test_loader'])), 3)

        print("step1----Epoch %d/%d  accuracy: %.3f " %
              (epoch + 1, config['n_epoches_1'], accuracy))

    #-----------------train DCD for step 2--------------------------------

    # X_s,Y_s=dataloader.sample_data()
    # X_t,Y_t=dataloader.create_target_samples(config['n_target_samples'])

    for epoch in range(config['n_epoches_2']):
        # for data,labels in config['source_train_loader']:
        iter_source = iter(config['source_train_loader'])
        iter_target = iter(config['target_train_loader'])
        len_source_loader = len(config['source_train_loader'])
        len_target_loader = len(config['target_train_loader'])
        num_iter = len_source_loader
        for i in range(1, num_iter + 1):
            data_source, label_source = iter_source.next()
            data_target, label_target = iter_target.next()
            if i % len_target_loader == 0:
                iter_target = iter(config['target_train_loader'])
            if torch.cuda.is_available():
                data_source, label_source = data_source.cuda(
                ), label_source.cuda()
                data_target, label_target = data_target.cuda(
                ), label_target.cuda()

        groups, aa = dataloader.sample_groups(data_source,
                                              label_source,
                                              data_target,
                                              label_target,
                                              seed=epoch)
        # groups, aa = dataloader.sample_groups(X_s,Y_s,X_t,Y_t,seed=epoch)

        n_iters = 4 * len(groups[1])
        index_list = torch.randperm(n_iters)
        mini_batch_size = 40  #use mini_batch train can be more stable

        loss_mean = []

        X1 = []
        X2 = []
        ground_truths = []
        for index in range(n_iters):

            ground_truth = index_list[index] // len(groups[1])

            x1, x2 = groups[ground_truth][index_list[index] -
                                          len(groups[1]) * ground_truth]
            X1.append(x1)
            X2.append(x2)
            ground_truths.append(ground_truth)

            #select data for a mini-batch to train
            if (index + 1) % mini_batch_size == 0:
                X1 = torch.stack(X1)
                X2 = torch.stack(X2)
                ground_truths = torch.LongTensor(ground_truths)
                X1 = X1.to(device)
                X2 = X2.to(device)
                ground_truths = ground_truths.to(device)

                optimizer_D.zero_grad()
                X_cat = torch.cat([extractor(X1), extractor(X2)], 1)
                y_pred = discriminator(X_cat.detach())
                loss = loss_domain(y_pred, ground_truths)
                loss.backward()
                optimizer_D.step()
                loss_mean.append(loss.item())
                X1 = []
                X2 = []
                ground_truths = []

        print("step2----Epoch %d/%d loss:%.3f" %
              (epoch + 1, config['n_epoches_2'], np.mean(loss_mean)))
示例#3
0
文件: main.py 项目: zhuzhenxi/aitom
torch.backends.cudnn.deterministic = True
cectA_dataset = CECT_dataset(path=opt['tar_data'])
cectA_dataloader = DataLoader(dataset=cectA_dataset,
                              batch_size=opt['batch_size'],
                              shuffle=True)
cectB_dataset = CECT_dataset(path=opt['src_data'])
cectB_dataloader = DataLoader(dataset=cectB_dataset,
                              batch_size=opt['batch_size'],
                              shuffle=True)

stage = opt['load_model']
classifier = main_models.Classifier(opt)
encoder1_S = main_models.Encoder1(opt)
encoder1_T = main_models.Encoder1(opt)
encoder2 = main_models.Encoder2(opt)
discriminator = main_models.DCD(opt)
conv_discriminator = main_models.CONV_DCD(opt)

classifier.to(device)
encoder1_S.to(device)
encoder1_T.to(device)
encoder2.to(device)
discriminator.to(device)
conv_discriminator.to(device)

if not os.path.exists('results'):
    os.mkdir('results')
if stage == 1:
    print('We will skip stage 1, stage 2 will be processed.')
    encoder1_S.load_state_dict(torch.load('results/encoder1S_stage1.pt'))
    encoder1_T.load_state_dict(torch.load('results/encoder1T_stage1.pt'))