示例#1
0
def main(n, h, lr, epochs, bs, exp_dir):
    # ---------------------------------------------------------------
    # Setup
    # ---------------------------------------------------------------
    # Logging
    log_dir = Path(exp_dir / 'logs/')
    log_dir.mkdir(exist_ok=True)

    logging.basicConfig(level=logging.DEBUG,
                        format='%(levelname)s-%(message)s',
                        handlers=[
                            logging.FileHandler(log_dir / 'experiment1'),
                            logging.StreamHandler()
                        ])

    # Configurations
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    logging.info(
        'n={},h={},lr={},epochs={},batch size={},dest={},device={}'.format(
            n, h, lr, epochs, bs, exp_dir, device))
    torch.manual_seed(590238490)
    # device = 'cpu'
    train_loader = MNIST.train(bs)
    test_loader = MNIST.test(bs)
    model = VAE(h, n).to(device)
    optimizer = optim.Adagrad(model.parameters(), lr=lr)
    # ---------------------------------------------------------------
    # Execute experiment
    # ---------------------------------------------------------------
    execute(model, optimizer, epochs, train_loader, test_loader, device, n,
            exp_dir)
示例#2
0
def main(n, h, epochs, bs, learning_rate):
    #    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    train_loader = mnist_.train(bs)
    test_loader = mnist_.test(bs)
    model = VAE(h, n)
    #    model = model.to(device) # for Adagrad: https://github.com/pytorch/pytorch/issues/7321
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    #    optimizer = optim.Adagrad(model.parameters(), lr = learning_rate)
    logging.info('n={},h={},lr={},epochs={},batch size={}'
                 .format(n, h, learning_rate, epochs, batch_size))

    exp = MNIST_vae(model, optimizer, learning_rate, bs, n, train_loader, test_loader)
    train_points, train_loss, test_points, test_loss, seq_train_loss, seq_test_loss = exp.execute(epochs)

    return train_points, train_loss, test_points, test_loss, seq_train_loss, seq_test_loss
示例#3
0
def main_discrete(base_dir=Path('../data/results')):
    # ---------------------------------------------------------------
    # Setup Logger
    # ---------------------------------------------------------------
    log_dir = Path(base_dir / 'comparison/')
    log_dir.mkdir(exist_ok=True)

    logging.basicConfig(level=logging.DEBUG,
                        format='%(levelname)s-%(message)s',
                        handlers=[
                            logging.FileHandler(log_dir / 'comparison'),
                            logging.StreamHandler()
                        ])

    # ---------------------------------------------------------------
    # Common configurations
    # ---------------------------------------------------------------
    torch.manual_seed(590238490)
    data = MNIST.test(64, dir='../data')
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    exp1(base_dir, data, device)
    exp1_relu(base_dir, data, device)
    exp3(base_dir, data, device)
    exp4(base_dir, data, device)
    exp6(base_dir, data, device)
    exp7(base_dir, data, device)
示例#4
0
def get_ordered_digits():
    data = MNIST.test(64, dir='../data')

    digits_tensors = []
    digits_labels = []
    found = []
    for x, labels in data:
        for i, digit in enumerate(labels):
            if digit.item() not in found:
                found.append(digit.item())
                digits_tensors.append(x[i])
                digits_labels.append(digit)
            elif len(found) == 10:
                break
        if len(found) == 10:
            break

    torch_digits_tensors = torch.stack(digits_tensors)
    torch_digits_labels = torch.stack(digits_labels)
    id_sort = torch.argsort(torch.Tensor(digits_labels))

    # select digits in ascending order
    torch_digits_labels = torch_digits_labels[id_sort]
    torch_digits_tensors = torch_digits_tensors[id_sort]

    # Testing
    # for i, data in enumerate(torch_digits_tensors):
    #     img = data.numpy().transpose(1, 2, 0)
    #     lbl = torch_digits_labels[i]
    #     print(lbl)
    #     cv2.imshow('', img)
    #     cv2.waitKey()
    return torch_digits_tensors, torch_digits_labels
示例#5
0
def get_dataloader(conf: str) -> Tuple[DataLoader, DataLoader]:
    if conf.dataset.name == "mnist":
        data = MNIST(conf)
        return data.train_dataloader(), data.val_dataloader()
    else:
        raise Exception(f"Invalid dataset name: {conf.dataset.name}")
示例#6
0
def get_data():
    data = MNIST.test(4000, dir='../data')
    return data