def train_conv_vae_lstm_mdn(): print("\n\n\n\n\n") # env_name = "SonicTheHedgehog-Genesis" # None env_name = "SonicTheHedgehog2-Genesis" # env_name = "SonicAndKnuckles-Genesis" # env_name = "SonicTheHedgehog3-Genesis" # env_name = "SonicAndKnuckles3-Genesis" # conv_vae_filename = "weights/conv_vae_gray_edges.pkl" # 1, 1024 # lstm_mdn_filename = "weights/lstm_mdn_gray_edges.pkl" # 1024 conv_vae_filename = "weights/conv_vae_gray.pkl" # 1, 1024 lstm_mdn_filename = "weights/lstm_mdn_gray.pkl" # 1024 env = retro.make(env_name) # print(env.observation_space) # Box(224, 320, 3) # print(env.action_space) # MultiBinary(12) # print(env.action_space.sample()) # [1 1 1 0 1 0 1 0 0 1 1 1] conv_vae_buffer = [] latent_vector = None batch_size = 50 sequence_len = batch_size - 1 conv_vae = ConvVAE((1, 128, 128), 1024) conv_vae_optimizer = optim.Adam(conv_vae.parameters(), lr=0.00025) if os.path.exists(conv_vae_filename): print("loading conv vae weights") conv_vae.load_state_dict(torch.load(conv_vae_filename)) lstm_mdn = LSTM(vector_size=1024) lstm_mdn_optimizer = optim.Adam(lstm_mdn.parameters(), lr=0.00025) if os.path.exists(lstm_mdn_filename): print("loading lstm mdn weights") lstm_mdn.load_state_dict(torch.load(lstm_mdn_filename)) for episode in range(1, 2): img = env.reset() # (224, 320, 3) step = 0 reward_dict = { "current_score": 0, "current_x": 0, "current_rings": 0, "reward_flow": 0, "lives": None } while True: # img = dumyshape_gray_edges(img) # (1, 128, 128) img = dumyshape_gray(img) # (1, 128, 128) img = torch.FloatTensor(img) # [1, 128, 128] img = img.unsqueeze(0) # [1, 1, 128, 128] conv_vae_buffer.append(img) ################ ## cnn_vae train if len(conv_vae_buffer) == batch_size: # 1000 conv_vae_buffer = torch.cat( conv_vae_buffer) # [4000, 1, 128, 128] deconv_img, mu, logvar, z = conv_vae(conv_vae_buffer) # [1000, 1, 128, 128] [1000, 1024] [1000, 1024] [1000, 1024] ############## ## lstm buffer if latent_vector is None: latent_vector = z else: latent_vector = torch.cat((latent_vector, z), dim=0) ################# ## conv vae train # conv_vae_loss = conv_vae.conv_vae_loss(deconv_img, conv_vae_buffer, mu, logvar) # print(step, "loss:",conv_vae_loss) # conv_vae_optimizer.zero_grad() # conv_vae_loss.backward() # conv_vae_optimizer.step() # zero out conv buffer conv_vae_buffer = [] ################# ## lstm mdn train # +1 to represent future step # time vector: [t,t,t,t,t] # latent vector: [l,l,l,l,l,l] if latent_vector is not None and latent_vector.size( 0) >= sequence_len + 1: print("lstm mdn training", step, latent_vector.size()) # cut vector to sequence_len + 1 latent_vector_1001 = latent_vector[:sequence_len + 1, :] # [1001, 4608] # trunkate 1 element from right, to produce 1000 size time vectors # time vector: [t,t,t,t,t] # latent vector: [l,l,l,l,l] latent_vector_1000_right = latent_vector_1001[: -1, :] # [1000, 4608] latent_vector_1000_right = latent_vector_1000_right.unsqueeze( 0) # [1, 1000, 4608] # pi, sigma, mu = lstm_mdn(latent_vector_1000_right) # [1, 1000, 5, 4608] z_t = lstm_mdn.predict( latent_vector_1000_right) # [1, 49, 1024] # trunkate 1 element from left # time vector: [t,t,t,t,t] # latent vector: [l,l,l,l,l] # 1000 elements buffer: predictions and actual line up target_latent_vector = latent_vector_1001[ 1:, :] # [1000, 4608] target_latent_vector = target_latent_vector.unsqueeze( 0) # [1, 1000, 4608] # lstm_mdn_loss = lstm_mdn.mdn_loss_function(pi, sigma, mu, target_latent_vector) lstm_mse_loss = lstm_mdn.mse_loss_function( actual=target_latent_vector, prediction=z_t) lstm_mdn_optimizer.zero_grad() # lstm_mdn_loss.backward() lstm_mse_loss.backward() lstm_mdn_optimizer.step() # zero out buffer and states latent_vector = None lstm_mdn.reset_states() action = env.action_space.sample() action[7] = 1 img, reward, done, info = env.step(action) ##################### ## reward calculation # reward_flow = reward_calculation(reward_dict, info) # print(reward_flow) # time.sleep(.025) # env.render() step += 1 ################# ## save weights if step >= 8000: print("saving weights") torch.save(conv_vae.state_dict(), conv_vae_filename) # torch.save(lstm_mdn.state_dict(), lstm_mdn_filename) step = 0 if done: break env.close()
def prepare_list_pics(): # conv_vae_filename = "weights/conv_vae_SonicAndKnuckles.pkl" # conv_vae_filename = "weights/conv_vae_gray_edges.pkl" conv_vae_filename = "weights/conv_vae_gray.pkl" # lstm_mdn_filename = "weights/lstm_mdn_SonicAndKnuckles.pkl" # lstm_mdn_filename = "weights/lstm_mdn_gray_edges.pkl" lstm_mdn_filename = "weights/lstm_mdn_gray.pkl" base_dir = "/opt/Projects/dataset/sonic" batch_size = 4000 conv_vae = ConvVAE((1, 128, 128), 1024) # 4608 conv_vae_optimizer = optim.Adam(conv_vae.parameters(), lr=0.00025) if os.path.exists(conv_vae_filename): print("loading conv vae weights") conv_vae.load_state_dict(torch.load(conv_vae_filename)) lstm_mdn = LSTM(vector_size=1024) lstm_mdn_optimizer = optim.Adam(lstm_mdn.parameters(), lr=0.00025) if os.path.exists(lstm_mdn_filename): print("loading lstm mdn weights") lstm_mdn.load_state_dict(torch.load(lstm_mdn_filename)) for subdir in os.listdir(base_dir): # print(subdir) # 1_1 1_2 2_3 2_4 ... # epoch inside subdir for epoch in range(200): src_dir = os.path.join(base_dir, subdir) # /opt/Projects/dataset/sonic/1 list_of_files = list(os.walk(src_dir))[0][2] # full_batches = len(list_of_files) // batch_size len_of_files = len(list_of_files) # 79964 # print( full_batches ) # 79 start = 0 offset = batch_size while offset <= (len_of_files - 1): batch_list = list_of_files[start:offset] # 1000 train_conv_lstm_on_pics(conv_vae, conv_vae_optimizer, lstm_mdn, lstm_mdn_optimizer, src_dir, batch_list) start += batch_size offset += batch_size print(epoch) print("saving conv vae weights") torch.save(conv_vae.state_dict(), conv_vae_filename) print("saving lstm mdn weights") torch.save(lstm_mdn.state_dict(), lstm_mdn_filename) print("\n") lstm_mdn.reset_states() # после епоча