def plot_means(curr_embed, curr_mapped, curr_labels, curr_images, name, with_targets, token_list): ImgGen = LevelImageGen("./mario/sprites") means = [] m_pts = [] m_imgs = [] for i, label in enumerate(np.unique(curr_labels)): m = curr_embed[curr_labels == label, :].mean(0) closest_idx = (np.sum(abs(curr_embed - m)**2, axis=1)**(1. / 2)).argmin() means.append(curr_mapped[closest_idx, :]) m_imgs.append(curr_images[closest_idx, :]) m_pts.append(curr_embed[closest_idx, :]) means = np.array(means) plt.figure(figsize=(len(np.unique(curr_labels) * 2), 2)) for i, img in enumerate(m_imgs): plt.subplot(1, len(np.unique(curr_labels)), i + 1) plt.imshow( ImgGen.render( one_hot_to_ascii_level( torch.tensor(img).unsqueeze(0), token_list))) plt.axis("off") figure_path = os.path.join( wandb.run.dir, f"{name}_imgs{'_targets' if with_targets else ''}.pdf") plt.tight_layout() plt.savefig(figure_path, dpi=300) wandb.save(figure_path) wandb.log({f"{name}_imgs{'_targets' if with_targets else ''}": plt}) plt.close() return means
def new_tile_types(vec, opt): ascii_level = one_hot_to_ascii_level(vec.detach(), opt.token_list) ref_level = [] path = opt.input_dir + '/' + opt.input_name hamming = 0 with open(path, "r") as f: for line in f: ref_level.append(line) token_set = set() for row in ref_level: for token in row: token_set.add(token) score = 0.0 for row in ascii_level: for token in row: if token not in token_set: score += 1 return score
def generate(n, h, w): for j in range(n): with torch.no_grad(): samples = [torch.normal(mu, logvar).to(device)] samples = torch.tensor( [model.decode(sample).cpu().numpy() for sample in samples]) samples = samples.reshape((1, len(opt.token_list), h, w)) ind = torch.argmax(samples, dim=1) hotenc = torch.zeros_like(samples) for x in range(hotenc.shape[2]): for y in range(hotenc.shape[3]): hotenc[0, ind[0, x, y], x, y] = 1 ascii_gen = one_hot_to_ascii_level(hotenc, opt.token_list) # ascii_real = one_hot_to_ascii_level(real[9:10,:,:,:], opt.token_list) if not os.path.exists(f'{opt.out_dir}/txt'): os.makedirs(f'{opt.out_dir}/txt') with open(f"{opt.out_dir}/txt/run{j}.txt", 'w') as file: for x in ascii_gen: file.write(f'{x}') gen_level = opt.ImgGen.render(ascii_gen) # real_level = opt.ImgGen.render(ascii_real) if n > 1: if not os.path.exists(f'{opt.out_dir}/img'): os.makedirs(f'{opt.out_dir}/img') gen_level.save(rf"{opt.out_dir}/img/run{j}.png") else: return samples
def normalized_compression_dist(vec, opt): ascii_level = one_hot_to_ascii_level(vec.detach(), opt.token_list) ref_level = [] path = opt.input_dir + '/' + opt.input_name with open(path, "r") as f: for line in f: ref_level.append(line) x, y = "", "" for row in ascii_level: x += row for row in ref_level: y += row xy = x + y kx = gzip.compress(x.encode('utf-8')) ky = gzip.compress(y.encode('utf-8')) kxy = gzip.compress(xy.encode('utf-8')) ncd = (sys.getsizeof(kxy) - min(sys.getsizeof(kx), sys.getsizeof(ky))) / max( sys.getsizeof(kx), sys.getsizeof(ky)) return ncd
def num_koopa(vec, opt): ascii_level = one_hot_to_ascii_level(vec.detach(), opt.token_list) koopas = 0 for row in ascii_level: for token in row: if token == "k": koopas += 1 return koopas
def num_enemies(vec, opt): ascii_level = one_hot_to_ascii_level(vec.detach(), opt.token_list) enemies = 0 for row in ascii_level: for token in row: if token in ENEMY_TOKENS or token in SPECIAL_ENEMY_TOKENS: enemies += 1 return enemies
def spiky(vec, opt): ascii_level = one_hot_to_ascii_level(vec.detach(), opt.token_list) score = 0.0 for row in ascii_level: for token in row: if token == 'y': score += 1 return score
def platform_test_vec(vec, token_list): score = 0.0 ascii_level = one_hot_to_ascii_level(vec.detach(), token_list) for i in range(len(ascii_level[-1])): if ascii_level[-1][i] != ascii_level[-2][i]: score += 1.0 return score
def num_jumps(vec, token_list): num_jumps = 0 ascii_level = one_hot_to_ascii_level(vec.detach(), token_list) merged_level = "" for i in range(len(ascii_level[-1])): if ascii_level[-1][i] == "X" or ascii_level[-2][i] == "X": merged_level += "X" else: merged_level += " " platforms = merged_level.split() return max(len(platforms) - 1, 0)
def enemy_on_stairs(vec, opt): score = 0.0 ascii_level = one_hot_to_ascii_level(vec.detach(), opt.token_list) for i, row in enumerate(ascii_level): if i == 15: continue else: for j, token in enumerate(row): if token in ENEMY_TOKENS and ascii_level[i + 1][j] == "#": score += 1 return score
def midair_pipes(vec, opt): score = 0.0 ascii_level = one_hot_to_ascii_level(vec.detach(), opt.token_list) for i, row in enumerate(ascii_level): if i == 15: continue else: for j, token in enumerate(row): if token in PIPE_TOKENS and ascii_level[i + 1][j] in SKY_TOKENS: score += 1 return score
def hamming_dist(vec, opt): ascii_level = one_hot_to_ascii_level(vec.detach(), opt.token_list) ref_level = [] path = opt.input_dir + '/' + opt.input_name hamming = 0 with open(path, "r") as f: for line in f: ref_level.append(line) for i in range(len(ref_level)): for j in range(len(ref_level[0])): if i == 15 and j == 202: pass else: if ascii_level[i][j] != ref_level[i][j]: hamming += 1.0 hamming_score = hamming / (float(len(ref_level)) * float(len(ref_level[0]))) return hamming_score
def test_playability(vec, token_list): render_mario = True root = Tk(className=" TOAD-GUI") level_l = IntVar() level_h = IntVar() level_l.set(0) level_h.set(0) placeholder = Image.new( 'RGB', (890, 256), (255, 255, 255)) # Placeholder image for the preview load_string_gen = StringVar() load_string_txt = StringVar() ImgGen = LevelImageGen( os.path.join(os.path.join(os.curdir, "utils"), "sprites")) use_gen = BooleanVar() use_gen.set(False) levelimage = ImageTk.PhotoImage(placeholder) level_obj = LevelObject('-', None, levelimage, ['-'], None, None) is_loaded = BooleanVar() is_loaded.set(False) error_msg = StringVar() error_msg.set("No Errors") # Py4j Java bridge uses Mario AI Framework gateway = JavaGateway.launch_gateway(classpath=MARIO_AI_PATH, die_on_exit=True, redirect_stdout=sys.stdout, redirect_stderr=sys.stderr) # Open up game window and assign agent game = gateway.jvm.engine.core.MarioGame() game.initVisuals(2.0) agent = gateway.jvm.agents.robinBaumgarten.Agent() game.setAgent(agent) #create a level object to load the level into level_obj = LevelObject(0, 0, 0, 0, 0, 0) #convert the level to ascii level_obj.ascii_level = one_hot_to_ascii_level(vec.detach(), token_list) # Check if a Mario token exists - if not, we need to place one m_exists = False for line in level_obj.ascii_level: if 'M' in line: m_exists = True break if not m_exists: level_obj.ascii_level = place_a_mario_token(level_obj.ascii_level) level_obj.tokens = token_list #ImgGen = MarioLevelGen('utils/sprites') img = ImageTk.PhotoImage(ImgGen.render(level_obj.ascii_level)) level_obj.image = img level_obj.scales = None level_obj.noises = None level_l.set(vec.shape[-1]) level_h.set(vec.shape[-2]) is_loaded.set(True) use_gen.set(False) error_msg.set("Level loaded") # Play the level # perc = 0 try: result = game.gameLoop(''.join(level_obj.ascii_level), 20, 0, render_mario, 1000000) perc = int(result.getCompletionPercentage() * 100) timeLeft = result.getRemainingTime() # time remaining in the countdown jumps = result.getNumJumps( ) # number of jumps performed by mario during the game max_jump = result.getMaxXJump( ) #maximum x distance traversed by mario during a jump error_msg.set("Level Played. Completion Percentage: %d%%" % perc) except Exception: error_msg.set("Level Play was interrupted.") is_loaded.set(True) finally: # game.getWindow().dispose() gateway.java_process.kill() #gateway.close() gateway.shutdown() is_loaded.set(True) # use_gen.set(remember_use_gen) # only set use_gen to True if it was previously return perc, timeLeft, jumps, max_jump return perc
rand_network.eval() #load the pickled archive df = pandas.read_pickle(s_dir_name + "/archive.zip") if opt.all: solutions = [] bcs = [] objs = [] for _, row in df.iterrows(): latent = np.array(row.loc["solution_0":]) bc = row.loc[["behavior_0", "behavior_1"]] obj = row.loc[["objective"]][0] solutions.append(latent) bcs.append(bc) objs.append(obj) for solution, bc, obj in zip(solutions, bcs, objs): solution = torch.from_numpy(solution).float() noise = rand_network(solution).detach() levels = generate_samples_cmaes(generators, noise_maps, reals, noise_amplitudes, noise, opt, in_s=in_s, scale_v=opt.scale_v, scale_h=opt.scale_h, save_dir=s_dir_name, num_samples=1) level = levels[0] ascii_level = one_hot_to_ascii_level(level, opt.token_list) img = opt.ImgGen.render(ascii_level) img.save("%s/elite_%.3f_%.3f_score_%.2f.png" % (s_dir_name, bc[0], bc[1], obj)) except: continue
# print('Reconstructed Images') # plt.figure(3) # n = 12 # for i in range(n): # ax = plt.subplot(n,1,i+1) # plt.imshow(output[0,i,:,:]) # plt.show() ind = torch.argmax(output, dim = 1) hotenc = torch.zeros_like(output) for x in range(hotenc.shape[2]): for y in range(hotenc.shape[3]): hotenc[0,ind[0,x,y],x,y] = 1 ascii_gen = one_hot_to_ascii_level(hotenc, opt.token_list) ascii_real = one_hot_to_ascii_level(real, opt.token_list) gen_level = opt.ImgGen.render(ascii_gen) real_level = opt.ImgGen.render(ascii_real) gen_level.save(rf"Gen_Levels\{n_epochs}.png") # real_level.save(r"Gen_Levels\real.png") # # y = torch.randn(5,5) # # print(y) # print(y.shape) # z = y.unfold(0,3,2) # print(z) # print(z.shape)
def train_single_scale(D, G, reals, generators, noise_maps, input_from_prev_scale, noise_amplitudes, opt): """ Train one scale. D and G are the current discriminator and generator, reals are the scaled versions of the original level, generators and noise_maps contain information from previous scales and will receive information in this scale, input_from_previous_scale holds the noise map and images from the previous scale, noise_amplitudes hold the amplitudes for the noise in all the scales. opt is a namespace that holds all necessary parameters. """ current_scale = len(generators) real = reals[current_scale] keepSky = False kernel_dims = (2, 2) # Initialize real detector real0 = preprocess(opt, real, keepSky) N, C, H, W = real0.shape scale = opt.scales[current_scale] if current_scale < len(opt.scales) else 1 if opt.cgan: detector = PCA_Detector(opt, 'real', real0, kernel_dims) real_detection_map = detector(real0) detection_scale = 0.1 real_detection_map *= detection_scale real1 = torch.cat( [real, F.interpolate(real_detection_map, (H, W))], dim=1) divergences = [] else: real1 = real if opt.game == 'mario': token_group = MARIO_TOKEN_GROUPS else: # if opt.game == 'mariokart': token_group = MARIOKART_TOKEN_GROUPS nzx = real.shape[2] # Noise size x nzy = real.shape[3] # Noise size y padsize = int( 1 * opt.num_layer ) # As kernel size is always 3 currently, padsize goes up by one per layer if not opt.pad_with_noise: pad_noise = nn.ZeroPad2d(padsize) pad_image = nn.ZeroPad2d(padsize) else: pad_noise = nn.ReflectionPad2d(padsize) pad_image = nn.ReflectionPad2d(padsize) # setup optimizer optimizerD = optim.Adam(D.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(G.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600, 2500], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600, 2500], gamma=opt.gamma) if current_scale == 0: # Generate new noise z_opt = generate_spatial_noise([1, opt.nc_current, nzx, nzy], device=opt.device) z_opt = pad_noise(z_opt) else: # Add noise to previous output z_opt = torch.zeros([1, opt.nc_current, nzx, nzy]).to(opt.device) z_opt = pad_noise(z_opt) logger.info("Training at scale {}", current_scale) for epoch in tqdm(range(opt.niter)): step = current_scale * opt.niter + epoch noise_ = generate_spatial_noise([1, opt.nc_current, nzx, nzy], device=opt.device) noise_ = pad_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real D.zero_grad() output = D(real1).to(opt.device) errD_real = -output.mean() errD_real.backward(retain_graph=True) # train with fake if (j == 0) & (epoch == 0): if current_scale == 0: # If we are in the lowest scale, noise is generated from scratch prev = torch.zeros(1, opt.nc_current, nzx, nzy).to(opt.device) input_from_prev_scale = prev prev = pad_image(prev) z_prev = torch.zeros(1, opt.nc_current, nzx, nzy).to(opt.device) z_prev = pad_noise(z_prev) opt.noise_amp = 1 else: # First step in NOT the lowest scale # We need to adapt our inputs from the previous scale and add noise to it prev = draw_concat(generators, noise_maps, reals, noise_amplitudes, input_from_prev_scale, "rand", pad_noise, pad_image, opt) # For the seeding experiment, we need to transform from token_groups to the actual token if current_scale == (opt.token_insert + 1): prev = group_to_token(prev, opt.token_list, token_group) prev = interpolate(prev, real1.shape[-2:], mode="bilinear", align_corners=False) prev = pad_image(prev) z_prev = draw_concat(generators, noise_maps, reals, noise_amplitudes, input_from_prev_scale, "rec", pad_noise, pad_image, opt) # For the seeding experiment, we need to transform from token_groups to the actual token if current_scale == (opt.token_insert + 1): z_prev = group_to_token(z_prev, opt.token_list, token_group) z_prev = interpolate(z_prev, real1.shape[-2:], mode="bilinear", align_corners=False) opt.noise_amp = update_noise_amplitude( z_prev, real1[:, :-1], opt) z_prev = pad_image(z_prev) else: # Any other step prev = draw_concat(generators, noise_maps, reals, noise_amplitudes, input_from_prev_scale, "rand", pad_noise, pad_image, opt) # For the seeding experiment, we need to transform from token_groups to the actual token if current_scale == (opt.token_insert + 1): prev = group_to_token(prev, opt.token_list, token_group) prev = interpolate(prev, real1.shape[-2:], mode="bilinear", align_corners=False) prev = pad_image(prev) # After creating our correct noise input, we feed it to the generator: noise = opt.noise_amp * noise_ + prev fake = G(noise.detach(), prev, temperature=1 if current_scale != opt.token_insert else 1) fake0 = preprocess(opt, fake, keepSky) if opt.cgan: Nf, Cf, Hf, Wf = fake0.shape fake_detection_map = detector(fake0) * detection_scale fake1 = torch.cat( [fake, F.interpolate(fake_detection_map, (Hf, Wf))], dim=1) else: fake1 = fake # Then run the result through the discriminator output = D(fake1.detach()) errD_fake = output.mean() # Backpropagation errD_fake.backward(retain_graph=False) # Gradient Penalty gradient_penalty = calc_gradient_penalty(D, real1, fake1, opt.lambda_grad, opt.device) gradient_penalty.backward(retain_graph=False) # Logging: if step % 10 == 0: wandb.log( { f"D(G(z))@{current_scale}": errD_fake.item(), f"D(x)@{current_scale}": -errD_real.item(), f"gradient_penalty@{current_scale}": gradient_penalty.item() }, step=step, sync=False) optimizerD.step() ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): G.zero_grad() fake = G(noise.detach(), prev.detach(), temperature=1 if current_scale != opt.token_insert else 1) fake0 = preprocess(opt, fake, keepSky) Nf, Cf, Hf, Wf = fake0.shape if opt.cgan: fake_detection_map = detector(fake0) * detection_scale fake1 = torch.cat( [fake, F.interpolate(fake_detection_map, (Hf, Wf))], dim=1) else: fake1 = fake output = D(fake1) errG = -output.mean() errG.backward(retain_graph=False) if opt.alpha != 0: # i. e. we are trying to find an exact recreation of our input in the lat space Z_opt = opt.noise_amp * z_opt + z_prev G_rec = G( Z_opt.detach(), z_prev, temperature=1 if current_scale != opt.token_insert else 1) rec_loss = opt.alpha * F.mse_loss(G_rec, real) if opt.cgan: div = divergence(real_detection_map, preprocess(opt, G_rec, keepSky)) rec_loss += div rec_loss.backward( retain_graph=False ) # TODO: Check for unexpected argument retain_graph=True rec_loss = rec_loss.detach() else: # We are not trying to find an exact recreation rec_loss = torch.zeros([]) Z_opt = z_opt optimizerG.step() # More Logging: div = divergence(real_detection_map, preprocess(opt, fake, keepSky)) divergences.append(div) # logger.info("divergence(fake) = {}", div) if step % 10 == 0: wandb.log( { f"noise_amplitude@{current_scale}": opt.noise_amp, f"rec_loss@{current_scale}": rec_loss.item() }, step=step, sync=False, commit=True) # Rendering and logging images of levels if epoch % 500 == 0 or epoch == (opt.niter - 1): if opt.token_insert >= 0 and opt.nc_current == len(token_group): token_list = [list(group.keys())[0] for group in token_group] else: token_list = opt.token_list img = opt.ImgGen.render( one_hot_to_ascii_level(fake1[:, :-1].detach(), token_list)) img2 = opt.ImgGen.render( one_hot_to_ascii_level( G(Z_opt.detach(), z_prev, temperature=1 if current_scale != opt.token_insert else 1).detach(), token_list)) real_scaled = one_hot_to_ascii_level(real1[:, :-1].detach(), token_list) img3 = opt.ImgGen.render(real_scaled) wandb.log( { f"G(z)@{current_scale}": wandb.Image(img), f"G(z_opt)@{current_scale}": wandb.Image(img2), f"real@{current_scale}": wandb.Image(img3) }, sync=False, commit=False) real_scaled_path = os.path.join(wandb.run.dir, f"real@{current_scale}.txt") with open(real_scaled_path, "w") as f: f.writelines(real_scaled) wandb.save(real_scaled_path) # Learning Rate scheduler step schedulerD.step() schedulerG.step() if opt.cgan: div = divergence(real_detection_map, preprocess(opt, z_opt, keepSky)) divergences.append(div) # visualization config folder_name = 'gradcam' level_name = opt.input_name.rsplit(".", 1)[0].split("_", 1)[1] # GradCAM on D camD = LayerGradCam(D, D.tail) real0 = one_hot_to_ascii_level(real, opt.token_list) real0 = opt.ImgGen.render(real0) real0 = np.array(real0) attr = camD.attribute(real1, target=(0, 0, 0), relu_attributions=True) attr = LayerAttribution.interpolate(attr, (real0.shape[0], real0.shape[1]), 'bilinear') attr = attr.permute(2, 3, 1, 0).squeeze(3) attr = attr.detach().cpu().numpy() fig, ax = plt.subplots(1, 1) fig.figsize = (10, 1) ax.imshow(rgb2gray(real0), cmap='gray', vmin=0, vmax=1) im = ax.imshow(attr, cmap='jet', alpha=0.5) ax.axis('off') fig.colorbar(im, ax=ax, location='bottom', shrink=0.85) plt.suptitle(f'cGAN {level_name} D(x)@{current_scale} ({step})') plt.savefig(rf'{folder_name}\{level_name}_D_{current_scale}_{step}.png', bbox_inches='tight', pad_inches=0.1) # plt.show() plt.close() # GradCAM on G token_names = { 'M': 'Mario start', 'F': 'Mario finish', 'y': 'spiky', 'Y': 'winged spiky', 'k': 'green koopa', 'K': 'winged green koopa', '!': 'coin [?]', '#': 'pyramid', '-': 'sky', '1': 'invis. 1 up', '2': 'invis. coin', 'L': '1 up', '?': 'special [?]', '@': 'special [?]', 'Q': 'coin [?]', '!': 'coin [?]', 'C': 'coin brick', 'S': 'normal brick', 'U': 'mushroom brick', 'X': 'ground', 'E': 'goomba', 'g': 'goomba', 'k': 'green koopa', '%': 'platform', '|': 'platform bg', 'r': 'red koopa', 'R': 'winged red koopa', 'o': 'coin', 't': 'pipe', 'T': 'plant pipe', '*': 'bullet bill', '<': 'pipe top left', '>': 'pipe top right', '[': 'pipe left', ']': 'pipe right', 'B': 'bullet bill head', 'b': 'bullet bill body', 'D': 'used block', } def wrappedG(z): return G(z, z_opt) camG = LayerGradCam(wrappedG, G.tail[0]) z_cam = generate_spatial_noise([1, opt.nc_current, nzx, nzy], device=opt.device) z_cam = pad_noise(z_cam) attrs = [] for i in range(opt.nc_current): attr = camG.attribute(z_cam, target=(i, 0, 0), relu_attributions=True) attr = LayerAttribution.interpolate(attr, (real0.shape[0], real0.shape[1]), 'bilinear') attr = attr.permute(2, 3, 1, 0).squeeze(3) attr = attr.detach().cpu().numpy() attrs.append(attr) fig, axs = plt.subplots(opt.nc_current, 1) fig.figsize = (10, opt.nc_current) for i in range(opt.nc_current): axs[i].axis('off') axs[i].text(-0.1, 0.5, token_names[opt.token_list[i]], rotation=0, verticalalignment='center', horizontalalignment='right', transform=axs[i].transAxes) axs[i].imshow(rgb2gray(real0), cmap='gray', vmin=0, vmax=1) im = axs[i].imshow(attrs[i], cmap='jet', alpha=0.5) fig.colorbar(im, ax=axs, shrink=0.85) plt.suptitle(f'cGAN {level_name} G(z)@{current_scale} ({step})') plt.savefig(rf'{folder_name}\{level_name}_G_{current_scale}_{step}.png', bbox_inches='tight', pad_inches=0.1) # plt.show() plt.close() # Save networks torch.save(z_opt, "%s/z_opt.pth" % opt.outf) save_networks(G, D, z_opt, opt) wandb.save(opt.outf) return z_opt, input_from_prev_scale, G, divergences
loss, mu, logvar = train(epochs) # if loss <= MAX_LOSS: # break with torch.no_grad(): samples = [ torch.normal(mu, logvar.exp().sqrt()) for _ in range(NUM_SAMPLES) ] samples = [model.decode(sample) for sample in samples] def one_hot(x): x = F.one_hot(x, C).cuda() x = torch.transpose(x, 2, 3) x = torch.transpose(x, 1, 2) return x ind = [torch.argmax(sample, dim=1) for sample in samples] hotenc = [one_hot(x) for x in ind] ascii_gens = [one_hot_to_ascii_level(x, opt.token_list) for x in hotenc] ascii_real = one_hot_to_ascii_level(real, opt.token_list) gen_levels = [opt.ImgGen.render(x) for x in ascii_gens] real_level = opt.ImgGen.render(ascii_real) for i, gen_level in enumerate(gen_levels): gen_level.save(rf"VAE_Gen_patches\{EPOCHS}-{i}.png")
def train(real, opt): """ Wrapper function for training. Calculates necessary scales then calls train_single_scale on each. """ generators = [] noise_maps = [] noise_amplitudes = [] if opt.game == 'mario': token_group = MARIO_TOKEN_GROUPS else: # if opt.game == 'mariokart': token_group = MARIOKART_TOKEN_GROUPS scales = [[x, x] for x in opt.scales] opt.num_scales = len(scales) if opt.game == 'mario': scaled_list = special_mario_downsampling(opt.num_scales, scales, real, opt.token_list) else: # if opt.game == 'mariokart': scaled_list = special_mariokart_downsampling(opt.num_scales, scales, real, opt.token_list) reals = [*scaled_list, real] # If (experimental) token grouping feature is used: if opt.token_insert >= 0: reals = [(token_to_group(r, opt.token_list, token_group) if i < opt.token_insert else r) for i, r in enumerate(reals)] reals.insert( opt.token_insert, token_to_group(reals[opt.token_insert], opt.token_list, token_group)) input_from_prev_scale = torch.zeros_like(reals[0]) stop_scale = len(reals) opt.stop_scale = stop_scale # Log the original input level as an image img = opt.ImgGen.render(one_hot_to_ascii_level(real, opt.token_list)) wandb.log({"real": wandb.Image(img)}, commit=False) os.makedirs("%s/state_dicts" % (opt.out_), exist_ok=True) # Training Loop divergences = [] for current_scale in range(0, stop_scale): opt.outf = "%s/%d" % (opt.out_, current_scale) try: os.makedirs(opt.outf) except OSError: pass # If we are seeding, we need to adjust the number of channels if current_scale < (opt.token_insert + 1): # (stop_scale - 1): opt.nc_current = len(token_group) # Initialize models D, G = init_models(opt) # If we are seeding, the weights after the seed need to be adjusted if current_scale == (opt.token_insert + 1): # (stop_scale - 1): D, G = restore_weights(D, G, current_scale, opt) # Actually train the current scale z_opt, input_from_prev_scale, G, divs = train_single_scale( D, G, reals, generators, noise_maps, input_from_prev_scale, noise_amplitudes, opt) # Reset grads and save current scale G = reset_grads(G, False) G.eval() D = reset_grads(D, False) D.eval() generators.append(G) noise_maps.append(z_opt) noise_amplitudes.append(opt.noise_amp) divergences.append(divs) torch.save(noise_maps, "%s/noise_maps.pth" % (opt.out_)) torch.save(generators, "%s/generators.pth" % (opt.out_)) torch.save(reals, "%s/reals.pth" % (opt.out_)) torch.save(noise_amplitudes, "%s/noise_amplitudes.pth" % (opt.out_)) torch.save(opt.num_layer, "%s/num_layer.pth" % (opt.out_)) torch.save(opt.token_list, "%s/token_list.pth" % (opt.out_)) wandb.save("%s/*.pth" % opt.out_) torch.save(G.state_dict(), "%s/state_dicts/G_%d.pth" % (opt.out_, current_scale)) wandb.save("%s/state_dicts/*.pth" % opt.out_) del D, G torch.save(torch.tensor(divergences), "%s/divergences.pth" % opt.out_) return generators, noise_maps, reals, noise_amplitudes
def generate_samples_cmaes(generators, noise_maps, reals, noise_amplitudes, cmaes_noise, opt, in_s=None, scale_v=1.0, scale_h=1.0, current_scale=0, gen_start_scale=0, num_samples=50, render_images=True, save_tensors=False, save_dir="random_samples"): """ Generate samples given a pretrained TOAD-GAN (generators, noise_maps, reals, noise_amplitudes). Uses namespace "opt" that needs to be parsed. "in_s" can be used as a starting image in any scale set with "current_scale". "gen_start_scale" sets the scale generation is to be started in. "num_samples" is the number of different samples to be generated. "render_images" defines if images are to be rendered (takes space and time if many samples are generated). "save_tensors" defines if tensors are to be saved (can be needed for token insertion *experimental*). "save_dir" is the path the samples are saved in. """ # Holds images generated in current scale images_cur = [] # Check which game we are using for token groups if opt.game == 'mario': token_groups = MARIO_TOKEN_GROUPS elif opt.game == 'mariokart': token_groups = MARIOKART_TOKEN_GROUPS else: token_groups = [] NameError("name of --game not recognized. Supported: mario, mariokart") # Main sampling loop for G, Z_opt, noise_amp in zip(generators, noise_maps, noise_amplitudes): if current_scale >= len(generators): break # if we do not start at current_scale=0 we need this logger.info("Generating samples at scale {}", current_scale) # Padding (should be chosen according to what was trained with) n_pad = int(1 * opt.num_layer) if not opt.pad_with_noise: m = nn.ZeroPad2d(int(n_pad)) # pad with zeros else: m = nn.ReflectionPad2d(int(n_pad)) # pad with reflected noise # Calculate shapes to generate if 0 < gen_start_scale <= current_scale: # Special case! Can have a wildly different shape through in_s scale_v = in_s.shape[-2] / ( noise_maps[gen_start_scale - 1].shape[-2] - n_pad * 2) scale_h = in_s.shape[-1] / ( noise_maps[gen_start_scale - 1].shape[-1] - n_pad * 2) nzx = (Z_opt.shape[-2] - n_pad * 2) * scale_v nzy = (Z_opt.shape[-1] - n_pad * 2) * scale_h else: nzx = (Z_opt.shape[-2] - n_pad * 2) * scale_v nzy = (Z_opt.shape[-1] - n_pad * 2) * scale_h # Save list of images of previous scale and clear current images images_prev = images_cur images_cur = [] # Token insertion (Experimental feature! Generator needs to be trained with it) if current_scale < (opt.token_insert + 1): channels = len(token_groups) if in_s is not None and in_s.shape[1] != channels: old_in_s = in_s in_s = token_to_group(in_s, opt.token_list, token_groups) else: channels = len(opt.token_list) if in_s is not None and in_s.shape[1] != channels: old_in_s = in_s in_s = group_to_token(in_s, opt.token_list, token_groups) # If in_s is none or filled with zeros reshape to correct size with channels if in_s is None: in_s = torch.zeros(reals[0].shape[0], channels, *reals[0].shape[2:]).to(opt.device) elif in_s.sum() == 0: in_s = torch.zeros(1, channels, *in_s.shape[-2:]).to(opt.device) # Generate num_samples samples in current scale for n in tqdm(range(0, num_samples, 1)): # Get noise image #z_curr = generate_spatial_noise([1, channels, int(round(nzx)), int(round(nzy))], device=opt.device) z_curr = cmaes_noise[:(channels * int(round(nzx)) * int(round(nzy)))] cmaes_noise = cmaes_noise[(channels * int(round(nzx)) * int(round(nzy))):] z_curr = z_curr.reshape(1, channels, int(round(nzx)), int(round(nzy))) z_curr = m(z_curr) # Set up previous image I_prev if (not images_prev ) or current_scale == 0: # if there is no "previous" image I_prev = in_s else: I_prev = images_prev[n] # Transform to token groups if there is token insertion if current_scale == (opt.token_insert + 1): I_prev = group_to_token(I_prev, opt.token_list, token_groups) I_prev = interpolate( I_prev, [int(round(nzx)), int(round(nzy))], mode='bilinear', align_corners=False) I_prev = m(I_prev) # We take the optimized noise map Z_opt as an input if we start generating on later scales if current_scale < gen_start_scale: z_curr = Z_opt # Define correct token list (dependent on token insertion) if opt.token_insert >= 0 and z_curr.shape[1] == len(token_groups): token_list = [list(group.keys())[0] for group in token_groups] else: token_list = opt.token_list ########### # Generate! ########### z_in = noise_amp * z_curr + I_prev G.eval() #print(z_in.shape, I_prev.shape) with torch.no_grad(): I_curr = G(z_in.detach(), I_prev, temperature=1) # Allow road insertion in mario kart levels if opt.game == 'mariokart': if current_scale == 0 and opt.seed_road is not None: for token in token_list: if token == 'R': # Road map! tmp = opt.seed_road.clone().to(opt.device) I_curr[0, token_list.index(token)] = tmp elif token in [ 'O', 'Q', 'C', '<' ]: # Tokens that can only appear on roads I_curr[0, token_list. index(token)] *= opt.seed_road.to( opt.device) else: # Other tokens like walls I_curr[0, token_list.index(token)] = torch.min( I_curr[0, token_list.index(token)], 1 - opt.seed_road.to(opt.device)) # Save all scales # if True: # Save scale 0 and last scale # if current_scale == 0 or current_scale == len(reals) - 1: # Save only last scale if current_scale == len(reals) - 1: dir2save = opt.out_ + '/' + save_dir # Make directories try: os.makedirs(dir2save, exist_ok=True) if render_images: os.makedirs("%s/img" % dir2save, exist_ok=True) if save_tensors: os.makedirs("%s/torch" % dir2save, exist_ok=True) os.makedirs("%s/txt" % dir2save, exist_ok=True) except OSError: pass # Convert to ascii level level = one_hot_to_ascii_level(I_curr.detach(), token_list) # Render and save level image if render_images: img = opt.ImgGen.render(level) img.save("%s/img/%d_sc%d.png" % (dir2save, n, current_scale)) # Save level txt with open("%s/txt/%d_sc%d.txt" % (dir2save, n, current_scale), "w") as f: f.writelines(level) # Save torch tensor if save_tensors: torch.save( I_curr, "%s/torch/%d_sc%d.pt" % (dir2save, n, current_scale)) # Token insertion render (experimental!) if opt.token_insert >= 0 and current_scale >= 1: if old_in_s.shape[1] == len(token_groups): token_list = [ list(group.keys())[0] for group in token_groups ] else: token_list = opt.token_list level = one_hot_to_ascii_level(old_in_s.detach(), token_list) img = opt.ImgGen.render(level) img.save("%s/img/%d_sc%d.png" % (dir2save, n, current_scale - 1)) # Append current image images_cur.append(I_curr) # Go to next scale current_scale += 1 #return I_curr.detach() # return last generated image (usually unused) return images_cur
def train_single_scale(D, G, reals, generators, noise_maps, input_from_prev_scale, noise_amplitudes, opt): """ Train one scale. D and G are the current discriminator and generator, reals are the scaled versions of the original level, generators and noise_maps contain information from previous scales and will receive information in this scale, input_from_previous_scale holds the noise map and images from the previous scale, noise_amplitudes hold the amplitudes for the noise in all the scales. opt is a namespace that holds all necessary parameters. """ current_scale = len(generators) real = reals[current_scale] if opt.game == 'mario': token_group = MARIO_TOKEN_GROUPS else: # if opt.game == 'mariokart': token_group = MARIOKART_TOKEN_GROUPS nzx = real.shape[2] # Noise size x nzy = real.shape[3] # Noise size y padsize = int( 1 * opt.num_layer ) # As kernel size is always 3 currently, padsize goes up by one per layer if not opt.pad_with_noise: pad_noise = nn.ZeroPad2d(padsize) pad_image = nn.ZeroPad2d(padsize) else: pad_noise = nn.ReflectionPad2d(padsize) pad_image = nn.ReflectionPad2d(padsize) # setup optimizer optimizerD = optim.Adam(D.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(G.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600, 2500], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600, 2500], gamma=opt.gamma) if current_scale == 0: # Generate new noise z_opt = generate_spatial_noise([1, opt.nc_current, nzx, nzy], device=opt.device) z_opt = pad_noise(z_opt) else: # Add noise to previous output z_opt = torch.zeros([1, opt.nc_current, nzx, nzy]).to(opt.device) z_opt = pad_noise(z_opt) logger.info("Training at scale {}", current_scale) for epoch in tqdm(range(opt.niter)): step = current_scale * opt.niter + epoch noise_ = generate_spatial_noise([1, opt.nc_current, nzx, nzy], device=opt.device) noise_ = pad_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real D.zero_grad() output = D(real).to(opt.device) errD_real = -output.mean() errD_real.backward(retain_graph=True) # train with fake if (j == 0) & (epoch == 0): if current_scale == 0: # If we are in the lowest scale, noise is generated from scratch prev = torch.zeros(1, opt.nc_current, nzx, nzy).to(opt.device) input_from_prev_scale = prev prev = pad_image(prev) z_prev = torch.zeros(1, opt.nc_current, nzx, nzy).to(opt.device) z_prev = pad_noise(z_prev) opt.noise_amp = 1 else: # First step in NOT the lowest scale # We need to adapt our inputs from the previous scale and add noise to it prev = draw_concat(generators, noise_maps, reals, noise_amplitudes, input_from_prev_scale, "rand", pad_noise, pad_image, opt) # For the seeding experiment, we need to transform from token_groups to the actual token if current_scale == (opt.token_insert + 1): prev = group_to_token(prev, opt.token_list, token_group) prev = interpolate(prev, real.shape[-2:], mode="bilinear", align_corners=False) prev = pad_image(prev) z_prev = draw_concat(generators, noise_maps, reals, noise_amplitudes, input_from_prev_scale, "rec", pad_noise, pad_image, opt) # For the seeding experiment, we need to transform from token_groups to the actual token if current_scale == (opt.token_insert + 1): z_prev = group_to_token(z_prev, opt.token_list, token_group) z_prev = interpolate(z_prev, real.shape[-2:], mode="bilinear", align_corners=False) opt.noise_amp = update_noise_amplitude(z_prev, real, opt) z_prev = pad_image(z_prev) else: # Any other step prev = draw_concat(generators, noise_maps, reals, noise_amplitudes, input_from_prev_scale, "rand", pad_noise, pad_image, opt) # For the seeding experiment, we need to transform from token_groups to the actual token if current_scale == (opt.token_insert + 1): prev = group_to_token(prev, opt.token_list, token_group) prev = interpolate(prev, real.shape[-2:], mode="bilinear", align_corners=False) prev = pad_image(prev) # After creating our correct noise input, we feed it to the generator: noise = opt.noise_amp * noise_ + prev fake = G(noise.detach(), prev, temperature=1 if current_scale != opt.token_insert else 1) # Then run the result through the discriminator output = D(fake.detach()) errD_fake = output.mean() # Backpropagation errD_fake.backward(retain_graph=False) # Gradient Penalty gradient_penalty = calc_gradient_penalty(D, real, fake, opt.lambda_grad, opt.device) gradient_penalty.backward(retain_graph=False) # Logging: if step % 10 == 0: wandb.log( { f"D(G(z))@{current_scale}": errD_fake.item(), f"D(x)@{current_scale}": -errD_real.item(), f"gradient_penalty@{current_scale}": gradient_penalty.item() }, step=step, sync=False) optimizerD.step() ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): G.zero_grad() fake = G(noise.detach(), prev.detach(), temperature=1 if current_scale != opt.token_insert else 1) output = D(fake) errG = -output.mean() errG.backward(retain_graph=False) if opt.alpha != 0: # i. e. we are trying to find an exact recreation of our input in the lat space Z_opt = opt.noise_amp * z_opt + z_prev G_rec = G( Z_opt.detach(), z_prev, temperature=1 if current_scale != opt.token_insert else 1) rec_loss = opt.alpha * F.mse_loss(G_rec, real) rec_loss.backward( retain_graph=False ) # TODO: Check for unexpected argument retain_graph=True rec_loss = rec_loss.detach() else: # We are not trying to find an exact recreation rec_loss = torch.zeros([]) Z_opt = z_opt optimizerG.step() # More Logging: if step % 10 == 0: wandb.log( { f"noise_amplitude@{current_scale}": opt.noise_amp, f"rec_loss@{current_scale}": rec_loss.item() }, step=step, sync=False, commit=True) # Rendering and logging images of levels if epoch % 500 == 0 or epoch == (opt.niter - 1): if opt.token_insert >= 0 and opt.nc_current == len(token_group): token_list = [list(group.keys())[0] for group in token_group] else: token_list = opt.token_list img = opt.ImgGen.render( one_hot_to_ascii_level(fake.detach(), token_list)) img2 = opt.ImgGen.render( one_hot_to_ascii_level( G(Z_opt.detach(), z_prev, temperature=1 if current_scale != opt.token_insert else 1).detach(), token_list)) real_scaled = one_hot_to_ascii_level(real.detach(), token_list) img3 = opt.ImgGen.render(real_scaled) wandb.log( { f"G(z)@{current_scale}": wandb.Image(img), f"G(z_opt)@{current_scale}": wandb.Image(img2), f"real@{current_scale}": wandb.Image(img3) }, sync=False, commit=False) real_scaled_path = os.path.join(wandb.run.dir, f"real@{current_scale}.txt") with open(real_scaled_path, "w") as f: f.writelines(real_scaled) wandb.save(real_scaled_path) # Learning Rate scheduler step schedulerD.step() schedulerG.step() # Save networks torch.save(z_opt, "%s/z_opt.pth" % opt.outf) save_networks(G, D, z_opt, opt) wandb.save(opt.outf) return z_opt, input_from_prev_scale, G