Exemplo n.º 1
0
def run_mnist_prediction(prediction_hparams):
    # DATA
    datamodule = MNISTDataModule(data_dir=DATA_PATH + '/mnist_')
    datamodule.prepare_data()  # downloads data to given path
    train_dataloader = datamodule.train_dataloader()
    val_dataloader = datamodule.val_dataloader()
    test_dataloader = datamodule.test_dataloader()

    encoder_model = SimpleConvNet()
    return run_prediction(prediction_hparams, train_dataloader, val_dataloader, test_dataloader,
                          encoder_model, encoder_out_dim=50)
Exemplo n.º 2
0
def run_mnist_dvrl(prediction_hparams, dvrl_hparams):
    # DATA
    datamodule = MNISTDataModule(data_dir=DATA_PATH + '/mnist_')
    datamodule.prepare_data()  # downloads data to given path
    train_dataloader = datamodule.train_dataloader(
        batch_size=dvrl_hparams.get('outer_batch_size', 32))
    val_dataloader = datamodule.val_dataloader(
        batch_size=dvrl_hparams.get('outer_batch_size', 32))
    test_dataloader = datamodule.test_dataloader(
        batch_size=dvrl_hparams.get('outer_batch_size', 32))

    val_split = datamodule.val_split
    encoder_model = SimpleConvNet()
    return run_gumbel(dvrl_hparams,
                      prediction_hparams,
                      train_dataloader,
                      val_dataloader,
                      test_dataloader,
                      val_split,
                      encoder_model,
                      encoder_out_dim=50)
Exemplo n.º 3
0
    siamese_net_opt_lr = 1e-3
    data_dir = '../data/'
    train_transforms = None
    val_transforms = None
    mnist_dm = MNISTDataModule(data_dir, train_transforms=train_transforms, val_transforms=val_transforms)

    net = DualClassifier().cuda(device_id)
    net_opt = optim.SGD(net.parameters(), lr=net_opt_lr)

    siamese_net = ReversedSiameseNet().cuda(device_id)
    siamese_net_opt = optim.Adam(siamese_net.parameters(), lr=siamese_net_opt_lr)

    for eid in range(epochs):

        train_accuracy, train_loss = [], []
        for batch_id, batch in enumerate(mnist_dm.train_dataloader(train_batch_size)):
            x, y_true = batch[0].cuda(device_id), batch[1].cuda(device_id)

            # train_net = batch_id % 5 == 0

            # collect per-layer inputs and y_pred gradient
            net_opt.zero_grad()
            y_pred, inputs = net(x, siamese_net)
            saved_grad = []
            y_pred.register_hook(make_hook(saved_grad))
            loss = fn.cross_entropy(y_pred, y_true)
            loss.backward()

            # print(net.l1.forward_weight.grad[0, 0:10])
            # print(net.l2.forward_weight.grad[0, 0:10])
            # print(net.l3.forward_weight.grad[0, 0:10])
Exemplo n.º 4
0
device_id = None

latest_checkpoint = ""

if __name__ == '__main__':
    train_batch_size = 1024
    val_batch_size = 2048
    data_dir = '../data/'
    train_transforms = None
    val_transforms = None
    mnist_dm = MNISTDataModule(data_dir,
                               train_transforms=train_transforms,
                               val_transforms=val_transforms)
    mlp = DeepELM(1)
    train_pipe = TrainPipe(model=mlp)

    classic_trainer = pl.Trainer(
        gpus=([device_id] if device_id is not None else None), max_epochs=1000)
    print("TRAINING MLP")
    classic_trainer.fit(
        train_pipe,
        train_dataloader=mnist_dm.train_dataloader(train_batch_size),
        val_dataloaders=mnist_dm.val_dataloader(val_batch_size))

    # print("SAVE MLP")
    # torch.save(train_pipe.state_dict(), 'model_v1')

    # print("TEST MLP")
    # classic_trainer.test(train_pipe, datamodule=mnist_dm)
Exemplo n.º 5
0
    val_transforms = None
    mnist_dm = MNISTDataModule(data_dir,
                               train_transforms=train_transforms,
                               val_transforms=val_transforms)
    lt_cnn_model = LateralCNN(1)
    train_pipe = TrainPipe(model=lt_cnn_model)

    classic_trainer = pl.Trainer(
        gpus=([device_id] if device_id is not None else None), max_epochs=10)
    if not latest_checkpoint:
        print("TRAINING CLASSIC CNN")
        classic_trainer.fit(train_pipe, datamodule=mnist_dm)

        print('COLLECTING FEATURE MAPS')
        train_pipe.start_lateral_training()
        for batch in mnist_dm.train_dataloader():
            if device_id is not None:
                batch = (batch[0].cuda(device_id), batch[1].cuda(device_id))

            train_pipe.validation_step(batch, None)

        print("CALCULATING LATERAL LAYERS")
        train_pipe.finish_lateral_training()

        print("SAVE LATERAL CNN")
        torch.save(train_pipe.state_dict(), 'lateral_model_v1')
    else:
        train_pipe.load_state_dict(torch.load(latest_checkpoint))
        train_pipe.enable_laterals()

    print(train_pipe.model.conv1.laterals)