Exemple #1
0
    def testNormalize(self):
        df = dt.load_data(params.global_params['db_path'],
                          'C',
                          index_col='date',
                          from_date=20100101,
                          to_date=20100405,
                          limit=30)
        df = df['close']
        look_back = 20
        look_ahead = 1
        coeff = 3.0
        print(df.tail())
        data = df.values
        print('data.shape', data.shape)

        _, y_data = dt.normalize(data,
                                 look_back=look_back,
                                 look_ahead=look_ahead,
                                 alpha=coeff)

        tmp = dt.denormalize(y_data, data, look_back, look_ahead, coeff)
        print('denorma.shape', tmp.shape)

        plt.plot(data[look_back:], label='actual')
        plt.plot(tmp, label='denorm')
        plt.legend(loc='upper left')
        plt.show()
Exemple #2
0
def main(args):
    # Init tensorboard
    writer = SummaryWriter('./runs/' + args.runname + str(args.trialnumber))
    model_name = 'VanillaDMM'

    # Set evaluation log file
    evaluation_logpath = './logs/{}/evaluation_result.log'.format(
        model_name.lower())
    log_evaluation(evaluation_logpath,
                   'Evaluation Trial - {}\n'.format(args.trialnumber))

    # Constants
    time_length = 30
    input_length_for_pred = 20
    pred_length = time_length - input_length_for_pred
    train_batch_size = 16
    valid_batch_size = 1

    # For model
    input_channels = 1
    z_channels = 50
    emission_channels = [64, 32]
    transition_channels = 64
    encoder_channels = [32, 64]
    rnn_input_dim = 256
    rnn_channels = 128
    kernel_size = 3
    pred_length = 0

    # Device checking
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")

    # Make dataset
    logging.info("Generate data")
    train_datapath = args.datapath / 'train'
    valid_datapath = args.datapath / 'valid'
    train_dataset = DiffusionDataset(train_datapath)
    valid_dataset = DiffusionDataset(valid_datapath)

    # Create data loaders from pickle data
    logging.info("Generate data loaders")
    train_dataloader = DataLoader(
        train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=8)
    valid_dataloader = DataLoader(
        valid_dataset, batch_size=valid_batch_size, num_workers=4)

    # Training parameters
    width = 100
    height = 100
    input_dim = width * height

    # Create model
    logging.warning("Generate model")
    logging.warning(input_dim)
    pred_input_dim = 10
    dmm = DMM(input_channels=input_channels, z_channels=z_channels, emission_channels=emission_channels,
              transition_channels=transition_channels, encoder_channels=encoder_channels, rnn_input_dim=rnn_input_dim, rnn_channels=rnn_channels, kernel_size=kernel_size, height=height, width=width, pred_input_dim=pred_input_dim, num_layers=1, rnn_dropout_rate=0.0,
              num_iafs=0, iaf_dim=50, use_cuda=use_cuda)

    # Initialize model
    logging.info("Initialize model")
    epochs = args.endepoch
    learning_rate = 0.0001
    beta1 = 0.9
    beta2 = 0.999
    clip_norm = 10.0
    lr_decay = 1.0
    weight_decay = 0
    adam_params = {"lr": learning_rate, "betas": (beta1, beta2),
                   "clip_norm": clip_norm, "lrd": lr_decay,
                   "weight_decay": weight_decay}
    adam = ClippedAdam(adam_params)
    elbo = Trace_ELBO()
    svi = SVI(dmm.model, dmm.guide, adam, loss=elbo)

    # saves the model and optimizer states to disk
    save_model = Path('./checkpoints/' + model_name)

    def save_checkpoint(epoch):
        save_dir = save_model / '{}.model'.format(epoch)
        save_opt_dir = save_model / '{}.opt'.format(epoch)
        logging.info("saving model to %s..." % save_dir)
        torch.save(dmm.state_dict(), save_dir)
        logging.info("saving optimizer states to %s..." % save_opt_dir)
        adam.save(save_opt_dir)
        logging.info("done saving model and optimizer checkpoints to disk.")

    # Starting epoch
    start_epoch = args.startepoch

    # loads the model and optimizer states from disk
    if start_epoch != 0:
        load_opt = './checkpoints/' + model_name + \
            '/e{}-i188-opt-tn{}.opt'.format(start_epoch - 1, args.trialnumber)
        load_model = './checkpoints/' + model_name + \
            '/e{}-i188-tn{}.pt'.format(start_epoch - 1, args.trialnumber)

        def load_checkpoint():
            # assert exists(load_opt) and exists(load_model), \
            #     "--load-model and/or --load-opt misspecified"
            logging.info("loading model from %s..." % load_model)
            dmm.load_state_dict(torch.load(load_model, map_location=device))
            # logging.info("loading optimizer states from %s..." % load_opt)
            # adam.load(load_opt)
            # logging.info("done loading model and optimizer states.")

        if load_model != '':
            logging.info('Load checkpoint')
            load_checkpoint()

    # Validation only?
    validation_only = args.validonly

    # Train the model
    if not validation_only:
        logging.info("Training model")
        annealing_epochs = 1000
        minimum_annealing_factor = 0.2
        N_train_size = 3000
        N_mini_batches = int(N_train_size / train_batch_size +
                             int(N_train_size % train_batch_size > 0))
        for epoch in tqdm(range(start_epoch, epochs), desc='Epoch', leave=True):
            r_loss_train = 0
            dmm.train(True)
            idx = 0
            mov_avg_loss = 0
            mov_data_len = 0
            for which_mini_batch, data in enumerate(tqdm(train_dataloader, desc='Train', leave=True)):
                if annealing_epochs > 0 and epoch < annealing_epochs:
                    # compute the KL annealing factor approriate for the current mini-batch in the current epoch
                    min_af = minimum_annealing_factor
                    annealing_factor = min_af + (1.0 - min_af) * \
                        (float(which_mini_batch + epoch * N_mini_batches + 1) /
                         float(annealing_epochs * N_mini_batches))
                else:
                    # by default the KL annealing factor is unity
                    annealing_factor = 1.0

                data['observation'] = normalize(
                    data['observation'].unsqueeze(2).to(device))
                batch_size, length, _, w, h = data['observation'].shape
                data_reversed = reverse_sequences(data['observation'])
                data_mask = torch.ones(
                    batch_size, length, input_channels, w, h).cuda()

                loss = svi.step(data['observation'],
                                data_reversed, data_mask, annealing_factor)

                # Running losses
                mov_avg_loss += loss
                mov_data_len += batch_size

                r_loss_train += loss
                idx += 1

            # Average losses
            train_loss_avg = r_loss_train / (len(train_dataset) * time_length)
            writer.add_scalar('Loss/train', train_loss_avg, epoch)
            logging.info("Epoch: %d, Training loss: %1.5f",
                         epoch, train_loss_avg)

            # # Time to time evaluation
            if epoch == epochs - 1:
                for temp_pred_length in [20]:
                    r_loss_valid = 0
                    r_loss_loc_valid = 0
                    r_loss_scale_valid = 0
                    r_loss_latent_valid = 0
                    dmm.train(False)
                    val_pred_length = temp_pred_length
                    val_pred_input_length = 10
                    with torch.no_grad():
                        for i, data in enumerate(tqdm(valid_dataloader, desc='Eval', leave=True)):
                            data['observation'] = normalize(
                                data['observation'].unsqueeze(2).to(device))
                            batch_size, length, _, w, h = data['observation'].shape
                            data_reversed = reverse_sequences(
                                data['observation'])
                            data_mask = torch.ones(
                                batch_size, length, input_channels, w, h).cuda()

                            pred_tensor = data['observation'][:,
                                                              :input_length_for_pred, :, :, :]
                            pred_tensor_reversed = reverse_sequences(
                                pred_tensor)
                            pred_tensor_mask = torch.ones(
                                batch_size, input_length_for_pred, input_channels, w, h).cuda()

                            ground_truth = data['observation'][:,
                                                               input_length_for_pred:, :, :, :]

                            val_nll = svi.evaluate_loss(
                                data['observation'], data_reversed, data_mask)

                            preds, _, loss_loc, loss_scale = do_prediction_rep_inference(
                                dmm, pred_tensor_mask, val_pred_length, val_pred_input_length, data['observation'])

                            ground_truth = denormalize(
                                data['observation'].squeeze().cpu().detach()
                            )
                            pred_with_input = denormalize(
                                torch.cat(
                                    [data['observation'][:, :-val_pred_length, :, :, :].squeeze(),
                                     preds.squeeze()], dim=0
                                ).cpu().detach()
                            )

                            # Running losses
                            r_loss_valid += val_nll
                            r_loss_loc_valid += loss_loc
                            r_loss_scale_valid += loss_scale

                    # Average losses
                    valid_loss_avg = r_loss_valid / \
                        (len(valid_dataset) * time_length)
                    valid_loss_loc_avg = r_loss_loc_valid / \
                        (len(valid_dataset) * val_pred_length * width * height)
                    valid_loss_scale_avg = r_loss_scale_valid / \
                        (len(valid_dataset) * val_pred_length * width * height)
                    writer.add_scalar('Loss/test', valid_loss_avg, epoch)
                    writer.add_scalar(
                        'Loss/test_obs', valid_loss_loc_avg, epoch)
                    writer.add_scalar('Loss/test_scale',
                                      valid_loss_scale_avg, epoch)
                    logging.info("Validation loss: %1.5f", valid_loss_avg)
                    logging.info("Validation obs loss: %1.5f",
                                 valid_loss_loc_avg)
                    logging.info("Validation scale loss: %1.5f",
                                 valid_loss_scale_avg)
                    log_evaluation(evaluation_logpath, "Validation obs loss for {}s pred {}: {}\n".format(
                        val_pred_length, args.trialnumber, valid_loss_loc_avg))
                    log_evaluation(evaluation_logpath, "Validation scale loss for {}s pred {}: {}\n".format(
                        val_pred_length, args.trialnumber, valid_loss_scale_avg))

            # Save model
            if epoch % 50 == 0 or epoch == epochs - 1:
                torch.save(dmm.state_dict(), args.modelsavepath / model_name /
                           'e{}-i{}-tn{}.pt'.format(epoch, idx, args.trialnumber))
                adam.save(args.modelsavepath / model_name /
                          'e{}-i{}-opt-tn{}.opt'.format(epoch, idx, args.trialnumber))

    # Last validation after training
    test_samples_indices = range(100)
    total_n = 0
    if validation_only:
        r_loss_loc_valid = 0
        r_loss_scale_valid = 0
        r_loss_latent_valid = 0
        dmm.train(False)
        val_pred_length = args.validpredlength
        val_pred_input_length = 10
        with torch.no_grad():
            for i in tqdm(test_samples_indices, desc='Valid', leave=True):
                # Data processing
                data = valid_dataset[i]
                if torch.isnan(torch.sum(data['observation'])):
                    print("Skip {}".format(i))
                    continue
                else:
                    total_n += 1
                data['observation'] = normalize(
                    data['observation'].unsqueeze(0).unsqueeze(2).to(device))
                batch_size, length, _, w, h = data['observation'].shape
                data_reversed = reverse_sequences(data['observation'])
                data_mask = torch.ones(
                    batch_size, length, input_channels, w, h).to(device)

                # Prediction
                pred_tensor_mask = torch.ones(
                    batch_size, input_length_for_pred, input_channels, w, h).to(device)
                preds, _, loss_loc, loss_scale = do_prediction_rep_inference(
                    dmm, pred_tensor_mask, val_pred_length, val_pred_input_length, data['observation'])

                ground_truth = denormalize(
                    data['observation'].squeeze().cpu().detach()
                )
                pred_with_input = denormalize(
                    torch.cat(
                        [data['observation'][:, :-val_pred_length, :, :, :].squeeze(),
                         preds.squeeze()], dim=0
                    ).cpu().detach()
                )

                # Save samples
                if i < 5:
                    save_dir_samples = Path('./samples/more_variance_long')
                    with open(save_dir_samples / '{}-gt-test.pkl'.format(i), 'wb') as fout:
                        pickle.dump(ground_truth, fout)
                    with open(save_dir_samples / '{}-vanilladmm-pred-test.pkl'.format(i), 'wb') as fout:
                        pickle.dump(pred_with_input, fout)

                # Running losses
                r_loss_loc_valid += loss_loc
                r_loss_scale_valid += loss_scale
                r_loss_latent_valid += np.sum((preds.squeeze().detach().cpu().numpy(
                ) - data['latent'][time_length - val_pred_length:, :, :].detach().cpu().numpy()) ** 2)

        # Average losses
        test_samples_indices = range(total_n)
        print(total_n)
        valid_loss_loc_avg = r_loss_loc_valid / \
            (total_n * val_pred_length * width * height)
        valid_loss_scale_avg = r_loss_scale_valid / \
            (total_n * val_pred_length * width * height)
        valid_loss_latent_avg = r_loss_latent_valid / \
            (total_n * val_pred_length * width * height)
        logging.info("Validation obs loss for %ds pred VanillaDMM: %f",
                     val_pred_length, valid_loss_loc_avg)
        logging.info("Validation latent loss: %f", valid_loss_latent_avg)

        with open('VanillaDMMResult.log', 'a+') as fout:
            validation_log = 'Pred {}s VanillaDMM: {}\n'.format(
                val_pred_length, valid_loss_loc_avg)
            fout.write(validation_log)
