Beispiel #1
0
def train_cleitc(dataloader, seed, **kwargs):
    """

    :param s_dataloaders:
    :param t_dataloaders:
    :param kwargs:
    :return:
    """
    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])

    # get reference encoder
    aux_ae = deepcopy(autoencoder)

    aux_ae.encoder.load_state_dict(torch.load(os.path.join('./model_save/ae5000', f'ft_encoder_{seed}.pt')))
    print('reference encoder loaded')
    reference_encoder = aux_ae.encoder

    # construct transmitter
    transmitter = MLP(input_dim=kwargs['latent_dim'],
                      output_dim=kwargs['latent_dim'],
                      hidden_dims=[kwargs['latent_dim']]).to(kwargs['device'])

    ae_eval_train_history = defaultdict(list)
    ae_eval_test_history = defaultdict(list)

    if kwargs['retrain_flag']:
        cleit_params = [
            autoencoder.parameters(),
            transmitter.parameters()
        ]
        cleit_optimizer = torch.optim.AdamW(chain(*cleit_params), lr=kwargs['lr'])
        # start autoencoder pretraining
        for epoch in range(int(kwargs['train_num_epochs'])):
            if epoch % 1 == 0:
                print(f'----Autoencoder Training Epoch {epoch} ----')
            for step, batch in enumerate(dataloader):
                ae_eval_train_history = cleit_train_step(ae=autoencoder,
                                                         reference_encoder=reference_encoder,
                                                         transmitter=transmitter,
                                                         batch=batch,
                                                         device=kwargs['device'],
                                                         optimizer=cleit_optimizer,
                                                         history=ae_eval_train_history)
        torch.save(autoencoder.state_dict(), os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt'))
        torch.save(transmitter.state_dict(), os.path.join(kwargs['model_save_folder'], 'transmitter.pt'))
    else:
        try:
            autoencoder.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt')))
            transmitter.load_state_dict(torch.load(os.path.join(kwargs['model_save_folder'], 'transmitter.pt')))
        except FileNotFoundError:
            raise Exception("No pre-trained encoder")

    encoder = EncoderDecoder(encoder=autoencoder.encoder,
                             decoder=transmitter).to(kwargs['device'])

    return encoder, (ae_eval_train_history, ae_eval_test_history)
Beispiel #2
0
def main(epoch_num):
    # 下载mnist数据集
    mnist_train = datasets.MNIST('mnist',
                                 train=True,
                                 transform=transforms.Compose(
                                     [transforms.ToTensor()]),
                                 download=True)
    mnist_test = datasets.MNIST('mnist',
                                train=False,
                                transform=transforms.Compose(
                                    [transforms.ToTensor()]),
                                download=True)

    # 载入mnist数据集
    # batch_size设置每一批数据的大小,shuffle设置是否打乱数据顺序,结果表明,该函数会先打乱数据再按batch_size取数据
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    # 查看每一个batch图片的规模
    x, label = iter(mnist_train).__next__()  # 取出第一批(batch)训练所用的数据集
    print(' img : ', x.shape
          )  # img :  torch.Size([32, 1, 28, 28]), 每次迭代获取32张图片,每张图大小为(1,28,28)

    # 准备工作 : 搭建计算流程
    device = torch.device('cuda')
    model = AE().to(device)  # 生成AE模型,并转移到GPU上去
    print('The structure of our model is shown below: \n')
    print(model)
    loss_function = nn.MSELoss()  # 生成损失函数
    optimizer = optim.Adam(model.parameters(),
                           lr=1e-3)  # 生成优化器,需要优化的是model的参数,学习率为0.001

    # 开始迭代
    loss_epoch = []
    for epoch in range(epoch_num):
        # 每一代都要遍历所有的批次
        for batch_index, (x, _) in enumerate(mnist_train):
            # [b, 1, 28, 28]
            x = x.to(device)
            # 前向传播
            x_hat = model(x)  # 模型的输出,在这里会自动调用model中的forward函数
            loss = loss_function(x_hat, x)  # 计算损失值,即目标函数
            # 后向传播
            optimizer.zero_grad()  # 梯度清零,否则上一步的梯度仍会存在
            loss.backward()  # 后向传播计算梯度,这些梯度会保存在model.parameters里面
            optimizer.step()  # 更新梯度,这一步与上一步主要是根据model.parameters联系起来了

        loss_epoch.append(loss.item())
        if epoch % (epoch_num // 10) == 0:
            print('Epoch [{}/{}] : '.format(epoch, epoch_num), 'loss = ',
                  loss.item())  # loss是Tensor类型
            # x, _ = iter(mnist_test).__next__()   # 在测试集中取出一部分数据
            # with torch.no_grad():
            #     x_hat = model(x)

    return loss_epoch
Beispiel #3
0
def main():
    mnist_train = datasets.MNIST('mnist',
                                 True,
                                 transform=transforms.Compose(
                                     [transforms.ToTensor()]),
                                 download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

    mnist_test = datasets.MNIST('mnist',
                                False,
                                transform=transforms.Compose(
                                    [transforms.ToTensor()]),
                                download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    x, _ = iter(mnist_train).next()
    print('x:', x.shape)

    # device = torch.device('cuda')
    # model = AE().to(device)
    model = AE()
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    viz = visdom.Visdom()

    for epoch in range(1000):

        for batchidx, (x, _) in enumerate(mnist_train):
            # [b, 1, 28, 28]
            # x = x.to(device)

            x_hat, kld = model(x)
            loss = criteon(x_hat, x)

            if kld is not None:
                elbo = -loss - 1.0 * kld
                loss = -elbo

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # print(epoch, 'loss:', loss.item(), 'kld:', kld.item())
        print(epoch, 'loss', loss.item())

        x, _ = iter(mnist_test).next()
        # x = x.to(device)
        with torch.no_grad():
            x_hat, kld = model(x)
        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_ha'))
Beispiel #4
0
def train():
    mnist_train = datasets.MNIST('../data/mnist',
                                 train=True,
                                 transform=transforms.Compose(
                                     [transforms.ToTensor()]),
                                 download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
    mnist_test = datasets.MNIST('../data/mnist',
                                train=False,
                                transform=transforms.Compose(
                                    [transforms.ToTensor()]),
                                download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    #不需要label,因为是无监督学习
    x, _ = iter(mnist_train).next()
    print('x:', x.shape)

    device = torch.device('cuda')

    model = AE().to(device)
    criteon = nn.MSELoss()  # loss function
    optimzer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    vis = visdom.Visdom()

    for epoch in range(1000):

        # 训练过程
        for batchIdx, (x, _) in enumerate(mnist_train):
            #forwardp [b, 1, 28, 28]
            x = x.to(device)
            x_hat = model(x)
            loss = criteon(x_hat, x)

            #backward
            optimzer.zero_grad()
            loss.backward()
            optimzer.step()

        # 打印loss
        print('epoch:', epoch, '  loss:', loss.item())

        # 测试过程
        x, _ = iter(mnist_test).next()
        x = x.to(device)
        with torch.no_grad():  #测试不用梯度
            x_hat = model(x)

        vis.images(x, nrow=8, win='x', opts=dict(title='x'))  #画输入
        vis.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))  #画输出
Beispiel #5
0
def main():

    mnist_train = DataLoader(datasets.MNIST('../Lesson5/mnist_data',
                                            True,
                                            transform=transforms.Compose(
                                                [transforms.ToTensor()]),
                                            download=True),
                             batch_size=32,
                             shuffle=True)

    mnist_test = DataLoader(datasets.MNIST('../Lesson5/mnist_data',
                                           False,
                                           transforms.Compose(
                                               [transforms.ToTensor()]),
                                           download=True),
                            batch_size=32,
                            shuffle=True)

    x, _ = iter(mnist_train).next()
    print(f'x:{x.shape}')

    device = torch.device('cuda')
    model = AE().to(device)
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    viz = visdom.Visdom()

    for epoch in range(1000):

        for batchidx, (x, _) in enumerate(mnist_train):
            # [b, 1, 28, 28]
            x = x.to(device)

            x_hat, _ = model(x)
            loss = criteon(x_hat, x)

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(epoch, 'loss:', loss.item())

        x, _ = iter(mnist_test).next()
        x = x.to(device)
        with torch.no_grad():
            x_hat, kld = model(x)
        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))
Beispiel #6
0
def train_ae(dataloader, **kwargs):
    """
    :param s_dataloaders:
    :param t_dataloaders:
    :param kwargs:
    :return:
    """
    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])

    ae_eval_train_history = defaultdict(list)
    ae_eval_test_history = defaultdict(list)

    if kwargs['retrain_flag']:
        ae_optimizer = torch.optim.AdamW(autoencoder.parameters(),
                                         lr=kwargs['lr'])
        # start autoencoder pretraining
        for epoch in range(int(kwargs['train_num_epochs'])):
            if epoch % 50 == 0:
                print(f'----Autoencoder Training Epoch {epoch} ----')
            for step, batch in enumerate(dataloader):
                ae_eval_train_history = ae_train_step(
                    ae=autoencoder,
                    batch=batch,
                    device=kwargs['device'],
                    optimizer=ae_optimizer,
                    history=ae_eval_train_history)
        torch.save(autoencoder.state_dict(),
                   os.path.join(kwargs['model_save_folder'], 'ae.pt'))
    else:
        try:
            autoencoder.load_state_dict(
                torch.load(os.path.join(kwargs['model_save_folder'], 'ae.pt')))
        except FileNotFoundError:
            raise Exception("No pre-trained encoder")

    return autoencoder.encoder, (ae_eval_train_history, ae_eval_test_history)
