zy = int(opt.zy)
zx_sample = int(opt.zx_sample)
zy_sample = int(opt.zy_sample)
depth = 5
npx = zx_to_npx(zx, depth)
npy = zx_to_npx(zy, depth)
batch_size = int(opt.batchSize)

print(npx, npy)

if opt.data_iter == 'from_ti':
    # texture_dir='D:/gan_for_gradient_based_inv/training/ti/'
    texture_dir = 'C:/Users/Fleford/PycharmProjects/gan_for_gradient_based_inv/training/ti/'
    data_iter = get_texture2D_iter(texture_dir,
                                   npx=npx,
                                   npy=npy,
                                   mirror=False,
                                   batch_size=batch_size,
                                   n_channel=nc)


# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


netG = netG(nc, nz, ngf, gfs, ngpu)
def train(generator, discriminator, init_step, loader, total_iter=600000, max_step=6):
    step = init_step  # can be 1 = 8, 2 = 16, 3 = 32, 4 = 64, 5 = 128, 6 = 128
    # data_loader = sample_data(loader, 4 * 2 ** step)
    # dataset = iter(data_loader)

    # total_iter = 600000
    total_iter_remain = total_iter - (total_iter // max_step) * (step - 1)

    pbar = tqdm(range(total_iter_remain))

    disc_loss_val = 0
    gen_loss_val = 0
    grad_loss_val = 0

    from datetime import datetime
    import os
    date_time = datetime.now()
    post_fix = '%s_%s_%d_%d.txt' % (trial_name, date_time.date(), date_time.hour, date_time.minute)
    log_folder = 'trial_%s_%s_%d_%d' % (trial_name, date_time.date(), date_time.hour, date_time.minute)

    os.mkdir(log_folder)
    os.mkdir(log_folder + '/checkpoint')
    os.mkdir(log_folder + '/sample')

    config_file_name = os.path.join(log_folder, 'train_config_' + post_fix)
    config_file = open(config_file_name, 'w')
    config_file.write(str(args))
    config_file.close()

    log_file_name = os.path.join(log_folder, 'train_log_' + post_fix)
    log_file = open(log_file_name, 'w')
    log_file.write('g,d,cntxt_loss,ds_cntxt_loss\n')
    log_file.close()

    from shutil import copy
    copy('train.py', log_folder + '/train_%s.py' % post_fix)
    copy('progan_modules.py', log_folder + '/model_%s.py' % post_fix)
    copy('utils.py', log_folder + '/utils_%s.py' % post_fix)

    alpha = 0
    one = torch.FloatTensor([1]).to(device)
    # one = torch.tensor(1, dtype=torch.float).to(device)
    mone = one * -1
    iteration = 0

    # Prepare reference batch for display
    data_iter_sample = get_texture2D_iter('ti/', batch_size=5 * 10)
    real_image_raw_res_sample = torch.Tensor(next(data_iter_sample)).to(device)
    cond_array_sample, cond_mask_sample = generate_condition(real_image_raw_res_sample)
    # cond_array_sample = torch.zeros(batch_size, 1, 128, 128, device='cuda:0')

    # broadcast first cond_array to whole batch
    # one_cond_array_sample = torch.zeros_like(cond_array_sample)
    for slice in range(len(cond_array_sample) // 2):
        cond_array_sample[slice] = cond_array_sample[0]
    # cond_array_sample = one_cond_array_sample

    data_iter = get_texture2D_iter('ti/', batch_size=batch_size)
    cntxt_loss = torch.FloatTensor([69]).to(device)

    for i in pbar:
        discriminator.zero_grad()

        alpha = min(1, (2 / (total_iter // max_step)) * iteration)

        if iteration > total_iter // max_step:
            alpha = 0
            iteration = 0
            step += 1

            if step > max_step:
                alpha = 1
                step = max_step

        # Scale training image using avg downsampling
        real_image_raw_res = torch.Tensor(next(data_iter)).to(device)
        kernel_width = 2 ** (6 - step)
        avg_downsampler = torch.nn.AvgPool2d((kernel_width, kernel_width), stride=(kernel_width, kernel_width))
        cond_downsampler = torch.nn.MaxPool2d((kernel_width, kernel_width), stride=(kernel_width, kernel_width))
        real_image = avg_downsampler(real_image_raw_res)
        # plt.matshow(real_image[0, 0].cpu().detach().numpy())
        # plt.show()

        iteration += 1

        ### 1. train Discriminator
        b_size = real_image.size(0)
        # label = torch.zeros(b_size).to(device)
        real_predict = discriminator(
            real_image, step=step, alpha=alpha)

        real_predict = real_predict.mean() - 0.001 * (real_predict ** 2).mean()
        real_predict.backward(mone)

        # sample input data: vector for Generator
        gen_z = torch.randn(b_size, input_code_size).to(device)

        # generate condition array
        cond_array, cond_mask = generate_condition(real_image_raw_res)

        # # broadcast first raw image to the whole batch
        # # one_real_image_raw_res = torch.zeros_like(real_image_raw_res)
        # for slice in range(len(real_image_raw_res)//4):
        #     real_image_raw_res[slice] = real_image_raw_res[0]
        # # real_image_raw_res = one_real_image_raw_res
        #
        # # broadcast first cond array to the whole batch
        # # one_cond_array = torch.zeros_like(cond_array)
        # for slice in range(len(cond_array)//4):
        #     cond_array[slice] = cond_array[0]
        # # cond_array = one_cond_array
        #
        # # broadcast first cond mask to the whole batch
        # # one_cond_mask = torch.zeros_like(cond_mask)
        # for slice in range(len(cond_mask)//4):
        #     cond_mask[slice] = cond_mask[0]
        # # cond_mask = one_cond_mask

        fake_image = generator(gen_z, cond_array, step=step, alpha=alpha)
        fake_predict = discriminator(
            fake_image.detach(), step=step, alpha=alpha)
        fake_predict = fake_predict.mean()
        fake_predict.backward(one)

        ### gradient penalty for D
        eps = torch.rand(b_size, 1, 1, 1).to(device)
        x_hat = eps * real_image.data + (1 - eps) * fake_image.detach().data
        x_hat.requires_grad = True
        hat_predict = discriminator(x_hat, step=step, alpha=alpha)
        grad_x_hat = grad(outputs=hat_predict.sum(), inputs=x_hat, create_graph=True)[0]
        grad_penalty = ((grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) - 1) ** 2).mean()
        grad_penalty = 10 * grad_penalty
        grad_penalty.backward(one)
        grad_loss_val += grad_penalty.item()
        disc_loss_val += (real_predict - fake_predict).item()

        d_optimizer.step()

        ### 2. train Generator
        if (i + 1) % n_critic == 0:
            generator.zero_grad()
            discriminator.zero_grad()

            predict = discriminator(fake_image, step=step, alpha=alpha)

            # Calculate context loss (conditioning hard data)
            fake_image_upsampled = F.interpolate(fake_image, size=(128, 128), mode="nearest")

            # real_image_upsampled = F.interpolate(real_image, size=(128, 128), mode="nearest")
            # context_loss_array = ((fake_image_upsampled - real_image_upsampled) ** 2) * cond_mask

            context_loss_array = ((fake_image_upsampled - real_image_raw_res) ** 2) * cond_mask

            # ds_cond_mask = cond_downsampler(cond_mask)
            ds_context_loss_array = ((fake_image_upsampled - real_image_raw_res) ** 2) * cond_mask
            ds_context_loss_value = torch.sum(ds_context_loss_array)
            ds_cntxt_loss = ds_context_loss_value.item()

            context_loss_value = torch.sum(context_loss_array).log()

            loss = -predict.mean() + 1.0 * context_loss_value
            gen_loss_val += loss.item()
            cntxt_loss = context_loss_value.item()

            loss.backward()
            g_optimizer.step()
            accumulate(g_running, generator)

        if (i + 1) % 1000 == 0 or i == 0:
            with torch.no_grad():
                images = g_running(torch.randn(5 * 10, input_code_size).to(device),
                                   cond_array_sample, step=step, alpha=alpha).data.cpu()
                images = F.interpolate(images, size=(128, 128), mode="nearest")
                utils.save_image(
                    images,
                    f'{log_folder}/sample/{str(i + 1).zfill(6)}.png',
                    nrow=10,
                    normalize=True,
                    range=(-1, 1))

        if (i + 1) % 10000 == 0 or i == 0:
            try:
                torch.save(g_running.state_dict(), f'{log_folder}/checkpoint/{str(i + 1).zfill(6)}_g.model')
                torch.save(discriminator.state_dict(), f'{log_folder}/checkpoint/{str(i + 1).zfill(6)}_d.model')
            except:
                pass

        if (i + 1) % 500 == 0:
            state_msg = (f'{i + 1}; G: {gen_loss_val / (500 // n_critic):.3f}; D: {disc_loss_val / 500:.3f};'
                         f' Grad: {grad_loss_val / 500:.3f}; Alpha: {alpha:.3f}; Step: {step:.3f}; Iteration: {iteration:.3f};'
                         f' Context Loss: {cntxt_loss:.3f};' f' DS Context Loss: {ds_cntxt_loss:.3f};')
            print(real_image.shape)

            log_file = open(log_file_name, 'a+')
            new_line = "%.5f,%.5f,%.5f,%.5f\n" % (
            gen_loss_val / (500 // n_critic), disc_loss_val / 500, cntxt_loss, ds_cntxt_loss)
            log_file.write(new_line)
            log_file.close()

            disc_loss_val = 0
            gen_loss_val = 0
            grad_loss_val = 0

            print(state_msg)
input_code_size = 128

generator = Generator(in_channel=64,
                      input_code_dim=128,
                      pixel_norm=False,
                      tanh=False).to(device)

# generator.load_state_dict(torch.load('trial_test18_2020-10-12_22_29/checkpoint/160000_g.model'))
generator.load_state_dict(
    torch.load('trial_test18_2020-10-18_17_37/checkpoint/140000_g.model'))

# sample input data: vector for Generator
gen_z = torch.randn(b_size, input_code_size).to(device)

# generate condition array
data_iter = get_texture2D_iter('ti/', batch_size=b_size)
real_image_raw_res = torch.Tensor(next(data_iter)).to(device)
cond_array, cond_mask = generate_condition(real_image_raw_res)

cond_downsampler = torch.nn.MaxPool2d((8, 8), stride=(8, 8))

# broadcast first cond_array to whole batch
one_cond_array = torch.zeros_like(cond_array)
for slice in range(len(cond_array)):
    one_cond_array[slice] = cond_array[0]
cond_array = one_cond_array

# broadcast first cond array to the whole batch (cond_mask)
one_cond_mask = torch.zeros_like(cond_mask)
for slice in range(len(cond_mask)):
    one_cond_mask[slice] = cond_mask[0]