Exemple #3
0
res = p_diff * a_diff
res = np.where(res > 0., 1., 0.)
cum = np.sum(res)

my_accuracy = cum / res.shape[0]

res = np.where(res == 0.0, -1.0, 0.85)
res = np.cumsum(res)

plt.plot(res)
plt.show()

print('test accuracy', my_accuracy)

y_test = dt.denormalize(y_test, test_data, look_back, look_ahead, alpha)
predictions = dt.denormalize(predictions, test_data, look_back, look_ahead,
                             alpha)

test_dates = dts.int2dates(test_dates)

#for p, a, d in zip(predictions, y_test, test_dates):
#    print(d, p, a)

plt.plot(test_dates[-100:], y_test[-100:], label='actual')
plt.plot(test_dates[-100:], predictions[-100:], label='predictions')
plt.legend()
plt.show()
''' Print model output '''
if (history is not None):
    print(history.history.keys())
Exemple #4
0
                def validate_dataset(dataloader, name='test'):
                    for temp_pred_length in validation_pred_lengths:
                        r_loss_valid = 0
                        r_loss_loc_valid = 0
                        r_loss_scale_valid = 0
                        val_pred_length = temp_pred_length
                        val_pred_input_length = 10
                        dmm.train(False)
                        dmm.convlstm.eval()
                        with torch.no_grad():
                            for i, data in enumerate(
                                    tqdm(dataloader, desc=name, leave=True)):
                                data['observation'] = normalize(
                                    data['observation'].unsqueeze(2).to(
                                        device), data_min_val, data_max_val)
                                batch_size, length, _, w, h = data[
                                    'observation'].shape
                                data_reversed = reverse_sequences(
                                    data['observation'])
                                data_mask = torch.ones(batch_size, length,
                                                       input_channels, w,
                                                       h).cuda()

                                pred_tensor = data[
                                    'observation'][:, :
                                                   input_length_for_pred, :, :, :]
                                pred_tensor_reversed = reverse_sequences(
                                    pred_tensor)
                                pred_tensor_mask = torch.ones(
                                    batch_size, input_length_for_pred,
                                    input_channels, w, h).cuda()

                                ground_truth = data[
                                    'observation'][:,
                                                   input_length_for_pred:, :, :, :]

                                val_nll = svi.evaluate_loss(
                                    data['observation'], data_reversed,
                                    data_mask)
                                preds, _, loss_loc, loss_scale = do_prediction_rep_inference(
                                    dmm, pred_tensor_mask, val_pred_length,
                                    val_pred_input_length, data['observation'])

                                ground_truth = denormalize(
                                    data['observation'].squeeze().cpu().detach(
                                    ), data_min_val, data_max_val)
                                pred_with_input = denormalize(
                                    torch.cat([
                                        data['observation']
                                        [:, :-val_pred_length, :, :, :].
                                        squeeze(),
                                        preds.squeeze()
                                    ],
                                              dim=0).cpu().detach(),
                                    data_min_val, data_max_val)

                                # Running losses
                                r_loss_valid += val_nll
                                r_loss_loc_valid += loss_loc
                                r_loss_scale_valid += loss_scale

                            # Average losses
                            valid_loss_avg = r_loss_valid / \
                                (len(valid_dataset) * time_length)
                            valid_loss_loc_avg = r_loss_loc_valid / \
                                (len(valid_dataset) *
                                 val_pred_length * width * height)
                            writer.add_scalar('Loss/{}'.format(name),
                                              valid_loss_avg, epoch)
                            writer.add_scalar('Loss/{}_obs'.format(name),
                                              valid_loss_loc_avg, epoch)
                            logging.info("Validation loss: %1.5f",
                                         valid_loss_avg)
                            logging.info("Validation obs loss: %1.5f",
                                         valid_loss_loc_avg)
                            log_evaluation(
                                evaluation_logpath,
                                "Validation obs loss for {}s pred {} ConvDMM: {}\n"
                                .format(val_pred_length, args.trialnumber,
                                        valid_loss_loc_avg))
