예제 #1
0
    def eval(predict):
        multi_step_errors = []
        one_step_errors = []
        for delta, stop in tqdm(
                bitmap.generate_test_set(set_size=args.test_size,
                                         seed=args.test_seed)):
            start = predict(delta, stop)
            multi_step_errors.append(1 - score(delta, start, stop))

            one_step_start = start if delta == 1 else predict(1, stop)
            one_step_errors.append(1 - score(1, one_step_start, stop))
        return np.mean(multi_step_errors), np.var(multi_step_errors), np.mean(
            one_step_errors), np.var(one_step_errors)
예제 #2
0
    def eval(predict):
        multi_step_errors = []
        one_step_errors = []
        for batch in tqdm(grouper(bitmap.generate_test_set(set_size=args.test_size, seed=args.test_seed), 100)):
            deltas, stops = zip(*batch)

            delta_batch = np.array(deltas)
            stop_batch = np.array(stops)
            start_batch = predict(deltas, stops)

            for delta, start, stop in zip(delta_batch, start_batch, stop_batch):
                multi_step_errors.append(1 - score(delta, start, stop))

            one_deltas = np.ones_like(delta_batch)
            one_step_start = np.where(deltas == 1, start_batch, predict(one_deltas, stops))
            for delta, start, stop in zip(one_deltas, one_step_start, stop_batch):
                one_step_errors.append(1 - score(delta, start, stop))
        return np.mean(multi_step_errors), np.var(multi_step_errors), np.mean(one_step_errors), np.var(one_step_errors)
예제 #3
0
    # Use deltas as indices into preds.
    # TODO: I'm sure there's some clever way to do the same using numpy indexing/slicing.
    final_pred_batch = []
    for i in range(deltas_batch.size):
        final_pred_batch.append(preds[np.squeeze(
            deltas_batch[i])][i].detach().numpy())

    return final_pred_batch


def cnnify_batch(batches):
    return (np.expand_dims(batch, 1) for batch in batches)


val_set = bitmap.generate_test_set(set_size=opt.batchSize, seed=9568382)
deltas_val, stops_val = cnnify_batch(zip(*val_set))
ones_val = np.ones_like(deltas_val)

#


# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)
예제 #4
0
        tens = torch.Tensor(preds[-1])
        pred_batch = net(tens) #np.array(net(tens) > 0.5, dtype=np.float)
        preds.append(pred_batch)

    # Use deltas as indices into preds.
    # TODO: I'm sure there's some clever way to do the same using numpy indexing/slicing.
    final_pred_batch = []
    for i in range(deltas_batch.size):
        final_pred_batch.append(preds[np.squeeze(deltas_batch[i])][i])

    return final_pred_batch

def cnnify_batch(batches):
    return (np.expand_dims(batch, 1) for batch in batches)

val_set = bitmap.generate_test_set(set_size=100, seed=9568382)
deltas_val, stops_val = cnnify_batch(zip(*val_set))
ones_val = np.ones_like(deltas_val)

multi_step_errors = []
one_step_errors = []
best_multi_step_error = 1.0
best_multi_step_idx = -1
best_one_step_error = 1.0
best_one_step_idx = -1

for i, batch in tqdm(enumerate(grouper(bitmap.generate_inf_cases(True, 432341, return_one_but_last=True), 2048))):
    deltas, one_but_lasts, stops = zip(*batch)

    deltas_batch = np.expand_dims(deltas, 1)
    one_but_lasts_batch = torch.Tensor(np.expand_dims(one_but_lasts, 1))
