Ejemplo n.º 1
0
def main():
    initial_latency = args.init_latency
    restore = args.restore
    bw_amplify = args.bw_amplify

    # Load env
    env = Env.Live_Streaming(initial_latency)
    _, action_dims = env.get_action_info()
    reply_buffer = Reply_Buffer(Config.reply_buffer_size)
    agent = Agent(action_dims)
    reward_logs = []
    loss_logs = []

    logs_path = Config.logs_path + '/'
    if bw_amplify:
        logs_path += 'latency_' + str(initial_latency) + 's_amplified/'
    else:
        logs_path += 'latency_' + str(initial_latency) + 's'
    if not os.path.exists(logs_path):
        os.makedirs(logs_path)

    starting_episode = 1
    # restore model
    if restore:
        starting_episode = agent.train_restore(logs_path)

    print("Episode starts from: ", starting_episode)
    for episode in range(starting_episode, Config.total_episode + 1):
        # reset env
        env_end = env.reset(bw_amplify=bw_amplify)
        env.act(0, 1)  # Default
        state = env.get_state()
        total_reward = 0.0

        # Update epsilon
        agent.update_epsilon_by_epoch(episode)
        while not env.streaming_finish():
            if Config.model_version == 0:
                action = agent.take_action(np.array([state]))
                action_1 = action // action_dims[1]
                action_2 = action % action_dims[1]
                reward = env.act(action_1, action_2)
                # print(reward)
                state_new = env.get_state()
                total_reward += reward
                action_onehot = np.zeros(action_dims[0] * action_dims[1])
                action_onehot[action] = 1
                # print(env.streaming_finish())
                reply_buffer.append((state, action_onehot, reward, state_new,
                                     env.streaming_finish()))
                state = state_new
            # elif Config.model_version == 1 or Config.model_version == 2:
            #     action_1, action_2 = agent.take_action(np.array([state]))
            #     # print(action_1, action_2)
            #     reward = env.act(action_1, action_2)
            #     # print(reward)
            #     state_new = env.get_state()
            #     total_reward += reward
            #     action_onehots = []
            #     action_1_onehot = np.zeros(action_dims[0])
            #     action_2_onehot = np.zeros(action_dims[1])
            #     action_1_onehot[action_1] = 1
            #     action_2_onehot[action_2] = 1
            #     # print(env.streaming_finish())
            #     reply_buffer.append((state, action_1_onehot, action_2_onehot, reward, state_new, env.streaming_finish()))
            #     state = state_new

        # sample batch from reply buffer
        if episode < starting_episode + Config.observe_episode:
            continue

        # update target network
        if episode % Config.update_target_frequency == 0:
            agent.update_target_network()

        if Config.model_version == 0:
            batch_state, batch_actions, batch_reward, batch_state_new, batch_over = reply_buffer.sample(
            )
            loss = agent.update_Q_network_v0(batch_state, batch_actions,
                                             batch_reward, batch_state_new,
                                             batch_over)
        # elif Config.model_version == 1:
        #     batch_state, batch_actions_1, batch_actions_2, batch_reward, batch_state_new, batch_over = reply_buffer.sample()
        #     loss = agent.update_Q_network_v1(batch_state, batch_actions_1, batch_actions_2, batch_reward, batch_state_new, batch_over)
        # elif Config.model_version == 2:
        #     batch_state, batch_actions_1, batch_actions_2, batch_reward, batch_state_new, batch_over = reply_buffer.sample()
        #     loss = agent.update_Q_network_v2(batch_state, batch_actions_1, batch_actions_2, batch_reward, batch_state_new, batch_over)

        loss_logs.extend([[episode, loss]])
        reward_logs.extend([[episode, total_reward]])

        # save model
        if episode % Config.save_logs_frequency == 0:
            print("episode:", episode)
            agent.save(episode, logs_path)
            np.save(os.path.join(logs_path, 'loss.npy'), np.array(loss_logs))
            np.save(os.path.join(logs_path, 'reward.npy'),
                    np.array(reward_logs))

        # print reward and loss
        if episode % Config.show_loss_frequency == 0:
            if Config.loss_version == 0:
                print('Episode: {} Reward: {:.3f} Loss: {:.3f}'.format(
                    episode, total_reward, loss[0]))
            # elif Config.loss_version == 1:
            #     print('Episode: {} Reward: {:.3f} Loss: {:.3f} and {:.3f}' .format(episode, total_reward, loss[0], loss[1]))
        agent.update_epsilon_by_epoch(episode)