Beispiel #7
0
from ae import AE
from visdom import Visdom

device = torch.device('cuda:0')  if torch.cuda.is_available() else torch.device('cpu')
batchsz = 128
epochs = 50
lr = 1e-3

train_dataset = datasets.MNIST('../data',transform=transforms.ToTensor())
test_dataset = datasets.MNIST('../data', False, transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batchsz, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batchsz, shuffle=True)

net = AE()
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, 5)
net.to(device)
criterion.to(device)
train_loss = []
viz = Visdom()
for epoch in range(epochs):
    train_loss.clear()
    net.train()
    for step, (x, _) in enumerate(train_loader):
        x = x.to(device)
        x_hat = net(x)

        loss = criterion(x_hat, x)
        optimizer.zero_grad()
        loss.backward()
Beispiel #8
0
def train_cleita(dataloader, seed, **kwargs):
    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])

    # get reference encoder
    aux_ae = deepcopy(autoencoder)

    aux_ae.encoder.load_state_dict(
        torch.load(os.path.join('./model_save', f'ft_encoder_{seed}.pt')))
    print('reference encoder loaded')
    reference_encoder = aux_ae.encoder

    # construct transmitter
    transmitter = MLP(input_dim=kwargs['latent_dim'],
                      output_dim=kwargs['latent_dim'],
                      hidden_dims=[kwargs['latent_dim']]).to(kwargs['device'])

    confounding_classifier = MLP(input_dim=kwargs['latent_dim'],
                                 output_dim=1,
                                 hidden_dims=kwargs['classifier_hidden_dims'],
                                 dop=kwargs['dop']).to(kwargs['device'])

    ae_train_history = defaultdict(list)
    ae_val_history = defaultdict(list)
    critic_train_history = defaultdict(list)
    gen_train_history = defaultdict(list)

    if kwargs['retrain_flag']:
        cleit_params = [autoencoder.parameters(), transmitter.parameters()]
        cleit_optimizer = torch.optim.AdamW(chain(*cleit_params),
                                            lr=kwargs['lr'])
        classifier_optimizer = torch.optim.RMSprop(
            confounding_classifier.parameters(), lr=kwargs['lr'])
        for epoch in range(int(kwargs['train_num_epochs'])):
            if epoch % 50 == 0:
                print(f'confounder wgan training epoch {epoch}')
            for step, batch in enumerate(dataloader):
                critic_train_history = critic_train_step(
                    critic=confounding_classifier,
                    ae=autoencoder,
                    reference_encoder=reference_encoder,
                    transmitter=transmitter,
                    batch=batch,
                    device=kwargs['device'],
                    optimizer=classifier_optimizer,
                    history=critic_train_history,
                    # clip=0.1,
                    gp=10.0)
                if (step + 1) % 5 == 0:
                    gen_train_history = gan_gen_train_step(
                        critic=confounding_classifier,
                        ae=autoencoder,
                        transmitter=transmitter,
                        batch=batch,
                        device=kwargs['device'],
                        optimizer=cleit_optimizer,
                        alpha=1.0,
                        history=gen_train_history)

        torch.save(autoencoder.state_dict(),
                   os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt'))
        torch.save(transmitter.state_dict(),
                   os.path.join(kwargs['model_save_folder'], 'transmitter.pt'))
    else:
        try:
            autoencoder.load_state_dict(
                torch.load(
                    os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt')))
            transmitter.load_state_dict(
                torch.load(
                    os.path.join(kwargs['model_save_folder'],
                                 'transmitter.pt')))
        except FileNotFoundError:
            raise Exception("No pre-trained encoder")

    encoder = EncoderDecoder(encoder=autoencoder.encoder,
                             decoder=transmitter).to(kwargs['device'])

    return encoder, (ae_train_history, ae_val_history, critic_train_history,
                     gen_train_history)
Beispiel #9
0
def main():

    model = AE()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.MSELoss()
    #criteon = nn.CrossEntropyLoss()
    print(model)

    #for epoch in range(epochs):
    for epoch in range(100):
        for step, (x, y) in enumerate(all_loader):
            # x: [b,1,100,100]
            x_hat = model(x, False)
            #print('x shape:',x.shape,'x_hat shape:',x_hat.shape)
            loss = criteon(x_hat, x)
            #backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(epoch, 'loss', loss.item())

        # 使用encode部分
        # encoder = []
        # label = []
        encoder = torch.randn(301, 100)
        label = torch.randn(301, 2)
        #tmp = torch.randn(2, 1, 100, 100)
        for x, y in encoder_loader:
            #x,y = iter(all_loader).next()
            with torch.no_grad():
                x_encoder = model(x, True)
                # label.append(y)
                # encoder.append(x_encoder)
                label = y
                encoder = x_encoder
        encoder = encoder.numpy()
        label = label.numpy()
        #print(encoder.shape,label.shape)

        #  谱聚类
        if epoch == 0:
            pred = torch.zeros(301, 101)
            pred = pred.numpy()
        pred.T[0][:] = label.T[0]
        print(pred)

        from sklearn.cluster import SpectralClustering
        from sklearn.metrics import adjusted_rand_score
        from sklearn.metrics import normalized_mutual_info_score
        from sklearn.metrics.pairwise import cosine_similarity
        simatrix = 0.5 * cosine_similarity(encoder) + 0.5
        SC = SpectralClustering(affinity='precomputed',
                                assign_labels='discretize',
                                random_state=100)
        label1 = SC.fit_predict(simatrix)

        pred.T[epoch + 1][:] = label1[:]

        print('pred:', pred.shape)
        print(pred)
        if epoch == 99:
            pd.DataFrame(pred).to_csv("pred.csv", index=False, sep=',')