Beispiel #1
0
def eval(data_loader, k):
    fgen.eval()
    nent = 0
    nll_mc = 0
    nll_iw = 0
    num_insts = 0
    for i, (data, _) in enumerate(data_loader):
        data = data.to(device, non_blocking=True)
        # [batch, channels, H, W]
        batch, c, h, w = data.size()
        x = preprocess(data, n_bits)
        # [batch, k]
        noise, log_probs_posterior = fgen.dequantize(x, nsamples=k)
        # [batch, k, channels, H, W]
        data = preprocess(data, n_bits, noise)
        # [batch * k, channels, H, W] -> [batch * k] -> [batch, k]
        log_probs = fgen.log_probability(data.view(batch * k, c, h, w)).view(batch, k)
        # [batch, k]
        log_iw = log_probs - log_probs_posterior

        num_insts += batch
        nent += log_probs_posterior.mean(dim=1).sum().item()
        nll_mc -= log_iw.mean(dim=1).sum().item()
        nll_iw += (math.log(k) - torch.logsumexp(log_iw, dim=1)).sum().item()

    nent = nent / num_insts
    nepd = nent / (nx * np.log(2.0))
    nll_mc = nll_mc / num_insts + np.log(n_bins / 2.) * nx
    bpd_mc = nll_mc / (nx * np.log(2.0))
    nll_iw = nll_iw / num_insts + np.log(n_bins / 2.) * nx
    bpd_iw = nll_iw / (nx * np.log(2.0))

    print('Avg  NLL: {:.2f}, NENT: {:.2f}, IW: {:.2f}, BPD: {:.4f}, NEPD: {:.4f}, BPD_IW: {:.4f}'.format(
        nll_mc, nent, nll_iw, bpd_mc, nepd, bpd_iw))
    return nll_mc, nent, nll_iw, bpd_mc, nepd, bpd_iw
Beispiel #2
0
def reconstruct(epoch):
    print('reconstruct')
    fgen.eval()
    n = 16
    np.random.shuffle(test_index)
    img, _ = get_batch(test_data, test_index[:n])
    img = preprocess(img.to(device), n_bits)

    z, _ = fgen.encode(img)
    img_recon, _ = fgen.decode(z)

    abs_err = img_recon.add(img * -1).abs()
    print('Err: {:.4f}, {:.4f}'.format(abs_err.max().item(), abs_err.mean().item()))

    img = postprocess(img, n_bits)
    img_recon = postprocess(img_recon, n_bits)
    comparison = torch.cat([img, img_recon], dim=0).cpu()
    reorder_index = torch.from_numpy(np.array([[i + j * n for j in range(2)] for i in range(n)])).view(-1)
    comparison = comparison[reorder_index]
    image_file = 'reconstruct{}.png'.format(epoch)
    save_image(comparison, os.path.join(result_path, image_file), nrow=16)
Beispiel #3
0
def train(epoch, k):
    print('Epoch: %d (lr=%.6f (%s), patient=%d)' % (epoch, lr, opt, patient))
    fgen.train()
    nll = 0
    nent = 0
    num_insts = 0
    num_back = 0
    num_nans = 0
    start_time = time.time()
    for batch_idx, (data, _) in enumerate(train_loader):
        batch_size = len(data)
        optimizer.zero_grad()
        data = data.to(device, non_blocking=True)
        nll_batch = 0
        nent_batch = 0
        data_list = [data, ] if batch_steps == 1 else data.chunk(batch_steps, dim=0)
        for data in data_list:
            x = preprocess(data, n_bits)
            # [batch, k]
            noise, log_probs_posterior = fgen.dequantize(x, nsamples=k)
            # [batch, k] -> [1]
            log_probs_posterior = log_probs_posterior.mean(dim=1).sum()
            # [batch, k, channels, H, W] -> [batch, channels, H, W]
            data = preprocess(data, n_bits, noise[:, 0:1]).squeeze(1)
            log_probs = fgen.log_probability(data).sum()
            loss = (log_probs_posterior - log_probs) / batch_size
            loss.backward()
            with torch.no_grad():
                nll_batch -= log_probs.item()
                nent_batch += log_probs_posterior.item()

        if grad_clip > 0:
            grad_norm = clip_grad_norm_(fgen.parameters(), grad_clip)
        else:
            grad_norm = total_grad_norm(fgen.parameters())

        if math.isnan(grad_norm):
            num_nans += 1
        else:
            optimizer.step()
            scheduler.step()
            # exponentialMovingAverage(fgen, fgen_shadow, polyak_decay)
            num_insts += batch_size
            nll += nll_batch
            nent += nent_batch

        if batch_idx % args.log_interval == 0:
            sys.stdout.write("\b" * num_back)
            sys.stdout.write(" " * num_back)
            sys.stdout.write("\b" * num_back)
            train_nent = nent / num_insts
            train_nll = nll / num_insts + train_nent + np.log(n_bins / 2.) * nx
            bits_per_pixel = train_nll / (nx * np.log(2.0))
            nent_per_pixel = train_nent / (nx * np.log(2.0))
            curr_lr = scheduler.get_lr()[0]
            log_info = '[{}/{} ({:.0f}%) lr={:.6f}, {}] NLL: {:.2f}, BPD: {:.4f}, NENT: {:.2f}, NEPD: {:.4f}'.format(
                batch_idx * batch_size, len(train_index), 100. * batch_idx * batch_size / len(train_index), curr_lr, num_nans,
                train_nll, bits_per_pixel, train_nent, nent_per_pixel)
            sys.stdout.write(log_info)
            sys.stdout.flush()
            num_back = len(log_info)

    sys.stdout.write("\b" * num_back)
    sys.stdout.write(" " * num_back)
    sys.stdout.write("\b" * num_back)
    train_nent = nent / num_insts
    train_nll = nll / num_insts + train_nent + np.log(n_bins / 2.) * nx
    bits_per_pixel = train_nll / (nx * np.log(2.0))
    nent_per_pixel = train_nent / (nx * np.log(2.0))
    print('Average NLL: {:.2f}, BPD: {:.4f}, NENT: {:.2f}, NEPD: {:.4f}, time: {:.1f}s'.format(
        train_nll, bits_per_pixel, train_nent, nent_per_pixel, time.time() - start_time))
Beispiel #4
0
    json.dump(params, open(os.path.join(model_path, 'config.json'), 'w'), indent=2)
    if dequant == 'uniform':
        fgen = FlowGenModel.from_params(params).to(device)
    elif dequant == 'variational':
        fgen = VDeQuantFlowGenModel.from_params(params).to(device)
    else:
        raise ValueError('unknown dequantization method: %s' % dequant)
    # initialize
    fgen.eval()
    init_batch_size = 512
    init_iter = 1
    print('init: {} instances with {} iterations'.format(init_batch_size, init_iter))
    for _ in range(init_iter):
        init_index = np.random.choice(train_index, init_batch_size, replace=False)
        init_data, _ = get_batch(train_data, init_index)
        init_data = preprocess(init_data.to(device), n_bits)
        fgen.init(init_data, init_scale=1.0)
    # create shadow mae for ema
    # params = json.load(open(args.config, 'r'))
    # fgen_shadow = FlowGenModel.from_params(params).to(device)
    # exponentialMovingAverage(fgen, fgen_shadow, polyak_decay, init=True)

    fgen.to_device(device)
    optimizer = get_optimizer(lr, fgen.parameters())
    lmbda = lambda step: step / float(warmups) if step < warmups else step_decay ** (step - warmups)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lmbda)
    scheduler.step()

    start_epoch = 1
    patient = 0
    best_epoch = 0