Example #1
0
def train():
    parser = argparse.ArgumentParser(description='Train VAE.')
    parser.add_argument('-c', '--config', default='train_config.json', help='Config file.')
    args = parser.parse_args()
    print(args)
    c = json.load(open(args.config))
    print(c)

    pyro.clear_param_store()

    # TODO: Move to config file.
    lookback = 50
    max_n_files = None

    train_start_date = datetime.strptime(c['train_start_date'], '%Y/%m/%d')
    train_end_date = datetime.strptime(c['train_end_date'], '%Y/%m/%d')
    val_start_date = datetime.strptime(c['val_start_date'], '%Y/%m/%d')
    val_end_date = datetime.strptime(c['val_end_date'], '%Y/%m/%d')
    min_sequence_length_train = 2 * (c['series_length'] + lookback)
    min_sequence_length_test = 2 * (c['series_length'] + lookback)

    out_path = Path(c['out_dir'])
    out_path.mkdir(exist_ok=True)

    dataset_train = create_ticker_dataset(c['in_dir'], c['series_length'], lookback, min_sequence_length_train,
                                          start_date=train_start_date, end_date=train_end_date,
                                          normalised_returns=c['normalised_returns'], max_n_files=max_n_files)
    dataset_val = create_ticker_dataset(c['in_dir'], c['series_length'], lookback, min_sequence_length_test,
                                        start_date=val_start_date, end_date=val_end_date, fixed_start_date=True,
                                        normalised_returns=c['normalised_returns'], max_n_files=max_n_files)
    train_loader = DataLoader(dataset_train, batch_size=c['batch_size'], shuffle=True, num_workers=0, drop_last=True)
    val_loader = DataLoader(dataset_val, batch_size=c['batch_size'], shuffle=False, num_workers=0, drop_last=True)

    N_train_data = len(dataset_train)
    N_val_data = len(dataset_val)
    N_mini_batches = N_train_data // c['batch_size']
    N_train_time_slices = c['batch_size'] * N_mini_batches

    print(f'N_train_data: {N_train_data}, N_val_data: {N_val_data}')

    # setup the VAE
    vae = VAE(c['series_length'], z_dim=c['z_dim'], hidden_dims=c['hidden_dims'], use_cuda=c['cuda'])

    # setup the optimizer
    adam_args = {"lr": c['learning_rate']}
    optimizer = Adam(adam_args)

    # setup the inference algorithm
    elbo = JitTrace_ELBO() if c['jit'] else Trace_ELBO()
    svi = SVI(vae.model, vae.guide, optimizer, loss=elbo)

    if c['checkpoint_load']:
        checkpoint = torch.load(c['checkpoint_load'])
        vae.load_state_dict(checkpoint['model_state_dict'])
        # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    train_elbo = []
    val_elbo = []
    # training loop
    for epoch in range(c['n_epochs']):
        # initialize loss accumulator
        epoch_loss = 0.
        # do a training epoch over each mini-batch x returned
        # by the data loader
        for batch in train_loader:
            x = batch['series']
            # if on GPU put mini-batch into CUDA memory
            if c['cuda']:
                x = x.cuda()
            # do ELBO gradient and accumulate loss
            epoch_loss += svi.step(x.float())

        # report training diagnostics
        normalizer_train = len(train_loader.dataset)
        total_epoch_loss_train = epoch_loss / normalizer_train
        train_elbo.append(total_epoch_loss_train)
        print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))

        torch.save({
            'epoch': epoch,
            'model_state_dict': vae.state_dict(),
            # 'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': epoch_loss
        }, out_path / c['checkpoint_save'].format(epoch))

        if epoch % c['val_frequency'] == 0:
            # initialize loss accumulator
            val_loss = 0.
            # compute the loss over the entire test set
            for i, batch in enumerate(val_loader):
                x = batch['series']
                # if on GPU put mini-batch into CUDA memory
                if c['cuda']:
                    x = x.cuda()
                x = x.float()
                # compute ELBO estimate and accumulate loss
                val_loss += svi.evaluate_loss(x)

                if i == 0:
                    # Visualise first batch.
                    x_reconst = vae.reconstruct_img(x)
                    x = x.cpu().numpy()
                    x_reconst = x_reconst.cpu().detach().numpy()

                    n = min(5, x.shape[0])
                    fig, axes = plt.subplots(n, 1, squeeze=False)
                    for s in range(n):
                        ax = axes[s, 0]
                        ax.plot(x[s])
                        ax.plot(x_reconst[s])
                    fig.savefig(out_path / f'val_{epoch:03d}.png')
                    plt.close(fig)

            # report test diagnostics
            normalizer_val = len(val_loader.dataset)
            total_epoch_loss_val = val_loss / normalizer_val
            val_elbo.append(total_epoch_loss_val)
            print("[epoch %03d]  average val loss: %.4f" % (epoch, total_epoch_loss_val))

            # t-SNE.
            all_z_latents = []
            for batch in val_loader:
                x = batch['series']
                # z_latents = minibatch_inference(dmm, test_batch)
                # z_latents = encode_x_to_z(dmm, test_batch, sample_z_t=False)
                # x, z, x_reconst = test_minibatch(dmm, test_batch, args, sample_z=True)

                if c['cuda']:
                    x = x.cuda()

                z_loc, z_scale, z = vae.encode_x(x.float())
                all_z_latents.append(z.cpu().numpy())

            # all_latents = torch.cat(all_z_latents, dim=0)
            all_latents = np.concatenate(all_z_latents, axis=0)

            # Run t-SNE with 2 output dimensions.
            from sklearn.manifold import TSNE
            model_tsne = TSNE(n_components=2, random_state=0)
            # z_states = all_latents.detach().cpu().numpy()
            z_states = all_latents
            z_embed = model_tsne.fit_transform(z_states)
            # Plot t-SNE embedding.
            fig = plt.figure()
            plt.scatter(z_embed[:, 0], z_embed[:, 1], s=10)

            fig.savefig(out_path / f'tsne_{epoch:03d}.png')
            plt.close(fig)

    print('Finished training.')
