Beispiel #1
0
def init_classifier(dataset):
    """
    Initialize a classifier based on the dataset.
    """
    d = data.to_dataset(dataset)
    if dataset == 'mnist':
        return lenet.LeNet5()
Beispiel #2
0
def init_generator(dataset):
    """
    Initialize a generator based on the dataset.
    """
    d = data.to_dataset(dataset)
    if dataset in ('mnist', 'fashion', 'svhn'):
        return models.ImageGenerator(d.ny, d.nc)
    else:
        return models.DenseGenerator(d.ny, d.nx, n_layers=2)
Beispiel #3
0
def init_classifier(dataset):
    """
    Initialize a classifier based on the dataset.
    """
    d = data.to_dataset(dataset)
    if dataset == 'mnist':
        return lenet.LeNet5()
    elif dataset in ('svhn', 'fashion'):
        return resnet.ResNet(d.nc, d.ny)
    else:
        return linear.MLP(d.nx, d.ny)
Beispiel #4
0
def sample_random_data(dataset, num_data, dist, device):
    """
    Sample artificial data from simple distributions.
    """
    size = (num_data, *data.to_dataset(dataset).size)
    if dist == 'normal':
        return torch.randn(size, device=device)
    elif dist == 'uniform':
        tensor = torch.zeros(size, dtype=torch.float, device=device)
        tensor.uniform_(-1, 1)
        return tensor
    else:
        raise ValueError(dist)
Beispiel #5
0
def prepare_teacher(dataset):
    """
    Prepare datasets and hyperparameters for training a teacher network.
    """
    batch_size = 64

    if dataset == 'mnist':
        lrn_rate = 1e-5
        save_every = 10
        min_epochs = 10000
        val_epochs = 10000
        max_epochs = 200
    elif dataset == 'fashion':
        lrn_rate = 1e-2
        save_every = 20
        min_epochs = 10000
        val_epochs = 10000
        max_epochs = 100
    elif dataset == 'svhn':
        lrn_rate = 1e-4
        save_every = 10
        min_epochs = 10000
        val_epochs = 10000
        max_epochs = 50
    else:
        lrn_rate = 1e-3
        save_every = 0
        val_epochs = 20
        min_epochs = 100
        max_epochs = 1000

    trn_data, val_data, test_data = data.to_dataset(dataset).to_loaders(batch_size)
    loss_func = nn.CrossEntropyLoss().to(DEVICE)

    return dict(
        lrn_rate=lrn_rate,
        save_every=save_every,
        min_epochs=min_epochs,
        val_epochs=val_epochs,
        max_epochs=max_epochs,
        trn_data=trn_data,
        val_data=val_data,
        test_data=test_data,
        trn_loss_func=loss_func,
        test_loss_func=loss_func)
Beispiel #6
0
def prepare_student(model, dataset, data_dist, generators=None):
    """
    Prepare datasets and hyperparameters for training a student network.
    """
    batch_size = 64
    num_batches = 100
    save_every = -1
    val_epochs = 50
    min_epochs = 100

    if data_dist == 'kegnet':
        max_epochs = 400

        if dataset == 'mnist':
            lrn_rate = 1e-5
        else:
            lrn_rate = 1e-4
    elif data_dist in ('normal', 'uniform'):
        max_epochs = 1000

        if dataset == 'mnist':
            lrn_rate = 1e-6
        else:
            lrn_rate = 1e-4
    else:
        raise ValueError()

    trn_data = prepare_data(
        model, data_dist, dataset, batch_size, num_batches, generators)
    _, val_data, test_data = data.to_dataset(dataset).to_loaders(batch_size)
    trn_loss_func = cls_loss.KLDivLoss().to(DEVICE)
    test_loss_func = nn.CrossEntropyLoss().to(DEVICE)

    return dict(
        lrn_rate=lrn_rate,
        save_every=save_every,
        min_epochs=min_epochs,
        val_epochs=val_epochs,
        max_epochs=max_epochs,
        trn_data=trn_data,
        val_data=val_data,
        test_data=test_data,
        trn_loss_func=trn_loss_func,
        test_loss_func=test_loss_func)
Beispiel #7
0
def main(dataset, cls_path, out_path, index=0):
    """
    Main function for training a generator.
    """
    global DEVICE
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    utils.set_seed(seed=2019 + index)

    num_epochs = 200
    save_every = 100
    viz_every = 10

    assert num_epochs >= save_every

    if dataset == 'mnist':
        dec_layers = 1
        lrn_rate = 1e-3
        alpha = 1
        beta = 0
    elif dataset == 'fashion':
        dec_layers = 3
        lrn_rate = 1e-2
        alpha = 1
        beta = 10
    elif dataset == 'svhn':
        dec_layers = 3
        lrn_rate = 1e-2
        alpha = 1
        beta = 1
    else:
        dec_layers = 2
        lrn_rate = 1e-4
        alpha = 1
        beta = 0

    cls_network = cls_utils.init_classifier(dataset).to(DEVICE)
    gen_network = gen_utils.init_generator(dataset).to(DEVICE)
    utils.load_checkpoints(cls_network, cls_path, DEVICE)

    nz = gen_network.num_noises
    nx = data.to_dataset(dataset).nx
    dec_network = models.Decoder(nx, nz, dec_layers).to(DEVICE)

    networks = (gen_network, cls_network, dec_network)

    path_loss = os.path.join(out_path, 'loss-gen.txt')
    dir_model = os.path.join(out_path, 'generator')
    path_model = None

    os.makedirs(os.path.join(out_path, 'images'), exist_ok=True)
    with open(path_loss, 'w') as f:
        f.write('Epoch\tClsLoss\tDecLoss\tDivLoss\tLossSum\tAccuracy\n')

    loss1 = gen_loss.ReconstructionLoss(method='kld').to(DEVICE)
    loss2 = gen_loss.ReconstructionLoss(method='l2').to(DEVICE)
    loss3 = gen_loss.DiversityLoss(metric='l1').to(DEVICE)
    losses = loss1, loss2, loss3

    params = list(gen_network.parameters()) + list(dec_network.parameters())
    optimizer = optim.Adam(params, lrn_rate)

    for epoch in range(1, num_epochs + 1):
        trn_acc, trn_losses = update(networks, losses, optimizer, alpha, beta)

        with open(path_loss, 'a') as f:
            f.write(f'{epoch:3d}')
            for loss in trn_losses:
                f.write(f'\t{loss:.8f}')
            f.write(f'\t{trn_acc:.8f}\n')

        if viz_every > 0 and epoch % viz_every == 0:
            path = os.path.join(out_path, f'images/images-{epoch:03d}.png')
            gen_utils.visualize_images(gen_network, path, DEVICE)

        if epoch % save_every == 0:
            path = f'{dir_model}-{epoch:03d}.pth.tar'
            utils.save_checkpoints(gen_network, path)
            path_model = path

    print(f'Finished training the generator (index={index}).')
    return path_model