コード例 #1
0
def test(encoder_model, epoch, x64_test, x16_test):
    encoder.eval()
    decoder.eval()

    z = torch.randn(9, 1, 16, 16).to(device)
    _, z16, _ = X_l_encoder(encoder_model,
                            (torch.FloatTensor(x16_test[111:112])).to(device))
    z16 = (torch.ones([9, 1, 4, 4])).to(device) * z16
    z16 = upsample(z16, 16)
    z_h = torch.cat((z16, z), 1)
    gen_imgs = decoder(z_h)
    samples = np.squeeze(gen_imgs.data.cpu().numpy())
    plot_generation(samples, epoch, output_dir)

    real_imgs = x64_test[[10, 30, 50, 100]]
    real_x16 = x16_test[[10, 30, 50, 100]]
    _, z16, _ = X_l_encoder(encoder_model,
                            (torch.FloatTensor(real_x16)).to(device))
    z16 = upsample(z16, 16)
    real_imgs = (torch.FloatTensor(real_imgs)).to(device)
    encoded_imgs, _, _ = encoder(real_imgs)
    z_h = torch.cat((z16, encoded_imgs), 1)
    decoded_imgs = decoder(z_h)
    samples_gen = np.squeeze(decoded_imgs.data.cpu().numpy())
    samples_real = np.squeeze(real_imgs.data.cpu().numpy())
    samples = np.vstack((samples_real, samples_gen))
    plot_reconstruction(samples, epoch, output_dir)
コード例 #2
0
ファイル: VAE_train_16.py プロジェクト: xiayzh/MH-MDGM
def test(epoch,x_test):
    encoder.eval()
    decoder.eval()
    z = torch.randn(9, 1, 4 ,4).to(device)
    imgs = decoder(z)
    samples = np.squeeze(imgs.data.cpu().numpy())
    plot_generation(samples,epoch,output_dir,1)

    real_imgs = x_test[[10,30,50,100]]
    real_imgs = (torch.FloatTensor(real_imgs)).to(device) 
    encoded_imgs,_,_ = encoder(real_imgs)
    decoded_imgs = decoder(encoded_imgs)
    samples_gen  = np.squeeze(decoded_imgs.data.cpu().numpy())
    samples_real = np.squeeze(real_imgs.data.cpu().numpy())

    samples = np.vstack((samples_real,samples_gen))
    plot_reconstruction(samples,epoch,output_dir)