Example #2
0
def train():
    parser = argparse.ArgumentParser(description='Train VAE.')
    parser.add_argument('-c',
                        '--config',
                        default='train_config.json',
                        help='Config file.')
    args = parser.parse_args()
    print(args)
    c = json.load(open(args.config))
    print(c)

    # clear param store
    pyro.clear_param_store()

    input_dim = 1
    max_n_examples = None

    out_path = Path(c['out_dir'])
    out_path.mkdir(exist_ok=True)

    if 0:
        dataset_train = create_ranking_dataset(c['training_filename'],
                                               0.0,
                                               0.7,
                                               max_n_examples=max_n_examples)
        dataset_val = create_ranking_dataset(c['training_filename'],
                                             0.7,
                                             1.0,
                                             max_n_examples=max_n_examples)
        train_loader = DataLoader(dataset_train,
                                  batch_size=c['batch_size'],
                                  shuffle=True,
                                  num_workers=3,
                                  drop_last=True)
        val_loader = DataLoader(dataset_val,
                                batch_size=c['batch_size'],
                                shuffle=False,
                                num_workers=3,
                                drop_last=True)
    else:
        dataset_train = create_ranking_dataset(n_examples=100)
        dataset_val = create_ranking_dataset(n_examples=100)
        train_loader = DataLoader(dataset_train,
                                  batch_size=8,
                                  shuffle=True,
                                  num_workers=3,
                                  drop_last=True)
        val_loader = DataLoader(dataset_val,
                                batch_size=8,
                                shuffle=False,
                                num_workers=3,
                                drop_last=True)
        c['num_features'] = 2

    N_train_data = len(dataset_train)
    N_val_data = len(dataset_val)
    N_mini_batches = N_train_data // c['batch_size']
    N_train_time_slices = c['batch_size'] * N_mini_batches

    print(f'N_train_data: {N_train_data}, N_val_data: {N_val_data}')

    # setup the VAE
    vae = VAE(c['num_features'], z_dim=c['z_dim'], use_cuda=c['cuda'])

    # setup the optimizer
    adam_args = {"lr": c['learning_rate']}
    optimizer = Adam(adam_args)

    # setup the inference algorithm
    elbo = JitTrace_ELBO() if c['jit'] else Trace_ELBO()
    svi = SVI(vae.model, vae.guide, optimizer, loss=elbo)

    if c['checkpoint_load']:
        checkpoint = torch.load(c['checkpoint_load'])
        vae.load_state_dict(checkpoint['model_state_dict'])
        # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    train_elbo = []
    val_elbo = []
    # training loop
    num_epochs = c['n_epochs']
    for epoch in range(num_epochs):
        print(f'Starting epoch {epoch} of {num_epochs}.')
        # initialize loss accumulator
        epoch_loss = 0.
        # do a training epoch over each mini-batch x returned
        # by the data loader
        for batch_num, batch in enumerate(train_loader):
            print(f'Batch {batch_num} of {N_mini_batches}.')
            features_1 = batch['features_1']
            features_2 = batch['features_2']
            target_class = batch['target_class']
            if c['cuda']:
                features_1 = features_1.cuda()
                features_2 = features_2.cuda()
                target_class = target_class.cuda()
            # do ELBO gradient and accumulate loss
            epoch_loss += svi.step(features_1.float(), features_2.float(),
                                   target_class.float())

        # report training diagnostics
        normalizer_train = len(train_loader.dataset)
        total_epoch_loss_train = epoch_loss / normalizer_train
        train_elbo.append(total_epoch_loss_train)
        print("[epoch %03d]  average training loss: %.4f" %
              (epoch, total_epoch_loss_train))

        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': vae.state_dict(),
                # 'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': epoch_loss
            },
            out_path / c['checkpoint_save'].format(epoch))

        if epoch % c['val_frequency'] == 0:
            print('Evaluation validation data.')
            val_loss = 0.
            true_positives = 0
            num_val_examples = 0

            # Compute the loss over the validation set.
            for i, batch in enumerate(val_loader):
                features_1 = batch['features_1']
                features_2 = batch['features_2']
                target_class = batch['target_class']
                # if on GPU put mini-batch into CUDA memory
                if c['cuda']:
                    features_1 = features_1.cuda()
                    features_2 = features_2.cuda()
                    target_class = target_class.cuda()
                features_1 = features_1.float()
                features_2 = features_2.float()
                target_class = target_class.float()
                # compute ELBO estimate and accumulate loss
                val_loss += svi.evaluate_loss(features_1, features_2,
                                              target_class)

                if 1:
                    z_1_loc, z_1_scale, z_1_sample = vae.encode_x(features_1)
                    z_2_loc, z_2_scale, z_2_sample = vae.encode_x(features_2)
                    pred = torch.sigmoid(z_1_loc - z_2_loc)
                    pred[pred >= 0.5] = 1
                    pred[pred < 0.5] = 0
                    target_class = torch.unsqueeze(target_class, dim=1)
                    pred_target = torch.cat([pred, target_class], dim=1)
                    # print(f'pred vs target: {pred_target}')

                    pred = pred.cpu().detach().numpy()
                    target_class = target_class.cpu().detach().numpy()
                    pred_correct = (pred == target_class).astype(int)
                    # print(pred_correct)
                    # print(target_class.shape)
                    pred_correct.flatten()
                    true_positives += np.sum(pred_correct)
                    num_val_examples += pred_correct.shape[0]
                    # accuracy = np.sum(pred_correct) / pred_correct.shape[0]
                    # print('accuracy:', accuracy)
                    # return

                if 0:
                    # Make some rank predictions.
                    z_1_loc, z_1_scale, z_1_sample = vae.encode_x(x_1)
                    z_2_loc, z_2_scale, z_2_sample = vae.encode_x(x_2)
                    z_loc = torch.cat([z_1_loc, z_2_loc], dim=1)
                    print(f'y: {y}, z_loc: {z_loc}')
                    return

                if 0:
                    if i == 0:
                        # Visualise first batch.
                        x_reconst = vae.reconstruct_img(x)
                        x = x.cpu().numpy()
                        x_reconst = x_reconst.cpu().detach().numpy()

                        n = min(5, x.shape[0])
                        fig, axes = plt.subplots(n, 1, squeeze=False)
                        for s in range(n):
                            ax = axes[s, 0]
                            ax.plot(x[s])
                            ax.plot(x_reconst[s])
                        fig.savefig(out_path / f'val_{epoch:03d}.png')
                        plt.close(fig)

            # report test diagnostics
            normalizer_val = len(val_loader.dataset)
            total_epoch_loss_val = val_loss / normalizer_val
            val_elbo.append(total_epoch_loss_val)
            print("[epoch %03d]  average val loss: %.4f" %
                  (epoch, total_epoch_loss_val))

            accuracy = true_positives / num_val_examples
            print(f'Accuracy: {accuracy}')

    print('Finished training.')