def main(args):
    # Init tensorboard
    writer = SummaryWriter('./runs/' + args.runname + str(args.trialnumber))

    # Set evaluation log file
    evaluation_logpath = './logs/convlstm/evaluation_result.log'
    log_evaluation(evaluation_logpath,
                   'Evaluation Trial - {}\n'.format(args.trialnumber))

    # Constants
    time_length = 30
    input_length = 20
    pred_length = time_length - input_length
    validation_pred_lengths = [5, 10, 15, 20]
    train_batch_size = 16
    valid_batch_size = 1
    min_val, max_val = 0, 1000

    # Device checking
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")

    # Make dataset
    logging.info("Generate data")
    train_datapath = args.datapath / 'train'
    valid_datapath = args.datapath / 'valid'
    train_dataset = DiffusionDataset(train_datapath)
    valid_dataset = DiffusionDataset(valid_datapath)

    # Create data loaders from pickle data
    logging.info("Generate data loaders")
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=train_batch_size,
                                  shuffle=True,
                                  num_workers=4)
    valid_dataloader = DataLoader(valid_dataset,
                                  batch_size=valid_batch_size,
                                  num_workers=4,
                                  shuffle=False)

    # Training parameters
    row_dim = 100
    col_dim = 100
    input_dim = row_dim * col_dim

    # Create model
    logging.warning("Generate model")
    logging.warning(input_dim)
    input_channels = 1
    hidden_channels = [64]
    kernel_size = 3
    pred_input_dim = 10
    model = ConvLSTM(input_channels, hidden_channels, kernel_size,
                     pred_input_dim, device)

    # Initialize model
    logging.info("Initialize model")
    epochs = args.endepoch
    learning_rate = 0.001

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Starting epoch
    start_epoch = args.startepoch

    # Load model
    if start_epoch != 0:
        load_model = Path('./checkpoints/ConvLSTM/e{}-i188-tn{}.pt'.format(
            start_epoch - 1, args.trialnumber))
        if load_model is not None:
            logging.info("loading model from %s..." % load_model)
            model.load_state_dict(torch.load(load_model, map_location=device))

    # Validation only?
    validation_only = args.validonly

    # Train the model
    if not validation_only:
        logging.info("Training model")
        train_losses = []
        valid_losses = []
        for epoch in tqdm(range(start_epoch, epochs), desc='Epoch',
                          leave=True):
            r_loss_train = 0
            model.train(True)
            idx = 0
            mov_avg_loss = 0
            for data in tqdm(train_dataloader, desc='Train', leave=True):
                optimizer.zero_grad()

                data['observation'] = normalize(
                    data['observation'].unsqueeze(2).to(device), min_val,
                    max_val)
                batch_size, length, _, w, h = data['observation'].shape

                preds, _ = model(
                    data['observation'][:, :-pred_length, :, :, :],
                    pred_length)

                loss = criterion(
                    preds, data['observation'][:, input_length:, :, :, :])
                loss.backward()
                optimizer.step()

                # Running losses
                mov_avg_loss += loss.item() * data['observation'].size(0)

                if idx % 3 == 2:
                    with open('training.log', 'a+') as fout:
                        temp_avg = mov_avg_loss / 3
                        training_log = 'Epoch {}: Iter {}: {}\n'.format(
                            epoch, idx, temp_avg)
                        fout.write(training_log)
                    mov_avg_loss = 0

                r_loss_train += loss.item() * data['observation'].size(0)
                idx += 1

            # Average losses
            train_loss_avg = r_loss_train / len(train_dataset)
            writer.add_scalar('Loss/train', train_loss_avg, epoch)
            logging.info("Epoch: %d, Training loss: %1.5f", epoch,
                         train_loss_avg)

            # Time to time evaluation
            if epoch == epochs - 1:
                for temp_pred_length in validation_pred_lengths:
                    r_loss_valid = 0
                    r_loss_loc_valid = 0
                    val_input_length = time_length - temp_pred_length
                    model.train(False)
                    with torch.no_grad():
                        for i, data in enumerate(
                                tqdm(valid_dataloader,
                                     desc='Valid',
                                     leave=True)):
                            data['observation'] = normalize(
                                data['observation'].unsqueeze(2).to(device),
                                min_val, max_val)
                            batch_size, length, _, w, h = data[
                                'observation'].shape

                            preds = model.forward_with_gt(
                                data['observation'], temp_pred_length)

                            loss = criterion(
                                preds, data['observation']
                                [:, val_input_length:, :, :, :])

                            ground_truth = denormalize(
                                data['observation'].squeeze().cpu().detach(),
                                min_val, max_val)
                            pred_with_input = denormalize(
                                torch.cat([
                                    data['observation']
                                    [:, :-temp_pred_length, :, :, :].squeeze(),
                                    preds.squeeze()
                                ],
                                          dim=0).cpu().detach(), min_val,
                                max_val)

                            # Running losses
                            r_loss_valid += loss.item() * \
                                data['observation'].size(0)
                            r_loss_loc_valid += np.sum((
                                preds.detach().cpu().numpy() -
                                data['observation'][:,
                                                    val_input_length:, :, :, :]
                                .detach().cpu().numpy())**2)

                    # Average losses
                    valid_loss_avg = r_loss_valid / len(valid_dataset)
                    valid_loss_loc_avg = r_loss_loc_valid / \
                        (len(valid_dataset) * temp_pred_length * col_dim * row_dim)
                    writer.add_scalar('Loss/test_obs', valid_loss_loc_avg,
                                      epoch)
                    writer.add_scalar('Loss/test', valid_loss_avg, epoch)
                    logging.info("Validation loss: %1.5f", valid_loss_avg)
                    logging.info("Validation obs loss: %1.5f",
                                 valid_loss_loc_avg)
                    log_evaluation(
                        evaluation_logpath,
                        "Validation obs loss for {}s pred {}: {}\n".format(
                            temp_pred_length, args.trialnumber,
                            valid_loss_loc_avg))

                # Save model
                torch.save(
                    model.state_dict(), args.modelsavepath /
                    'e{}-i{}-tn{}.pt'.format(epoch, idx, args.trialnumber))

    # Last validation after training
    test_samples_indices = range(len(valid_dataset))
    total_n = len(valid_dataset)
    if validation_only:
        r_loss_valid = 0
        r_loss_loc_valid = 0
        r_loss_latent_valid = 0
        model.train(False)
        pred_length = args.validpredlength
        input_length = time_length - pred_length
        with torch.no_grad():
            for i in tqdm(test_samples_indices, desc='Valid', leave=True):
                # Data processing
                data = valid_dataset[i]
                data['observation'] = normalize(
                    data['observation'].unsqueeze(0).unsqueeze(2).to(device),
                    min_val, max_val)
                batch_size, length, _, w, h = data['observation'].shape

                preds = model.forward_with_gt(data['observation'], pred_length)

                loss = criterion(
                    preds, data['observation'][:, input_length:, :, :, :])

                ground_truth = denormalize(
                    data['observation'].squeeze().cpu().detach(), min_val,
                    max_val)
                pred_with_input = denormalize(
                    torch.cat([
                        data['observation']
                        [:, :-pred_length, :, :, :].squeeze(),
                        preds.squeeze()
                    ],
                              dim=0).cpu().detach(), min_val, max_val)

                # Running losses
                r_loss_valid += loss.item() * data['observation'].size(0)
                r_loss_loc_valid += np.sum(
                    (preds.detach().cpu().numpy() - data['observation']
                     [:, input_length:, :, :, :].detach().cpu().numpy())**2)
                r_loss_latent_valid += np.sum((
                    preds.squeeze().detach().cpu().numpy() -
                    data['latent'][input_length:, :, :].detach().cpu().numpy()
                )**2)

        # Average losses
        print(total_n)

        valid_loss_avg = r_loss_valid / total_n
        valid_loss_loc_avg = r_loss_loc_valid / \
            (total_n * pred_length * col_dim * row_dim)
        valid_loss_latent_avg = r_loss_latent_valid / \
            (total_n * pred_length * col_dim * row_dim)
        logging.info("Validation loss: %f", valid_loss_avg)
        logging.info("Validation obs loss: %f", valid_loss_loc_avg)
        logging.info("Validation latent loss: %f", valid_loss_latent_avg)