def predict(params):

    inverter = PrectNormalizer(xr.open_dataset(params['norm_fn']),
                               params['output_vars'],
                               params['input_transform'][0],
                               params['input_transform'][1],
                               params['var_cut_off'], params['model_type'])

    model = CVAE(params)
    optimizer = optim.Adam(model.parameters(), lr=0.00001, weight_decay=0.001)

    ### Load model
    checkpoint = torch.load('./runs/VAE_model_DNN_classifier_exp_VAE_Exp04.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

    model.cuda()
    model.eval()

    valid_dataset = spcamDataset(params, phase="validation")
    valid_loader = DataLoader(valid_dataset,
                              sampler=SubsetSampler(valid_dataset.indices))

    result_predicted, result_actual = [], []
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(valid_loader):

            target = target.squeeze(0).type(torch.float32).to(params['device'])

            z = torch.randn(data.shape[1], 16)

            z = torch.cat((z, data.squeeze(0)), dim=1).cuda()

            predPrecit = model.decoder(z)

            print("Batch MSE {}".format(
                metrics.mean_squared_error(
                    predPrecit.detach().cpu().numpy(),
                    target.squeeze(0).detach().cpu().numpy())))

            #val_loss = compute_loss(target, sampled_precit, mean, log_var) #.type(torch.FloatTensor).to(params['device']))
            #assert val_loss.requires_grad == False

            result_predicted.extend(predPrecit.cpu().detach().numpy())
            result_actual.extend(target.squeeze(0).cpu().detach().numpy())

    mse = metrics.mean_squared_error(np.array(result_actual),
                                     np.array(result_predicted))

    print("MSE {}".format(mse))
Example #2
0
            epochLoss += vaeLoss.data[0]
            epochLoss_kl += klLoss.data[0]
            epochLoss_bce += bceLoss.data[0]
            epochLoss_gen += genLoss.data[0]
            epochLoss_dis += disLoss.data[0]
            epochLoss_class += classLoss.data[0]

            if i % 100 == 0:
                i += 1
                print '[%d, %d] loss: %0.5f, bce: %0.5f, alpha*kl: %0.5f, gen: %0.5f, dis: %0.5f, class: %0.5f, time: %0.3f' % \
               (e, i, epochLoss/i, epochLoss_bce/i, opts.alpha*epochLoss_kl/i, epochLoss_gen/i, epochLoss_dis/i, \
               epochLoss_class/i, time() - TIME)

        #generate samples after each 10 epochs
        if e % 1 == 0:
            cvae.eval()
            dis.eval()

            #Load test data
            testIter = iter(testLoader)
            xTest, yTest = testIter.next()
            yTest = yTest
            if cvae.useCUDA:
                xTest = Variable(xTest).cuda().data
                yTest = Variable(yTest).cuda()
                outputs, outMu, outLogVar, outY = cvae(Variable(xTest))
            else:
                yTest = Variable(yTest)

            print 'saving a set of samples'
            if cvae.useCUDA:
Example #3
0
  model.train()
  train_loss = 0
  
  for x, y in train_dataloader:
    
    x = x.view(-1, input_size).to(device)
    y = utils.y_to_onehot(y, batch_size, num_of_classes).to(device)
    
    optimizer.zero_grad()
    x_mu, x_logvar, z, z_mu, z_logvar = model(x, y)
    loss = model.loss_calc(x, x_mu, z_mu, z_logvar)
    loss.backward()
    train_loss += loss.item()
    optimizer.step()
    
  model.eval()
  test_loss = 0
  
  with torch.no_grad():
    
    for x, y in test_dataloader:
    
      x = x.view(-1, input_size).to(device)
      y = utils.y_to_onehot(y, batch_size, num_of_classes).to(device)

      x_mu, x_logvar, z, z_mu, z_logvar = model(x, y)
      loss = model.loss_calc(x, x_mu, z_mu, z_logvar)
      test_loss += loss.item()
      
  print('Epoch is {}. Train loss = {}. Test loss = {}'.format(epoch, train_loss/len(train_dataset), test_loss/len(test_dataset)))