Exemple #1
0
def train_conv_vae_lstm_mdn():

    print("\n\n\n\n\n")
    # env_name = "SonicTheHedgehog-Genesis" # None
    env_name = "SonicTheHedgehog2-Genesis"
    # env_name = "SonicAndKnuckles-Genesis"
    # env_name = "SonicTheHedgehog3-Genesis"
    # env_name = "SonicAndKnuckles3-Genesis"

    # conv_vae_filename = "weights/conv_vae_gray_edges.pkl" # 1, 1024
    # lstm_mdn_filename = "weights/lstm_mdn_gray_edges.pkl" # 1024
    conv_vae_filename = "weights/conv_vae_gray.pkl"  # 1, 1024
    lstm_mdn_filename = "weights/lstm_mdn_gray.pkl"  # 1024

    env = retro.make(env_name)
    # print(env.observation_space) # Box(224, 320, 3)
    # print(env.action_space) # MultiBinary(12)
    # print(env.action_space.sample()) # [1 1 1 0 1 0 1 0 0 1 1 1]

    conv_vae_buffer = []
    latent_vector = None
    batch_size = 50
    sequence_len = batch_size - 1

    conv_vae = ConvVAE((1, 128, 128), 1024)
    conv_vae_optimizer = optim.Adam(conv_vae.parameters(), lr=0.00025)
    if os.path.exists(conv_vae_filename):
        print("loading conv vae weights")
        conv_vae.load_state_dict(torch.load(conv_vae_filename))

    lstm_mdn = LSTM(vector_size=1024)
    lstm_mdn_optimizer = optim.Adam(lstm_mdn.parameters(), lr=0.00025)
    if os.path.exists(lstm_mdn_filename):
        print("loading lstm mdn weights")
        lstm_mdn.load_state_dict(torch.load(lstm_mdn_filename))

    for episode in range(1, 2):

        img = env.reset()  # (224, 320, 3)

        step = 0

        reward_dict = {
            "current_score": 0,
            "current_x": 0,
            "current_rings": 0,
            "reward_flow": 0,
            "lives": None
        }

        while True:

            # img = dumyshape_gray_edges(img)  # (1, 128, 128)
            img = dumyshape_gray(img)  # (1, 128, 128)
            img = torch.FloatTensor(img)  # [1, 128, 128]
            img = img.unsqueeze(0)  # [1, 1, 128, 128]

            conv_vae_buffer.append(img)

            ################
            ## cnn_vae train
            if len(conv_vae_buffer) == batch_size:  # 1000

                conv_vae_buffer = torch.cat(
                    conv_vae_buffer)  # [4000, 1, 128, 128]

                deconv_img, mu, logvar, z = conv_vae(conv_vae_buffer)
                # [1000, 1, 128, 128] [1000, 1024] [1000, 1024] [1000, 1024]

                ##############
                ## lstm buffer
                if latent_vector is None:
                    latent_vector = z
                else:
                    latent_vector = torch.cat((latent_vector, z), dim=0)

                #################
                ## conv vae train
                # conv_vae_loss = conv_vae.conv_vae_loss(deconv_img, conv_vae_buffer, mu, logvar)
                # print(step, "loss:",conv_vae_loss)
                # conv_vae_optimizer.zero_grad()
                # conv_vae_loss.backward()
                # conv_vae_optimizer.step()

                # zero out conv buffer
                conv_vae_buffer = []

            #################
            ## lstm mdn train
            # +1 to represent future step
            # time vector:     [t,t,t,t,t]
            # latent vector: [l,l,l,l,l,l]
            if latent_vector is not None and latent_vector.size(
                    0) >= sequence_len + 1:
                print("lstm mdn training", step, latent_vector.size())

                # cut vector to sequence_len + 1
                latent_vector_1001 = latent_vector[:sequence_len +
                                                   1, :]  # [1001, 4608]

                # trunkate 1 element from right, to produce 1000 size time vectors
                # time vector:     [t,t,t,t,t]
                # latent vector: [l,l,l,l,l]
                latent_vector_1000_right = latent_vector_1001[:
                                                              -1, :]  # [1000, 4608]
                latent_vector_1000_right = latent_vector_1000_right.unsqueeze(
                    0)  # [1, 1000, 4608]

                # pi, sigma, mu = lstm_mdn(latent_vector_1000_right)
                # [1, 1000, 5, 4608]

                z_t = lstm_mdn.predict(
                    latent_vector_1000_right)  # [1, 49, 1024]

                # trunkate 1 element from left
                # time vector:     [t,t,t,t,t]
                # latent vector:   [l,l,l,l,l]
                # 1000 elements buffer: predictions and actual line up
                target_latent_vector = latent_vector_1001[
                    1:, :]  # [1000, 4608]
                target_latent_vector = target_latent_vector.unsqueeze(
                    0)  # [1, 1000, 4608]

                # lstm_mdn_loss = lstm_mdn.mdn_loss_function(pi, sigma, mu, target_latent_vector)
                lstm_mse_loss = lstm_mdn.mse_loss_function(
                    actual=target_latent_vector, prediction=z_t)

                lstm_mdn_optimizer.zero_grad()

                # lstm_mdn_loss.backward()
                lstm_mse_loss.backward()

                lstm_mdn_optimizer.step()

                # zero out buffer and states
                latent_vector = None
                lstm_mdn.reset_states()

            action = env.action_space.sample()
            action[7] = 1

            img, reward, done, info = env.step(action)

            #####################
            ## reward calculation
            # reward_flow = reward_calculation(reward_dict, info)

            # print(reward_flow)
            # time.sleep(.025)
            # env.render()

            step += 1

            #################
            ## save weights
            if step >= 8000:
                print("saving weights")
                torch.save(conv_vae.state_dict(), conv_vae_filename)
                # torch.save(lstm_mdn.state_dict(), lstm_mdn_filename)
                step = 0

            if done:
                break

    env.close()
