def run_mnist_prediction(prediction_hparams): # DATA datamodule = MNISTDataModule(data_dir=DATA_PATH + '/mnist_') datamodule.prepare_data() # downloads data to given path train_dataloader = datamodule.train_dataloader() val_dataloader = datamodule.val_dataloader() test_dataloader = datamodule.test_dataloader() encoder_model = SimpleConvNet() return run_prediction(prediction_hparams, train_dataloader, val_dataloader, test_dataloader, encoder_model, encoder_out_dim=50)
def run_mnist_dvrl(prediction_hparams, dvrl_hparams): # DATA datamodule = MNISTDataModule(data_dir=DATA_PATH + '/mnist_') datamodule.prepare_data() # downloads data to given path train_dataloader = datamodule.train_dataloader( batch_size=dvrl_hparams.get('outer_batch_size', 32)) val_dataloader = datamodule.val_dataloader( batch_size=dvrl_hparams.get('outer_batch_size', 32)) test_dataloader = datamodule.test_dataloader( batch_size=dvrl_hparams.get('outer_batch_size', 32)) val_split = datamodule.val_split encoder_model = SimpleConvNet() return run_gumbel(dvrl_hparams, prediction_hparams, train_dataloader, val_dataloader, test_dataloader, val_split, encoder_model, encoder_out_dim=50)
siamese_net_opt_lr = 1e-3 data_dir = '../data/' train_transforms = None val_transforms = None mnist_dm = MNISTDataModule(data_dir, train_transforms=train_transforms, val_transforms=val_transforms) net = DualClassifier().cuda(device_id) net_opt = optim.SGD(net.parameters(), lr=net_opt_lr) siamese_net = ReversedSiameseNet().cuda(device_id) siamese_net_opt = optim.Adam(siamese_net.parameters(), lr=siamese_net_opt_lr) for eid in range(epochs): train_accuracy, train_loss = [], [] for batch_id, batch in enumerate(mnist_dm.train_dataloader(train_batch_size)): x, y_true = batch[0].cuda(device_id), batch[1].cuda(device_id) # train_net = batch_id % 5 == 0 # collect per-layer inputs and y_pred gradient net_opt.zero_grad() y_pred, inputs = net(x, siamese_net) saved_grad = [] y_pred.register_hook(make_hook(saved_grad)) loss = fn.cross_entropy(y_pred, y_true) loss.backward() # print(net.l1.forward_weight.grad[0, 0:10]) # print(net.l2.forward_weight.grad[0, 0:10]) # print(net.l3.forward_weight.grad[0, 0:10])
device_id = None latest_checkpoint = "" if __name__ == '__main__': train_batch_size = 1024 val_batch_size = 2048 data_dir = '../data/' train_transforms = None val_transforms = None mnist_dm = MNISTDataModule(data_dir, train_transforms=train_transforms, val_transforms=val_transforms) mlp = DeepELM(1) train_pipe = TrainPipe(model=mlp) classic_trainer = pl.Trainer( gpus=([device_id] if device_id is not None else None), max_epochs=1000) print("TRAINING MLP") classic_trainer.fit( train_pipe, train_dataloader=mnist_dm.train_dataloader(train_batch_size), val_dataloaders=mnist_dm.val_dataloader(val_batch_size)) # print("SAVE MLP") # torch.save(train_pipe.state_dict(), 'model_v1') # print("TEST MLP") # classic_trainer.test(train_pipe, datamodule=mnist_dm)
val_transforms = None mnist_dm = MNISTDataModule(data_dir, train_transforms=train_transforms, val_transforms=val_transforms) lt_cnn_model = LateralCNN(1) train_pipe = TrainPipe(model=lt_cnn_model) classic_trainer = pl.Trainer( gpus=([device_id] if device_id is not None else None), max_epochs=10) if not latest_checkpoint: print("TRAINING CLASSIC CNN") classic_trainer.fit(train_pipe, datamodule=mnist_dm) print('COLLECTING FEATURE MAPS') train_pipe.start_lateral_training() for batch in mnist_dm.train_dataloader(): if device_id is not None: batch = (batch[0].cuda(device_id), batch[1].cuda(device_id)) train_pipe.validation_step(batch, None) print("CALCULATING LATERAL LAYERS") train_pipe.finish_lateral_training() print("SAVE LATERAL CNN") torch.save(train_pipe.state_dict(), 'lateral_model_v1') else: train_pipe.load_state_dict(torch.load(latest_checkpoint)) train_pipe.enable_laterals() print(train_pipe.model.conv1.laterals)