Ejemplo n.º 2
0
def main():
    massive = args.massive
    episode = args.episode
    model_v = args.model_version
    init_latency = args.init_latency
    random_latency = args.random_latency
    bw_amplify = args.bw_amplify

    env = Env.Live_Streaming(init_latency, testing=True, massive=massive, random_latency=random_latency)
    _, action_dims = env.get_action_info()
    # reply_buffer = Reply_Buffer(Config.reply_buffer_size)
    agent = Agent(action_dims, model_version=model_v)
    if model_v == 0:
        if bw_amplify:
            model_path = './models/logs_m_' + str(model_v) + '/t_0/l_0/latency_Nones_amplified/model-' + str(episode) + '.pth'
        else:
            model_path = './models/logs_m_' + str(model_v) + '/t_0/l_0/latency_Nones/model-' + str(episode) + '.pth'
    agent.restore(model_path)

    if massive:
        if bw_amplify:
            compare_path = Config.a_cdf_dir
            result_path = Config.a_massive_result_files + '/latency_Nones/'
        else:
            compare_path = Config.cdf_dir
            result_path = Config.massive_result_files + 'model_' + str(model_v) + '/latency_Nones/'
        if not os.path.exists(compare_path):
            os.makedirs(compare_path)
        if not os.path.exists(result_path):
             os.makedirs(result_path) 
        if random_latency:
            compare_file = open(compare_path + 'naive_speed_normal.txt' , 'w')
        else:
            compare_file = open(compare_path + 'naive_speed_' + str(int(init_latency)) +'s.txt' , 'w')
               
        while True:
            # Start testing
            env_end = env.reset(testing=True, bw_amplify=bw_amplify)
            if env_end:
                break
            testing_start_time = env.get_server_time()
            print("Initial latency is: ", testing_start_time)
            tp_trace, time_trace, trace_name, starting_idx = env.get_player_trace_info()
            print("Trace name is: ", trace_name)
            
            # print(massive, episode, model_v)
            log_path = result_path + trace_name 
            log_file = open(log_path, 'w')
            env.act(0, 1, massive=massive)   # Default
            state = env.get_state()
            total_reward = 0.0
            while not env.streaming_finish():
                if model_v == 0:
                    action = agent.testing_take_action(np.array([state]))
                    action_1 = action//action_dims[1]
                    action_2 = action%action_dims[1]
                    reward = env.act(action_1, action_2, log_file, massive=massive)
                    # print(reward)
                    state_new = env.get_state()
                    state = state_new
                    total_reward += reward   
                    # print(action_1, action_2, reward)
            print('File: ', trace_name, ' reward is: ', total_reward) 
            # Get initial latency of player and how long time is used. and tp/time trace
            testing_duration = env.get_server_time() - testing_start_time
            tp_record, time_record = get_tp_time_trace_info(tp_trace, time_trace, starting_idx, testing_duration + env.player.get_buffer())
            log_file.write('\t'.join(str(tp) for tp in tp_record))
            log_file.write('\n')
            log_file.write('\t'.join(str(time) for time in time_record))
            # log_file.write('\n' + str(IF_NEW))
            log_file.write('\n' + str(testing_start_time))
            log_file.write('\n')
            log_file.close()
            env.massive_save(trace_name, compare_file)
            env.save_bw_trace(trace_name, compare_file)
        compare_file.close()
    else:
        # check results log path
        if bw_amplify:
            result_path = Config.a_regular_test_files + 'model_' + str(model_v) + '/latency_' + str(init_latency) + 's/'
        else:
            result_path = Config.regular_test_files + 'model_' + str(model_v) + '/latency_' + str(init_latency) + 's/'
        if not os.path.exists(result_path):
             os.makedirs(result_path) 
        # Start testing
        env_end = env.reset(testing=True, bw_amplify=bw_amplify)
        testing_start_time = env.get_server_time()
        print("Initial latency is: ", testing_start_time)
        tp_trace, time_trace, trace_name, starting_idx = env.get_player_trace_info()
        print("Trace name is: ", trace_name, starting_idx)
        
        # print(massive, episode, model_v)
        log_path = result_path + trace_name + '.txt'
        log_file = open(log_path, 'w')
        env.act(0, 1, log_file)   # Default
        state = env.get_state()
        total_reward = 0.0
        while not env.streaming_finish():
            if model_v == 0:
                action = agent.testing_take_action(np.array([state]))
                action_1 = action//action_dims[1]
                action_2 = action%action_dims[1]
                reward = env.act(action_1, action_2,log_file)
                # print(reward)
                state_new = env.get_state()
                state = state_new
                # print(reward)
                total_reward += reward   
                # print(action_1, action_2, reward)
        print('File: ', trace_name, ' reward is: ', total_reward) 
        # Get initial latency of player and how long time is used. and tp/time trace
        testing_duration = env.get_server_time() - testing_start_time
        tp_record, time_record = get_tp_time_trace_info(tp_trace, time_trace, starting_idx, testing_duration + env.player.get_buffer())
        log_file.write('\t'.join(str(tp) for tp in tp_record))
        log_file.write('\n')
        log_file.write('\t'.join(str(time) for time in time_record))
        # log_file.write('\n' + str(IF_NEW))
        log_file.write('\n' + str(testing_start_time))
        log_file.write('\n')
        log_file.close()