Exemple #1
0
def run_mnist_prediction(prediction_hparams):
    # DATA
    datamodule = MNISTDataModule(data_dir=DATA_PATH + '/mnist_')
    datamodule.prepare_data()  # downloads data to given path
    train_dataloader = datamodule.train_dataloader()
    val_dataloader = datamodule.val_dataloader()
    test_dataloader = datamodule.test_dataloader()

    encoder_model = SimpleConvNet()
    return run_prediction(prediction_hparams, train_dataloader, val_dataloader, test_dataloader,
                          encoder_model, encoder_out_dim=50)
Exemple #2
0
def run_mnist_dvrl(prediction_hparams, dvrl_hparams):
    # DATA
    datamodule = MNISTDataModule(data_dir=DATA_PATH + '/mnist_')
    datamodule.prepare_data()  # downloads data to given path
    train_dataloader = datamodule.train_dataloader(
        batch_size=dvrl_hparams.get('outer_batch_size', 32))
    val_dataloader = datamodule.val_dataloader(
        batch_size=dvrl_hparams.get('outer_batch_size', 32))
    test_dataloader = datamodule.test_dataloader(
        batch_size=dvrl_hparams.get('outer_batch_size', 32))

    val_split = datamodule.val_split
    encoder_model = SimpleConvNet()
    return run_gumbel(dvrl_hparams,
                      prediction_hparams,
                      train_dataloader,
                      val_dataloader,
                      test_dataloader,
                      val_split,
                      encoder_model,
                      encoder_out_dim=50)
Exemple #3
0
            loss = fn.cross_entropy(y_pred, y_true)
            loss.backward()
            siamese_net_opt.step()

            net.train()
            siamese_net.eval()
            net_opt.zero_grad()
            y_pred, _ = net(x, siamese_net)
            loss = fn.cross_entropy(y_pred, y_true)
            train_loss.append(loss.detach().cpu().numpy())
            accuracy = (torch.argmax(y_pred, 1) == y_true).float().mean()
            train_accuracy.append(accuracy.detach().cpu().numpy())

            loss.backward()
            net_opt.step()

        print(f"{eid} | Train: loss={np.mean(train_loss):.6f}, acc={np.mean(train_accuracy):.4f}")

        net.eval()
        val_accuracy, val_loss = [], []
        for batch in mnist_dm.val_dataloader(val_batch_size):
            x, y_true = batch[0].cuda(device_id), batch[1].cuda(device_id)

            y_pred, _ = net(x, siamese_net)
            loss = fn.cross_entropy(y_pred, y_true)
            val_loss.append(loss.detach().cpu().numpy())
            accuracy = (torch.argmax(y_pred, 1) == y_true).float().mean()
            val_accuracy.append(accuracy.detach().cpu().numpy())

        print(f"{eid} | Val: loss={np.mean(val_loss):.6f}, acc={np.mean(val_accuracy):.4f}")
Exemple #4
0
device_id = None

latest_checkpoint = ""

if __name__ == '__main__':
    train_batch_size = 1024
    val_batch_size = 2048
    data_dir = '../data/'
    train_transforms = None
    val_transforms = None
    mnist_dm = MNISTDataModule(data_dir,
                               train_transforms=train_transforms,
                               val_transforms=val_transforms)
    mlp = DeepELM(1)
    train_pipe = TrainPipe(model=mlp)

    classic_trainer = pl.Trainer(
        gpus=([device_id] if device_id is not None else None), max_epochs=1000)
    print("TRAINING MLP")
    classic_trainer.fit(
        train_pipe,
        train_dataloader=mnist_dm.train_dataloader(train_batch_size),
        val_dataloaders=mnist_dm.val_dataloader(val_batch_size))

    # print("SAVE MLP")
    # torch.save(train_pipe.state_dict(), 'model_v1')

    # print("TEST MLP")
    # classic_trainer.test(train_pipe, datamodule=mnist_dm)