Exemplo n.º 1
0
def sample_kegnet_data(dataset, num_data, generators, device):
    """
    Sample artificial data using generator networks.
    """
    gen_models = []
    for path in generators:
        generator = init_generator(dataset).to(device)
        utils.load_checkpoints(generator, path, device)
        generator.eval()
        gen_models.append(generator)

    ny = gen_models[0].num_classes
    nz = gen_models[0].num_noises
    noises = sample_noises(size=(num_data, nz))
    labels_in = sample_labels(num_data, ny, dist='onehot')
    loader = DataLoader(TensorDataset(noises, labels_in), batch_size=256)

    images_list = []
    for idx, generator in enumerate(gen_models):
        l1 = []
        for z, y in loader:
            z = z.to(device)
            y = y.to(device)
            l1.append(generator(y, z).detach())
        images_list.append(torch.cat(tuple(l1), dim=0))
    return torch.cat(tuple(images_list), dim=0)
Exemplo n.º 2
0
def main(dataset, cls_path, out_path, index=0):
    """
    Main function for training a generator.
    """
    global DEVICE
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    utils.set_seed(seed=2019 + index)

    num_epochs = 200
    save_every = 100
    viz_every = 10

    assert num_epochs >= save_every

    if dataset == 'mnist':
        dec_layers = 1
        lrn_rate = 1e-3
        alpha = 1
        beta = 0
    elif dataset == 'fashion':
        dec_layers = 3
        lrn_rate = 1e-2
        alpha = 1
        beta = 10
    elif dataset == 'svhn':
        dec_layers = 3
        lrn_rate = 1e-2
        alpha = 1
        beta = 1
    else:
        dec_layers = 2
        lrn_rate = 1e-4
        alpha = 1
        beta = 0

    cls_network = cls_utils.init_classifier(dataset).to(DEVICE)
    gen_network = gen_utils.init_generator(dataset).to(DEVICE)
    utils.load_checkpoints(cls_network, cls_path, DEVICE)

    nz = gen_network.num_noises
    nx = data.to_dataset(dataset).nx
    dec_network = models.Decoder(nx, nz, dec_layers).to(DEVICE)

    networks = (gen_network, cls_network, dec_network)

    path_loss = os.path.join(out_path, 'loss-gen.txt')
    dir_model = os.path.join(out_path, 'generator')
    path_model = None

    os.makedirs(os.path.join(out_path, 'images'), exist_ok=True)
    with open(path_loss, 'w') as f:
        f.write('Epoch\tClsLoss\tDecLoss\tDivLoss\tLossSum\tAccuracy\n')

    loss1 = gen_loss.ReconstructionLoss(method='kld').to(DEVICE)
    loss2 = gen_loss.ReconstructionLoss(method='l2').to(DEVICE)
    loss3 = gen_loss.DiversityLoss(metric='l1').to(DEVICE)
    losses = loss1, loss2, loss3

    params = list(gen_network.parameters()) + list(dec_network.parameters())
    optimizer = optim.Adam(params, lrn_rate)

    for epoch in range(1, num_epochs + 1):
        trn_acc, trn_losses = update(networks, losses, optimizer, alpha, beta)

        with open(path_loss, 'a') as f:
            f.write(f'{epoch:3d}')
            for loss in trn_losses:
                f.write(f'\t{loss:.8f}')
            f.write(f'\t{trn_acc:.8f}\n')

        if viz_every > 0 and epoch % viz_every == 0:
            path = os.path.join(out_path, f'images/images-{epoch:03d}.png')
            gen_utils.visualize_images(gen_network, path, DEVICE)

        if epoch % save_every == 0:
            path = f'{dir_model}-{epoch:03d}.pth.tar'
            utils.save_checkpoints(gen_network, path)
            path_model = path

    print(f'Finished training the generator (index={index}).')
    return path_model
Exemplo n.º 3
0
def main(dataset, data_dist, path_out, index=0, load=None, generators=None,
         option=None):
    """
    Main function for training a classifier network.
    """
    global DEVICE
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    utils.set_seed(seed=2019 + index)

    path_loss = os.path.join(path_out, f'loss-{data_dist}-{option}.txt')
    path_model = os.path.join(path_out, 'classifier')
    path_comp = os.path.join(path_out, f'compression-{option}.txt')
    os.makedirs(path_out, exist_ok=True)

    model = cls_utils.init_classifier(dataset).to(DEVICE)
    if load is not None:
        utils.load_checkpoints(model, load, DEVICE)

    if data_dist == 'real':
        params = prepare_teacher(dataset)
    elif data_dist in ('kegnet', 'uniform', 'normal'):
        params = prepare_student(model, dataset, data_dist, generators)
        compress_classifier(model, option, path_comp)
    else:
        raise ValueError()

    lrn_rate = params['lrn_rate']
    save_every = params['save_every']
    min_epochs = params['min_epochs']
    val_epochs = params['val_epochs']
    max_epochs = params['max_epochs']
    trn_data = params['trn_data']
    val_data = params['val_data']
    test_data = params['test_data']
    trn_loss_func = params['trn_loss_func']
    test_loss_func = params['test_loss_func']

    optimizer = optim.Adam(model.parameters(), lrn_rate)

    with open(path_loss, 'w') as f:
        f.write('Epoch\tTrnLoss\tTrnAccuracy\tValLoss\tValAccuracy\t'
                'TestLoss\tTestAccuracy\tIsBest\n')

    best_acc, best_epoch = 0, 0
    for epoch in range(max_epochs + 1):
        if epoch > 0:
            update_classifier(model, trn_data, trn_loss_func, optimizer)
        trn_loss, trn_acc = eval_classifier(model, trn_data, trn_loss_func)
        val_loss, val_acc = eval_classifier(model, val_data, test_loss_func)
        test_loss, test_acc = eval_classifier(model, test_data, test_loss_func)

        if val_acc > best_acc:
            best_acc = val_acc
            best_epoch = epoch

        if epoch > max(best_epoch + val_epochs, min_epochs):
            break

        if epoch > 0:
            if epoch == best_epoch:
                p = f'{path_model}-best.pth.tar'
                utils.save_checkpoints(model, p)
            if save_every > 0 and epoch % save_every == 0:
                p = f'{path_model}-{epoch:03d}.pth.tar'
                utils.save_checkpoints(model, p)

        with open(path_loss, 'a') as f:
            f.write(f'{epoch:3d}\t')
            f.write(f'{trn_loss:.8f}\t{trn_acc:.8f}\t')
            f.write(f'{val_loss:.8f}\t{val_acc:.8f}\t')
            f.write(f'{test_loss:.8f}\t{test_acc:.8f}')
            if epoch == best_epoch:
                f.write('\tBEST')
            f.write('\n')

    print(f'Finished training the classifier (index={index}).')