Ejemplo n.º 1
0
def main(data_dir='/tmp', epochs=10000, M=20, lr=1e-2,
         batch_size=512, beta=1.0, map_est_hypers=False,
         seed=None):
  set_seeds(seed)

  wandb.init(tensorboard=True)

  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  logger = SummaryWriter(log_dir=wandb.run.dir)

  toy_train = ToyDataset()
  toy_val = ToyDataset(X=toy_train.data.clone(), Y=toy_train.targets.clone())
  toy_test = ToyDataset(X=toy_train.data.clone(), Y=toy_train.targets.clone())

  prev_params = []
  for t in range(2):
    toy_train.filter_by_class([2 * t, 2 * t + 1])
    toy_val.filter_by_class(range(2 * t + 2))
    toy_test.filter_by_class(range(2 * t + 2))

    state_dict = train(t, toy_train, toy_val, toy_test,
                       epochs=epochs, M=M*(t+1), lr=lr, beta=beta, batch_size=batch_size,
                       map_est_hypers=bool(map_est_hypers),
                       prev_params=prev_params, logger=logger, device=device)

    prev_params = [state_dict]

  logger.close()
Ejemplo n.º 2
0
def split_mnist(data_dir=None, epochs=500, M=60, lr=3e-3,
                batch_size=512, beta=10.0, map_est_hypers=False,
                seed=None):
  data_dir = data_dir or os.environ.get('USER_DATADIR', default='/tmp')
  set_seeds(seed)

  wandb.init(tensorboard=True)

  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  logger = SummaryWriter(log_dir=wandb.run.dir)

  mnist_train = SplitMNIST(f'{data_dir}', train=True)
  mnist_val = SplitMNIST(f'{data_dir}', train=True)
  mnist_test = SplitMNIST(f'{data_dir}', train=False)

  idx = torch.randperm(len(mnist_train))
  train_idx, val_idx = idx[:-10000], idx[-10000:]
  mnist_train.filter_by_idx(train_idx)
  mnist_val.filter_by_idx(val_idx)

  prev_params = []
  for t in range(5):
    mnist_train.filter_by_class([2 * t, 2 * t + 1])
    mnist_val.filter_by_class(range(2 * t + 2))
    mnist_test.filter_by_class(range(2 * t + 2))

    state_dict = train(t, mnist_train, mnist_val, mnist_test,
                       epochs=epochs, M=M, lr=lr, beta=beta, batch_size=batch_size,
                       map_est_hypers=bool(map_est_hypers),
                       prev_params=prev_params, logger=logger, device=device)

    prev_params.append(state_dict)

  logger.close()
Ejemplo n.º 3
0
def permuted_mnist(data_dir='/tmp',
                   n_tasks=10,
                   epochs=1000,
                   M=100,
                   lr=3.7e-3,
                   batch_size=512,
                   beta=1.64,
                   seed=None):
    set_seeds(seed)

    wandb.init(tensorboard=True)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    logger = SummaryWriter(log_dir=wandb.run.dir)

    ## NOTE: First task is unpermuted MNIST.
    tasks = [torch.arange(784)] + PermutedMNIST.create_tasks(n=n_tasks - 1)

    mnist_train = PermutedMNIST(f'{data_dir}', train=True)

    idx = torch.randperm(len(mnist_train))
    train_idx, val_idx = idx[:-10000], idx[-10000:]
    mnist_train.filter_by_idx(train_idx)

    mnist_val = []
    mnist_test = []

    prev_params = []
    for t in range(len(tasks)):
        mnist_train = PermutedMNIST(f'{data_dir}', train=True)
        mnist_train.filter_by_idx(train_idx)
        mnist_train.set_task(tasks[t])

        mnist_val.append(PermutedMNIST(f'{data_dir}', train=True))
        mnist_val[-1].filter_by_idx(val_idx)
        mnist_val[-1].set_task(tasks[t])

        mnist_test.append(PermutedMNIST(f'{data_dir}', train=False))
        mnist_test[-1].set_task(tasks[t])

        state_dict = train(t,
                           mnist_train,
                           ConcatDataset(mnist_val),
                           ConcatDataset(mnist_test),
                           epochs=epochs,
                           M=M,
                           lr=lr,
                           beta=beta,
                           batch_size=batch_size,
                           prev_params=prev_params,
                           logger=logger,
                           device=device)

        prev_params.append(state_dict)

    logger.close()
Ejemplo n.º 4
0
def toy(data_dir=None,
        epochs=5000,
        M=20,
        lr=1e-2,
        batch_size=512,
        beta=1.0,
        ep_var_mean=True,
        map_est_hypers=False,
        dkl=False,
        seed=None):
    data_dir = data_dir or os.environ.get('USER_DATADIR', default='/tmp')
    set_seeds(seed)

    wandb.init(tensorboard=True)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    logger = SummaryWriter(log_dir=wandb.run.dir)

    toy_train = ToyDataset()
    toy_val = ToyDataset(X=toy_train.data.clone(), Y=toy_train.targets.clone())
    toy_test = ToyDataset(X=toy_train.data.clone(),
                          Y=toy_train.targets.clone())

    prev_params = []
    for t in range(2):
        toy_train.filter_by_class([2 * t, 2 * t + 1])
        toy_val.filter_by_class(range(2 * t + 2))
        toy_test.filter_by_class(range(2 * t + 2))

        state_dict = train(t,
                           toy_train,
                           toy_val,
                           toy_test,
                           epochs=epochs,
                           M=M,
                           lr=lr,
                           beta=beta,
                           batch_size=batch_size,
                           ep_var_mean=bool(ep_var_mean),
                           map_est_hypers=bool(map_est_hypers),
                           dkl=bool(dkl),
                           prev_params=prev_params,
                           logger=logger,
                           device=device,
                           patience=-1)

        prev_params.append(state_dict)

    logger.close()