コード例 #3
0
ファイル: train.py プロジェクト: Ugenteraan/DeepCaps
def train(img_size,
          device=torch.device('cpu'),
          learning_rate=1e-3,
          num_epochs=500,
          decay_step=5,
          gamma=0.98,
          num_classes=10,
          lambda_=0.5,
          m_plus=0.9,
          m_minus=0.1,
          checkpoint_folder=None,
          checkpoint_name=None,
          load_checkpoint=False,
          graphs_folder=None):
    '''
    Function to train the DeepCaps Model
    '''
    checkpoint_path = checkpoint_folder + checkpoint_name

    deepcaps = DeepCapsModel(num_class=num_classes,
                             img_height=img_size,
                             img_width=img_size,
                             device=device).to(device)  #initialize model

    #load the current checkpoint
    if load_checkpoint and not checkpoint_name is None and os.path.exists(
            checkpoint_path):
        try:
            deepcaps.load_state_dict(torch.load(checkpoint_path))
            print("Checkpoint loaded!")
        except Exception as e:
            print(e)
            sys.exit()

    optimizer = torch.optim.Adam(deepcaps.parameters(), lr=learning_rate)
    # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=decay_step, gamma=gamma)

    best_accuracy = 0

    training_loss_list = []
    training_acc_list = []
    testing_loss_list = []
    testing_acc_list = []

    #training and testing
    for epoch_idx in range(num_epochs):

        print(
            f"Training and testing for epoch {epoch_idx} began with LR : {get_learning_rate(optimizer)}"
        )
        #Training
        batch_loss = 0
        batch_accuracy = 0
        batch_idx = 0

        deepcaps.train()  #train mode
        for batch_idx, (train_data, labels) in tqdm(
                enumerate(train_loader)):  #from training dataset

            data, labels = train_data.to(device), labels.to(device)
            onehot_label = onehot_encode(
                labels, num_classes=num_classes,
                device=device)  #convert the labels into one-hot vectors.

            optimizer.zero_grad()

            outputs, _, reconstructed, indices = deepcaps(data, onehot_label)
            loss = deepcaps.loss(x=outputs,
                                 reconstructed=reconstructed,
                                 data=data,
                                 labels=onehot_label,
                                 lambda_=lambda_,
                                 m_plus=m_plus,
                                 m_minus=m_minus)

            loss.backward()
            optimizer.step()

            batch_loss += loss.item()
            batch_accuracy += accuracy_calc(predictions=indices, labels=labels)

        epoch_accuracy = batch_accuracy / (batch_idx + 1)
        avg_batch_loss = batch_loss / (batch_idx + 1)
        print(
            f"Epoch : {epoch_idx}, Training Accuracy : {epoch_accuracy}, Training Loss : {avg_batch_loss}"
        )

        training_loss_list.append(avg_batch_loss)
        training_acc_list.append(epoch_accuracy)

        #Testing
        batch_loss = 0
        batch_accuracy = 0
        batch_idx = 0

        deepcaps.eval()  #eval mode
        for batch_idx, (test_data, labels) in tqdm(
                enumerate(test_loader)):  #from testing dataset

            data, labels = test_data.to(device), labels.to(device)
            onehot_label = onehot_encode(labels,
                                         num_classes=num_classes,
                                         device=device)

            outputs, _, reconstructed, indices = deepcaps(data, onehot_label)
            loss = deepcaps.loss(x=outputs,
                                 reconstructed=reconstructed,
                                 data=data,
                                 labels=onehot_label,
                                 lambda_=lambda_,
                                 m_plus=m_plus,
                                 m_minus=m_minus)

            batch_loss += loss.item()
            batch_accuracy += accuracy_calc(predictions=indices, labels=labels)

        epoch_accuracy = batch_accuracy / (batch_idx + 1)
        avg_batch_loss = batch_loss / (batch_idx + 1)
        print(
            f"Epoch : {epoch_idx}, Testing Accuracy : {epoch_accuracy}, Testing Loss : {avg_batch_loss}"
        )

        testing_loss_list.append(avg_batch_loss)
        testing_acc_list.append(epoch_accuracy)

        # lr_scheduler.step()

        if not graphs_folder is None and epoch_idx % 5 == 0:
            plot_loss_acc(path=graphs_folder,
                          num_epoch=epoch_idx,
                          train_accuracies=training_acc_list,
                          train_losses=training_loss_list,
                          test_accuracies=testing_acc_list,
                          test_losses=testing_loss_list)

            plot_reconstruction(path=graphs_folder,
                                num_epoch=epoch_idx,
                                original_images=data.detach(),
                                reconstructed_images=reconstructed.detach(),
                                predicted_classes=indices.detach(),
                                true_classes=labels.detach())

        if best_accuracy < epoch_accuracy:

            torch.save(deepcaps.state_dict(), checkpoint_path)
            print("Saved model at epoch %d" % (epoch_idx))
コード例 #4
0
# Saving reconstruction losses
with open('reconst_losses', 'wb') as f:
    pickle.dump(rec_losses, f)

# Plot reconstruction losses
plot.plot_reconst_losses(rec_losses, 'reconst_losses.png')

# Saving train total losses
with open('train_total_losses', 'wb') as f:
    pickle.dump(train_total_losses, f)

# Saving valid total losses
with open('valid_total_losses', 'wb') as f:
    pickle.dump(valid_total_losses, f)

# Plot train and valid total losses
plot.plot_tv_losses(train_total_losses, valid_total_losses,
                    'tv_total_losses.png')

plot.plot_reconstruction('rec1.png', model, test_loader)
plot.plot_reconstruction('rec2.png', model, test_loader)
plot.plot_reconstruction('rec3.png', model, test_loader)
plot.plot_reconstruction('rec4.png', model, test_loader)
plot.plot_reconstruction('rec5.png', model, test_loader)

plot.plot_generation('gen1.png', model)
plot.plot_generation('gen2.png', model)
plot.plot_generation('gen3.png', model)
plot.plot_generation('gen4.png', model)
plot.plot_generation('gen5.png', model)