예제 #5
0
def train(
        gen_arch,
        fwd_arch,
        device,
        writer,
        batchSize,
        niter,
        lr,
        beta1,
        dry_run,
        outf,
        workers=1,
        start_iter=0,
        gen_path=None,
        fwd_path=None,
        ngpu=1,
        nz=100,
        epoch_samples=64*100,
        learn_forward=True,
        sigmoid=True):

    os.makedirs(outf, exist_ok=True)

    gen_path = gen_path if start_iter == 0 else os.path.join(outf, f'netG_epoch_{start_iter-1}.pth')
    fwd_path = fwd_path if start_iter == 0 else os.path.join(outf, f'netF_epoch_{start_iter-1}.pth')

    # Prediction threshold
    pred_th = 0.5 if sigmoid else 0.0

    dataset = DataGenerator(823131 + start_iter, sigmoid)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize,
                                             shuffle=False, num_workers=int(workers))

    val_size = 1024
    val_set = bitmap.generate_test_set(set_size=val_size, seed=9568382)
    deltas_val, stops_val = cnnify_batch(zip(*val_set))
    ones_val = np.ones_like(deltas_val)
    noise_val = torch.randn(val_size, nz, 1, 1, device=device)

    netG = get_generator_net(gen_arch).to(device)
    init_model(netG, gen_path)

    netF = get_forward_net(fwd_arch).to(device)
    init_model(netF, fwd_path)

    criterion = nn.BCELoss()

    fixed_noise = torch.randn(batchSize, nz, 1, 1, device=device)
    fixed_ones = np.ones((batchSize,), dtype=np.int)

    # setup optimizer
    optimizerD = optim.Adam(netF.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

    scores = []
    for i in range(5):
        noise = torch.randn(val_size, nz, 1, 1, device=device)
        one_step_pred_batch = (netG(torch.Tensor(stops_val).to(device), noise) > pred_th).cpu()
        model_scores = scoring.score_batch(ones_val, np.array(one_step_pred_batch, dtype=np.bool), stops_val)
        scores.append(model_scores)

    zeros = np.zeros_like(one_step_pred_batch, dtype=np.bool)
    zeros_scores = scoring.score_batch(ones_val, zeros, stops_val)
    scores.append(zeros_scores)

    best_scores = np.max(scores, axis=0)

    print(
        f'Mean error one step: model {1 - np.mean(model_scores)}, zeros {1 - np.mean(zeros_scores)}, ensemble {1 - np.mean(best_scores)}')

    #for epoch in range(opt.niter):
    epoch = start_iter
    samples_in_epoch = 0
    samples_before = start_iter * epoch_samples
    for j, data in enumerate(dataloader, 0):
        i = start_iter * epoch_samples // batchSize + j
        ############################
        # (1) Update F (forward) network -- in the original GAN, it's a "D" network (discriminator)
        # Original comment: Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real starting board -- data set provides ground truth
        netF.zero_grad()
        start_real_cpu = data[0].to(device)
        stop_real_cpu = data[1].to(device)
        batch_size = start_real_cpu.size(0)

        output = netF(start_real_cpu)
        errD_real = criterion(output, stop_real_cpu)
        if learn_forward:
            errD_real.backward()
        D_x = (output.round().eq(stop_real_cpu)).sum().item() / output.numel()

        # train with fake -- use simulator (life_step) to generate ground truth
        # TODO: replace with fixed forward model (should be faster, in batches and on GPU)
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = netG(stop_real_cpu, noise)
        fake_np = (fake > pred_th).detach().cpu().numpy()
        fake_next_np = life_step(fake_np)
        fake_next = torch.tensor(fake_next_np, dtype=torch.float32).to(device)

        output = netF(fake.detach())
        errD_fake = criterion(output, fake_next)
        if learn_forward:
            errD_fake.backward()
        D_G_z1 = (output.round().eq(fake_next)).sum().item() / output.numel()
        errD = errD_real + errD_fake
        if learn_forward:
            optimizerD.step()

        # just for reporting...
        true_stop_np = (stop_real_cpu > pred_th).detach().cpu().numpy()
        fake_scores = scoring.score_batch(fixed_ones, fake_np, true_stop_np, show_progress=False)
        fake_mae = 1 - fake_scores.mean()
        fake_density = fake_np.mean()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        output = netF(fake)
        errG = criterion(output, stop_real_cpu)
        errG.backward()
        D_G_z2 = (output.round().eq(fake_next)).sum().item() / output.numel()
        optimizerG.step()

        samples_in_epoch += batch_size
        s = samples_before + samples_in_epoch
        writer.add_scalar('Loss/forward', errD.item(), i)
        writer.add_scalar('Loss/gen', errG.item(), i)
        writer.add_scalar('MAE/train', fake_mae.item(), i)
        print('[%d/%d][%d] Loss_F: %.4f Loss_G: %.4f fwd acc(real): %.2f fwd acc(fake): %.2f / %.2f, fake dens: %.2f, MAE: %.4f'
              % (epoch, start_iter+niter, i,
                 errD.item(), errG.item(), D_x, D_G_z1, D_G_z2, fake_density, fake_mae))
        if samples_in_epoch >= epoch_samples:
            """
            multi_step_pred_batch = predict(netG, deltas_val, stops_val, fixed_noise)
            multi_step_mean_err = 1 - np.mean(scoring.score_batch(deltas_val, np.array(multi_step_pred_batch, dtype=np.bool), stops_val))
            """

            one_step_pred_batch = (netG(torch.Tensor(stops_val).to(device), noise_val) > pred_th).detach().cpu().numpy()
            one_step_mean_err = 1 - np.mean(scoring.score_batch(ones_val, np.array(one_step_pred_batch, dtype=np.bool), stops_val))
            print(f'Mean error: one step {one_step_mean_err}')
            writer.add_scalar('MAE/val', one_step_mean_err, epoch)

            vutils.save_image(start_real_cpu,
                    '%s/real_samples.png' % outf,
                    normalize=True)
            fake = netG(stop_real_cpu, fixed_noise).detach()
            vutils.save_image(fake,
                    '%s/fake_samples_epoch_%03d.png' % (outf, epoch),
                    normalize=True)

            grid = vutils.make_grid(start_real_cpu)
            writer.add_image('real', grid, epoch)
            grid = vutils.make_grid(fake)
            writer.add_image('fake', grid, epoch)

            # do checkpointing
            torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (outf, epoch))
            torch.save(netF.state_dict(), '%s/netF_epoch_%d.pth' % (outf, epoch))
            epoch += 1
            samples_in_epoch = 0

        if epoch - start_iter >= niter:
            break
        if dry_run:
            break

    return one_step_mean_err