Exemple #2
0
def prepare_list_pics():
    # conv_vae_filename = "weights/conv_vae_SonicAndKnuckles.pkl"
    # conv_vae_filename = "weights/conv_vae_gray_edges.pkl"
    conv_vae_filename = "weights/conv_vae_gray.pkl"
    # lstm_mdn_filename = "weights/lstm_mdn_SonicAndKnuckles.pkl"
    # lstm_mdn_filename = "weights/lstm_mdn_gray_edges.pkl"
    lstm_mdn_filename = "weights/lstm_mdn_gray.pkl"

    base_dir = "/opt/Projects/dataset/sonic"

    batch_size = 4000

    conv_vae = ConvVAE((1, 128, 128), 1024)  # 4608
    conv_vae_optimizer = optim.Adam(conv_vae.parameters(), lr=0.00025)
    if os.path.exists(conv_vae_filename):
        print("loading conv vae weights")
        conv_vae.load_state_dict(torch.load(conv_vae_filename))

    lstm_mdn = LSTM(vector_size=1024)
    lstm_mdn_optimizer = optim.Adam(lstm_mdn.parameters(), lr=0.00025)
    if os.path.exists(lstm_mdn_filename):
        print("loading lstm mdn weights")
        lstm_mdn.load_state_dict(torch.load(lstm_mdn_filename))

    for subdir in os.listdir(base_dir):
        # print(subdir) # 1_1  1_2  2_3  2_4 ...

        # epoch inside subdir
        for epoch in range(200):

            src_dir = os.path.join(base_dir,
                                   subdir)  # /opt/Projects/dataset/sonic/1

            list_of_files = list(os.walk(src_dir))[0][2]

            # full_batches = len(list_of_files) // batch_size
            len_of_files = len(list_of_files)  # 79964
            # print( full_batches ) # 79

            start = 0
            offset = batch_size

            while offset <= (len_of_files - 1):
                batch_list = list_of_files[start:offset]  # 1000

                train_conv_lstm_on_pics(conv_vae, conv_vae_optimizer, lstm_mdn,
                                        lstm_mdn_optimizer, src_dir,
                                        batch_list)

                start += batch_size
                offset += batch_size

            print(epoch)

            print("saving conv vae weights")
            torch.save(conv_vae.state_dict(), conv_vae_filename)

            print("saving lstm mdn weights")
            torch.save(lstm_mdn.state_dict(), lstm_mdn_filename)

            print("\n")

            lstm_mdn.reset_states()  # после епоча