Ejemplo n.º 1
0
def generate_contaminated_data(
        eps, num_data, theta=None,
        type_cont="gauss_5",
        coord_median_as_origin=True):

    if theta is None:
        dim = 100
        theta = torch.zeros(dim)

    randidx = torch.rand(num_data) < eps
    dirty_data = generate_dirty_data(type_cont, theta, randidx.sum().item())

    data = theta + torch.randn((num_data, len(theta)), device=theta.device)
    data[randidx] = dirty_data

    if coord_median_as_origin:
        from utils import coord_median
        coordmedian = coord_median(data)
        data = data - coordmedian
        theta = theta - coordmedian
    return data, theta
Ejemplo n.º 2
0
    theta = theta.to(device)

    data_loader = torch.utils.data.DataLoader(TensorDataset(data),
                                              batch_size=args.real_batch_size,
                                              shuffle=True,
                                              num_workers=0)

    noise_generator = NoiseGenerator().to(device)
    '''
    We recommend not using coordinate-wise median as initialization.
    The global minimum of Wasserstein GAN has mean square error very close to the coordinate-wise median,
    thus we prefer the training starting from somewhere else in order to see the progress of training.
    '''
    generator = Generator(
        p=args.p,
        initializer=1.3 * coord_median(data_loader.dataset.tensors[0]),
        # 0.5 * torch.ones(args.p),
    ).to(device)

    sinkhorn = SinkhornIteration(lam=args.lam,
                                 max_iter=args.sinkhorn_max_iter,
                                 device=device,
                                 const=args.const,
                                 thres=args.thres)
    g_optim = torch.optim.SGD(generator.parameters(),
                              lr=args.g_sgd_lr,
                              momentum=args.g_sgd_momentum)

    print('initial dist {:.4f}'.format(
        torch.norm(generator.eta - theta).item()))
Ejemplo n.º 3
0
    data_loader = torch.utils.data.DataLoader(TensorDataset(data),
                                              batch_size=args.real_batch_size,
                                              shuffle=True,
                                              num_workers=0)

    noise_generator = NoiseGenerator().to(device)
    '''
    Do not use coordinate-wise median as initialization.
    The global minimum of MMD GAN has mean square error very close to the coordinate-wise median,
    thus we prefer the training starting from somewhere else in order to see the progress of training.
    '''
    generator = Generator(
        p=args.p,
        # initializer=torch.ones(args.p),
        initializer=1.5 * coord_median(data),
    ).to(device)

    mmd = MMD(sigma=args.sigma, device=device)

    g_optim = torch.optim.SGD(generator.parameters(),
                              lr=args.g_sgd_lr,
                              momentum=args.g_sgd_momentum)

    print('initial dist {:.4f}'.format(
        torch.norm(generator.eta - theta).item()))

    lst_eta = [generator.get_numpy_eta()]

    for i in range(args.num_epoch):
        total_loss = 0