Esempio n. 1
0
def split(dataset: Dataset,
          split_fraction: float,
          seed: int,
          return_test: bool = True) -> Tuple[Dataset, ...]:
    """
    Splits the dataset into pipelines and validation (and testing if return_test is true)
    :param dataset: Dataset object
    :param split_fraction: Fraction of the whole dataset to be used for validation
    :param seed: Seed used for splitting
    :param return_test: if should split into three parts
    :return:
    """

    splitter = DatasetValidationSplitter(len(dataset),
                                         split_fraction,
                                         shuffle_seed=seed)

    trainset = splitter.get_train_dataset(dataset)
    valset = splitter.get_val_dataset(dataset)

    # Split the valset into test and validation
    if return_test:
        # Set split_fraction to low value such that testset
        valset, testset = split(valset,
                                split_fraction=0.70,
                                seed=seed,
                                return_test=False)
        return trainset, valset, testset
    else:
        return trainset, valset
def run(train_batch_size,
        val_batch_size,
        epochs,
        lr,
        log_interval,
        input_size=10,
        hidden_size=100,
        out_size=4):
    dataset = FuzzBuzzDataset(input_size)
    splitter = DatasetValidationSplitter(len(dataset), 0.1)
    train_set = splitter.get_train_dataset(dataset)
    val_set = splitter.get_val_dataset(dataset)

    train_loader = DataLoader(train_set,
                              pin_memory=True,
                              batch_size=train_batch_size,
                              shuffle=True,
                              num_workers=2)
    val_loader = DataLoader(val_set,
                            pin_memory=True,
                            batch_size=val_batch_size,
                            shuffle=False,
                            num_workers=2)
    model = FuzzBuzzModel(input_size, hidden_size, out_size)

    device = 'cpu'

    if torch.cuda.is_available():
        device = 'cuda'

    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=lr)
    loss = nn.CrossEntropyLoss()

    trial = Trial(model, optimizer, criterion=loss,
                  metrics=['acc', 'loss']).to(device)
    trial = trial.with_generators(train_generator=train_loader,
                                  val_generator=val_loader)
    trial.run(epochs=epochs)

    trial.evaluate(data_key=VALIDATION_DATA)
Esempio n. 3
0
    def splitting(args):
        trainset, testset = func(args)

        if args.fold == 'test':
            return trainset, testset, testset

        if args.run_id == 0 and not os.path.exists(args.fold_path):
            gen_folds(args, trainset, len(trainset) // args.n_folds)
        else:
            time.sleep(3)

        folds = np.load(args.fold_path)
        train_ids, val_ids = folds['train'][int(args.fold)], folds['test'][int(
            args.fold)]

        splitter = DatasetValidationSplitter(len(trainset), 0.1)
        splitter.train_ids, splitter.valid_ids = train_ids, val_ids

        trainset, valset = splitter.get_train_dataset(
            trainset), splitter.get_val_dataset(trainset)
        return trainset, valset, testset
Esempio n. 4
0
def run(count, glimpse_size, memory_size, iteration, device='cuda'):
    base_dir = os.path.join('celeba_' + str(memory_size), str(glimpse_size))
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)

    transform_train = transforms.Compose([
        transforms.ToTensor()
    ])

    dataset = torchvision.datasets.ImageFolder(root='./cropped_celeba/', transform=transform_train)
    splitter = DatasetValidationSplitter(len(dataset), 0.05)
    trainset = splitter.get_train_dataset(dataset)

    # Save the ids
    torch.save((splitter.train_ids, splitter.valid_ids), os.path.join(base_dir, 'split.dat'))

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=10)

    model = CelebDraw(count, glimpse_size, memory_size)

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

    from datetime import datetime
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')

    call_a = callbacks.TensorBoardImages(comment=current_time, name='Prediction', write_each_epoch=True, key=torchbearer.Y_PRED)
    call_a.on_step_training = call_a.on_step_validation  # Hack to make this log training samples
    call_b = callbacks.TensorBoardImages(comment=current_time + '_celeba', name='Target', write_each_epoch=True,
                                key=torchbearer.Y_TRUE)
    call_b.on_step_training = call_b.on_step_validation  # Hack to make this log training samples

    trial = Trial(model, optimizer, nn.MSELoss(reduction='sum'), ['acc', 'loss'], pass_state=True, callbacks=[
        joint_kl_divergence(MU, LOGVAR),
        callbacks.MostRecent(os.path.join(base_dir, 'iter_' + str(iteration) + '.{epoch:02d}.pt')),
        callbacks.GradientClipping(5),
        call_a,
        call_b
    ]).with_generators(train_generator=trainloader).to(device)

    trial.run(100)
Esempio n. 5
0
def draw(count, glimpse_size, memory_size, file, device='cuda'):
    base_dir = os.path.join('celeba_' + str(memory_size), str(glimpse_size))

    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    dataset = torchvision.datasets.ImageFolder(root='./cropped_celeba/', transform=transform)
    splitter = DatasetValidationSplitter(len(dataset), 0.05)

    # load the ids
    splitter.train_ids, splitter.valid_ids = torch.load(os.path.join(base_dir, 'split.dat'))

    testset = splitter.get_val_dataset(dataset)

    testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=True, num_workers=10)

    model = CelebDraw(count, glimpse_size, memory_size, output_stages=True)

    optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0)

    from datetime import datetime
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')

    from visualise import StagesGrid

    trial = Trial(model, optimizer, nn.MSELoss(reduction='sum'), ['loss'], pass_state=True, callbacks=[
        callbacks.TensorBoardImages(comment=current_time, nrow=10, num_images=20, name='Prediction', write_each_epoch=True,
                                    key=torchbearer.Y_PRED, pad_value=1),
        callbacks.TensorBoardImages(comment=current_time + '_celeb', nrow=10, num_images=20, name='Target', write_each_epoch=False,
                                    key=torchbearer.Y_TRUE, pad_value=1),
        callbacks.TensorBoardImages(comment=current_time + '_celeb_mask', nrow=10, num_images=20, name='Masked Target', write_each_epoch=False,
                                    key=MASKED_TARGET, pad_value=1),
        StagesGrid('celeb_stages.png', STAGES, 20)
    ]).load_state_dict(torch.load(os.path.join(base_dir, file)), resume=False).with_generators(train_generator=testloader, val_generator=testloader).for_train_steps(1).for_val_steps(1).to(device)

    trial.run()  # Evaluate doesn't work with tensorboard in torchbearer, seems to have been fixed in most recent version
Esempio n. 6
0
    def __len__(self):
        return len(self.mnist_dataset)


BATCH_SIZE = 128

transform = transforms.Compose([transforms.ToTensor()])

# Define standard classification mnist dataset with random validation set

dataset = torchvision.datasets.MNIST('./data/mnist',
                                     train=True,
                                     download=True,
                                     transform=transform)
splitter = DatasetValidationSplitter(len(dataset), 0.1)
basetrainset = splitter.get_train_dataset(dataset)
basevalset = splitter.get_val_dataset(dataset)

# Wrap base classification mnist dataset to return the image as the target

trainset = AutoEncoderMNIST(basetrainset)

valset = AutoEncoderMNIST(basevalset)

traingen = torch.utils.data.DataLoader(trainset,
                                       batch_size=BATCH_SIZE,
                                       shuffle=True,
                                       num_workers=8)

valgen = torch.utils.data.DataLoader(valset,