예제 #1
0
# Copyright (c) Facebook, Inc. and its affiliates.
예제 #2
0
파일: match.py 프로젝트: apjacob/minirts
    device = torch.device('cuda:%d' % args.gpu)
    coach = ConvRnnCoach.load(args.coach_path).to(device)
    coach.max_raw_chars = args.max_raw_chars
    executor = Executor.load(args.model_path).to(device)
    executor_wrapper = ExecutorWrapper(coach, executor, coach.num_instructions,
                                       args.max_raw_chars, args.cheat,
                                       args.inst_mode)
    executor_wrapper.train(False)

    game_option = get_game_option(args)
    ai1_option, ai2_option = get_ai_options(args, coach.num_instructions)

    context, act_dc = create_game(args.num_thread, ai1_option, ai2_option,
                                  game_option)
    context.start()
    dc = DataChannelManager([act_dc])

    result_stat = ResultStat('reward', None)
    while not context.terminated():
        data = dc.get_input(max_timeout_s=1)
        if len(data) == 0:
            continue
        data = to_device(data['act'], device)
        result_stat.feed(data)
        reply = executor_wrapper.forward(data)

        dc.set_reply('act', reply)

    print(result_stat.log(0))
    dc.terminate()
예제 #3
0
    result1 = ResultStat('reward', None)
    result2 = ResultStat('reward', None)
    i = 0
    while not context.terminated():
        i += 1
        if i % 1000 == 0:
            print('%d, progress agent1: win %d, loss %d' %
                  (i, result1.win, result1.loss))

        data = dc.get_input(max_timeout_s=1)
        if len(data) == 0:
            continue
        for key in data:
            # print(key)
            batch = to_device(data[key], device)
            if key == 'act1':
                result1.feed(batch)
                with torch.no_grad():
                    reply = model1.forward(batch)
            elif key == 'act2':
                result2.feed(batch)
                with torch.no_grad():
                    reply = model2.forward(batch)
            else:
                assert False
            dc.set_reply(key, reply)

    print(result1.log(0))
    print(result2.log(0))
    dc.terminate()
예제 #4
0
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse
import os
import sys
import pprint
from set_path import append_sys_path
append_sys_path()
import torch
import tube
from pytube import DataChannelManager
import minirts
import numpy as np
import random
import pickle
from collections import defaultdict
from rnn_coach import ConvRnnCoach
from onehot_coach import ConvOneHotCoach
from rnn_generator import RnnGenerator
from itertools import groupby
from executor_wrapper import ExecutorWrapper
from executor import Executor
from common_utils import to_device, ResultStat, Logger
from best_models import best_executors, best_coaches
from tqdm import tqdm
p1dict = defaultdict(list)
예제 #5
0
def run_eval(args, model1, model2, device, num_games=100):

    num_eval_games = num_games

    result1 = ResultStat("reward", None)
    result2 = ResultStat("reward", None)

    game_option = get_game_option(args)
    ai1_option, ai2_option = get_ai_options(
        args, [model1.coach.num_instructions, model2.coach.num_instructions])

    if args.opponent == "sp":
        context, act1_dc, act2_dc = init_mt_games(num_eval_games, 0, args,
                                                  ai1_option, ai2_option,
                                                  game_option)
        pbar = tqdm(total=num_eval_games * 2)
    else:
        context, act1_dc, act2_dc = init_mt_games(0, num_eval_games, args,
                                                  ai1_option, ai2_option,
                                                  game_option)
        pbar = tqdm(total=num_eval_games)
    # context, act1_dc, act2_dc = init_games(
    #     num_eval_games, ai1_option, ai2_option, game_option)
    context.start()
    dc = DataChannelManager([act1_dc, act2_dc])

    i = 0
    model1.eval()
    model2.eval()

    while not context.terminated():
        i += 1
        # if i % 1000 == 0:
        #     print('%d, progress agent1: win %d, loss %d' % (i, result1.win, result1.loss))

        data = dc.get_input(max_timeout_s=1)
        if len(data) == 0:
            continue
        for key in data:
            # print(key)
            batch = to_device(data[key], device)
            if key == "act1":
                batch["actor"] = "act1"
                ## Add batches to state table using sampling before adding
                ## Add based on the game_id

                result1.feed(batch)
                with torch.no_grad():
                    reply, _ = model1.forward(batch)  # , exec_sample=True)

            elif key == "act2":
                batch["actor"] = "act2"
                result2.feed(batch)

                with torch.no_grad():
                    reply, _ = model2.forward(batch)

            else:
                assert False

            dc.set_reply(key, reply)

            game_ids = batch["game_id"].cpu().numpy()
            terminals = batch["terminal"].cpu().numpy().flatten()

            for i, g_id in enumerate(game_ids):
                if terminals[i] == 1:
                    pbar.update(1)

    model1.eval()
    model2.eval()
    pbar.close()

    return result1, result2
예제 #6
0
class Agent:
    def __init__(
        self,
        coach,
        executor,
        device,
        args,
        writer,
        trainable=False,
        exec_sample=False,
        pg="ppo",
        tag="",
    ):

        self.__coach = coach
        self.__executor = executor
        self.__trainable = trainable
        self.device = device
        self.args = args
        self.tb_log = args.tb_log
        self.save_folder = None
        self.tag = tag
        self.best_test_win_pct = 0
        self.tb_writer = writer
        self.traj_dict = defaultdict(list)
        self.result_dict = {}
        self.result = ResultStat("reward", None)
        self.model = self.load_model(self.__coach, self.__executor, self.args)
        self.exec_sample = exec_sample

        print("Using pg {} algorithm".format(args.pg))
        if self.__trainable:
            # if args.split_train:
            #     self.executor_optimizer = optim.Adam(
            #         self.model.executor.parameters(), lr=args.lr
            #     )
            #     self.coach_optimizer = optim.Adam(
            #         self.model.coach.parameters(), lr=args.lr
            #     )
            # else:
            self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr)
            self.sa_buffer = StateActionBuffer(
                max_buffer_size=args.max_table_size,
                buffer_add_prob=args.sampling_freq)
            # wandb.watch(self.model)
            self.pg = pg
        else:
            self.optimizer = None
            self.sa_buffer = None
            self.pg = None

    @property
    def executor(self):
        return self.__executor

    def set_optimizer(self, optimizer):
        self.optimizer = optimizer

    @property
    def coach(self):
        return self.__coach

    def clone(self, model_type="current"):
        if model_type == "current":
            print("Cloning current model.")
            cpy = Agent(
                self.__coach,
                self.__executor,
                self.device,
                self.args,
                self.tb_writer,
                trainable=False,
            )
            cpy.model = self.model
        else:
            raise NameError

        return cpy

    def load_model(self, coach_path, executor_path, args):
        coach_rule_emb_size = getattr(args, "coach_rule_emb_size", 0)
        executor_rule_emb_size = getattr(args, "executor_rule_emb_size", 0)
        inst_dict_path = getattr(args, "inst_dict_path", None)
        coach_random_init = getattr(args, "coach_random_init", False)

        if isinstance(coach_path, str):
            if "onehot" in coach_path:
                coach = ConvOneHotCoach.load(coach_path).to(self.device)
            elif "gen" in coach_path:
                coach = RnnGenerator.load(coach_path).to(self.device)
            else:
                coach = ConvRnnCoach.rl_load(
                    coach_path,
                    coach_rule_emb_size,
                    inst_dict_path,
                    coach_random_init=coach_random_init,
                ).to(self.device)
        else:
            print("Sharing coaches.")
            coach = coach_path
        coach.max_raw_chars = args.max_raw_chars

        if isinstance(executor_path, str):
            executor = Executor.rl_load(executor_path, executor_rule_emb_size,
                                        inst_dict_path).to(self.device)
        else:
            print("Sharing executors.")
            executor = executor_path

        executor_wrapper = ExecutorWrapper(
            coach,
            executor,
            coach.num_instructions,
            args.max_raw_chars,
            args.cheat,
            args.inst_mode,
        )
        executor_wrapper.train(False)
        return executor_wrapper

    def simulate(self, index, batch):
        with torch.no_grad():
            reply, coach_reply = self.model.forward(batch, self.exec_sample)

            if self.__trainable and self.model.training:
                assert self.sa_buffer is not None

                ## Add coach reply to state-reply dict
                # batch.update(coach_reply['samples'])
                rv = self.sa_buffer.push(index, batch)

            return reply

    def reset(self):
        if self.__trainable and self.model.training:
            assert len(self.sa_buffer) == 0
        else:
            self.sa_buffer = StateActionBuffer(
                max_buffer_size=self.args.max_table_size,
                buffer_add_prob=self.args.sampling_freq,
            )

        self.result = ResultStat("reward", None)
        self.traj_dict = defaultdict(list)
        self.result_dict = {}

    def update_logs(self, index, batch, reply):
        self.result.feed(batch)

        inst = reply["inst"]  ## Get get_inst
        game_ids = batch["game_id"].cpu().numpy()
        terminals = batch["terminal"].cpu().numpy().flatten()
        rewards = batch["reward"].cpu().numpy().flatten()

        # print(len(state_table))
        games_terminated = 0
        for i, g_id in enumerate(game_ids):

            act_dict = self.traj_dict
            act_win_dict = self.result_dict
            dict_key = str(index) + "_" + str(g_id[0])

            act_dict[dict_key].append(
                self.model.coach.inst_dict.get_inst(inst[i]))

            if terminals[i] == 1:
                games_terminated += 1

                # print("Game {} has terminated.".format(g_id[0]))
                if rewards[i] == 1:
                    act_win_dict[dict_key] = 1
                elif rewards[i] == -1:
                    act_win_dict[dict_key] = -1

        return games_terminated

    def train(self):
        self.model.train()

    def eval(self):
        self.model.eval()

    def eval_model(self, gen_id, other_agent, num_games=100):
        print("Evaluating model....")
        e_result1, e_result2 = run_eval(self.args, self.model,
                                        other_agent.model, self.device,
                                        num_games)
        test_win_pct = e_result1.win / e_result1.num_games

        print(e_result1.log(0))
        print(e_result2.log(0))

        if self.tb_log:
            a1_win = e_result1.win / e_result1.num_games
            a1_loss = e_result1.loss / e_result1.num_games

            if self.args.opponent == "rb":
                a2_win = e_result1.loss / e_result1.num_games
                a2_loss = e_result1.win / e_result1.num_games
            else:
                a2_win = e_result2.win / e_result2.num_games
                a2_loss = e_result2.loss / e_result2.num_games

            wandb.log(
                {
                    "Test/Agent-1/Win": a1_win,
                    "Test/Agent-1/Loss": a1_loss,
                    "Test/Eval_model/Win": a2_win,
                    "Test/Eval_model/Loss": a2_loss,
                },
                step=gen_id,
            )

            self.tb_writer.add_scalar("Test/Agent-1/Win", a1_win, gen_id)
            self.tb_writer.add_scalar("Test/Agent-1/Loss", a1_loss, gen_id)

            self.tb_writer.add_scalar("Test/Eval_model/Win", a2_win, gen_id)
            self.tb_writer.add_scalar("Test/Eval_model/Loss", a2_loss, gen_id)

            if test_win_pct > self.best_test_win_pct:
                self.best_test_win_pct = test_win_pct
                self.save_coach(gen_id)
                self.save_executor(gen_id)
                # wandb.save('{}/*.pt'.format(self.save_folder))
                # wandb.save('{}/*.params'.format(self.save_folder))

        return e_result1, e_result2

    def init_save_folder(self, log_name):

        self.save_folder = os.path.join(self.args.save_folder, log_name)

        if os.path.exists(self.save_folder):
            print(
                "Attempting to create an existing folder.. hence skipping...")
        else:
            os.makedirs(self.save_folder)

    #######################
    ## Executor specific ##
    #######################

    def align_executor_actions(self, batches):
        # action_probs = ['glob_cont_prob',
        #  'cmd_type_prob',
        #  'gather_idx_prob',
        #  'attack_idx_prob',
        #  'unit_type_prob',
        #  'building_type_prob',
        #  'building_loc_prob',
        #  'move_loc_prob']

        action_samples = [
            "current_cmd_type",
            "current_cmd_unit_type",
            "current_cmd_x",
            "current_cmd_y",
            "current_cmd_gather_idx",
            "current_cmd_attack_idx",
        ]

        action_samples_dict = {key: "t_" + key for key in action_samples}

        for (g_id, elements) in batches.items():

            for asp in action_samples_dict:
                elements[action_samples_dict[asp]] = elements[asp][1:]

            for (key, element) in elements.items():

                ## Alignment
                if key in action_samples_dict.values():
                    continue
                else:
                    elements[key] = element[:-1]

    def train_executor(self,
                       gen_id,
                       agg_win_batches=None,
                       agg_loss_batches=None):
        assert self.__trainable
        assert self.args.sampling_freq >= 1.0

        if len(self.sa_buffer) or (agg_win_batches is not None
                                   and agg_loss_batches is not None):
            # self.align_executor_actions(win_batches)
            # self.align_executor_actions(loss_batches)

            if agg_loss_batches is not None and agg_win_batches is not None:
                win_batches = agg_win_batches
                loss_batches = agg_loss_batches
            else:
                win_batches, loss_batches = self.sa_buffer.pop(
                    self.result_dict)

            if self.pg == "vanilla":
                l1_loss, mse_loss, value = self.__update_executor_vanilla(
                    win_batches, loss_batches)

                if self.tb_log:
                    wandb.log(
                        {
                            "Loss/Executor/RL-Loss": l1_loss,
                            "Loss/Executor/MSE-Loss": mse_loss,
                            "Loss/Executor/Value": value,
                        },
                        step=gen_id,
                    )
                    self.tb_writer.add_scalar("Loss/Executor/RL-Loss", l1_loss,
                                              gen_id)
                    self.tb_writer.add_scalar("Loss/Executor/MSE-Loss",
                                              mse_loss, gen_id)
                    self.tb_writer.add_scalar("Loss/Executor/Value", value,
                                              gen_id)

            elif self.pg == "ppo":
                action_loss, value_loss, entropy = self.__update_executor_ppo(
                    win_batches, loss_batches)

                if self.tb_log:
                    wandb.log(
                        {
                            "Loss/Executor/Action-Loss": action_loss,
                            "Loss/Executor/Value-Loss": value_loss,
                            "Loss/Executor/Entropy": entropy,
                        },
                        step=gen_id,
                    )

                    self.tb_writer.add_scalar("Loss/Executor/Action-Loss",
                                              action_loss, gen_id)
                    self.tb_writer.add_scalar("Loss/Executor/Value-Loss",
                                              value_loss, gen_id)
                    self.tb_writer.add_scalar("Loss/Executor/Entropy", entropy,
                                              gen_id)
            else:
                raise NotImplementedError
        else:
            print("State-Action Buffer is empty.")

    def __update_executor_vanilla(self, win_batches, loss_batches):
        assert self.__trainable

        self.model.train()
        self.optimizer.zero_grad()

        mse = torch.nn.MSELoss(reduction="none")

        denom = len(win_batches) + len(loss_batches)

        l1_losses = 0
        mse_losses = 0
        total_values = 0

        for (kind, r), batches in zip(reward_tuple,
                                      [win_batches, loss_batches]):

            for game_id, batch in batches.items():

                episode_len = batch["reward"].size(0)
                intervals = list(
                    range(0, episode_len, self.args.train_batch_size))
                interval_tuples_iter = zip(intervals,
                                           intervals[1:] + [episode_len])

                for (s, e) in interval_tuples_iter:

                    sliced_batch = slice_batch(batch, s, e)

                    (
                        log_prob,
                        all_losses,
                        value,
                    ) = self.model.get_executor_vanilla_rl_train_loss(
                        sliced_batch)
                    l1 = (
                        r * log_prob
                    )  # (r * torch.ones_like(value) - value.detach()) * log_prob
                    l2 = mse(value, r * torch.ones_like(value))

                    l1_mean = -1.0 * l1.sum() / denom
                    mse_loss = 1.0 * l2.sum() / denom
                    policy_loss = l1_mean  # + mse_loss

                    policy_loss.backward()

                    l1_losses += l1_mean.item()
                    mse_losses += mse_loss.item()
                    total_values += value.sum().item() / denom

        torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                       self.args.grad_clip)
        self.optimizer.step()
        self.optimizer.zero_grad()

        self.model.train()
        return l1_losses, mse_losses, total_values

    def __update_executor_ppo(self, win_batches, loss_batches):
        assert self.__trainable

        self.model.train()
        self.optimizer.zero_grad()

        mse = torch.nn.MSELoss()

        action_loss_mean = 0
        value_loss_mean = 0
        entropy_mean = 0
        num_iters = 0

        ##TODO: Add add discount factor

        for epoch in range(self.args.ppo_epochs):

            for (kind, r), batches in zip(reward_tuple,
                                          [win_batches, loss_batches]):

                for game_id, batch in batches.items():
                    episode_len = batch["reward"].size(0)
                    intervals = list(
                        range(0, episode_len, self.args.train_batch_size))
                    interval_tuples_iter = zip(intervals,
                                               intervals[1:] + [episode_len])

                    for (s, e) in interval_tuples_iter:
                        sliced_batch = slice_batch(batch, s, e)
                        (
                            log_prob,
                            old_exec_log_probs,
                            entropy,
                            value,
                        ) = self.model.get_executor_ppo_train_loss(
                            sliced_batch)
                        adv = r * torch.ones_like(value) - value.detach()

                        ratio = torch.exp(log_prob - old_exec_log_probs)
                        surr1 = ratio * adv
                        surr2 = (torch.clamp(ratio, 1.0 - self.args.ppo_eps,
                                             1.0 + self.args.ppo_eps) * adv)

                        action_loss = -torch.min(surr1, surr2).mean()
                        value_loss = 1.0 * mse(value,
                                               r * torch.ones_like(value))
                        entropy_loss = -1.0 * entropy  # .mean()

                        self.optimizer.zero_grad()
                        policy_loss = (action_loss + value_loss).mean()
                        policy_loss.backward()
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), 0.5 * self.args.grad_clip)
                        self.optimizer.step()

                        action_loss_mean += action_loss.item()
                        value_loss_mean += value_loss.item()
                        entropy_mean += entropy_loss  # .item()
                        num_iters += 1

        action_loss_mean = action_loss_mean / num_iters
        value_loss_mean = value_loss_mean / num_iters
        entropy_mean = entropy_mean / num_iters

        self.model.train()
        return action_loss_mean, value_loss_mean, entropy_mean

    def save_executor(self, index):
        assert self.save_folder is not None

        model_file = os.path.join(
            self.save_folder, f"best_exec_checkpoint{self.tag}_{index}.pt")
        print("Saving model exec to: ", model_file)
        self.model.executor.save(model_file)
        wandb.save(model_file)
        wandb.save(model_file + ".params")

        model_file = os.path.join(self.save_folder,
                                  f"best_exec_checkpoint{self.tag}.pt")
        print("Saving model exec to: ", model_file)
        self.model.executor.save(model_file)
        wandb.save(model_file)
        wandb.save(model_file + ".params")

    ####################
    ## Coach specific ##
    ####################

    def train_coach(self, gen_id):
        assert self.__trainable

        if len(self.sa_buffer):
            win_batches, loss_batches = self.sa_buffer.pop(self.result_dict)

            if self.pg == "vanilla":
                l1_loss, mse_loss, value = self.__update_coach_vanilla(
                    win_batches, loss_batches)

                if self.tb_log:
                    wandb.log(
                        {
                            "Loss/Coach/Action-Loss": l1_loss,
                            "Loss/Coach/Value-Loss": mse_loss,
                            "Loss/Coach/Value": value,
                        },
                        step=gen_id,
                    )

                    self.tb_writer.add_scalar("Loss/Coach/Action-Loss",
                                              l1_loss, gen_id)
                    self.tb_writer.add_scalar("Loss/Coach/Value-Loss",
                                              mse_loss, gen_id)
                    self.tb_writer.add_scalar("Loss/Coach/Value", value,
                                              gen_id)

            elif self.pg == "ppo":
                action_loss, value_loss, entropy = self.__update_coach_ppo(
                    win_batches, loss_batches)

                if self.tb_log:
                    wandb.log(
                        {
                            "Loss/Coach/Action-Loss": action_loss,
                            "Loss/Coach/Value-Loss": value_loss,
                            "Loss/Coach/Entropy": entropy,
                        },
                        step=gen_id,
                    )

                    self.tb_writer.add_scalar("Loss/Coach/Action-Loss",
                                              action_loss, gen_id)
                    self.tb_writer.add_scalar("Loss/Coach/Value-Loss",
                                              value_loss, gen_id)
                    self.tb_writer.add_scalar("Loss/Coach/Entropy", entropy,
                                              gen_id)
                    return win_batches, loss_batches
            else:
                raise NotImplementedError

        else:
            print("State-Action Buffer is empty.")

    def __update_coach_vanilla(self, win_batches, loss_batches):
        assert self.__trainable

        self.model.train()
        self.optimizer.zero_grad()

        mse = torch.nn.MSELoss(reduction="none")

        denom = np.sum([
            y["inst"].size()[0] for x, y in win_batches.items()
        ]) + np.sum([y["inst"].size()[0] for x, y in loss_batches.items()])

        l1_loss_mean = 0
        mse_loss_mean = 0

        total_value = 0
        for (kind, r), batches in zip(reward_tuple,
                                      [win_batches, loss_batches]):
            # if kind == "win":
            #     denom = np.sum([y['inst'].size()[0] for x, y in win_batches.items()])
            # else:
            #     denom = np.sum([y['inst'].size()[0] for x, y in loss_batches.items()])

            for game_id, batch in batches.items():
                log_prob, value = self.model.get_coach_vanilla_rl_train_loss(
                    batch)
                l1 = (r * torch.ones_like(value) - value.detach()) * log_prob
                l2 = mse(value, r * torch.ones_like(value))

                l1_mean = -1.0 * (l1.sum()) / denom
                mse_loss = 0.1 * (l2.sum()) / denom
                policy_loss = l1_mean + mse_loss

                policy_loss.backward()
                l1_loss_mean += l1_mean.item()
                mse_loss_mean += mse_loss.item()

                total_value += value.sum().item() / denom

            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           self.args.grad_clip * 10)
            self.optimizer.step()
            self.optimizer.zero_grad()

        self.model.train()
        return l1_loss_mean, mse_loss_mean, total_value

    def __update_coach_ppo(self, win_batches, loss_batches):
        assert self.__trainable

        self.model.train()
        self.optimizer.zero_grad()

        mse = torch.nn.MSELoss()

        action_loss_mean = 0
        value_loss_mean = 0
        entropy_mean = 0
        num_iters = 0

        ##TODO: Add add discount factor

        for epoch in range(self.args.ppo_epochs):

            for (kind, r), batches in zip(reward_tuple,
                                          [win_batches, loss_batches]):

                for game_id, batch in batches.items():
                    episode_len = batch["reward"].size(0)
                    intervals = list(
                        range(0, episode_len, self.args.train_batch_size))
                    interval_tuples_iter = zip(intervals,
                                               intervals[1:] + [episode_len])

                    for (s, e) in interval_tuples_iter:
                        sliced_batch = slice_batch(batch, s, e)
                        (
                            log_prob,
                            old_coach_log_probs,
                            entropy,
                            value,
                        ) = self.model.get_coach_ppo_rl_train_loss(
                            sliced_batch)
                        adv = r * torch.ones_like(value) - value.detach()

                        ratio = torch.exp(log_prob - old_coach_log_probs)
                        surr1 = ratio * adv
                        surr2 = (torch.clamp(ratio, 1.0 - self.args.ppo_eps,
                                             1.0 + self.args.ppo_eps) * adv)

                        action_loss = -torch.min(surr1, surr2).mean()
                        value_loss = 1.0 * mse(value,
                                               r * torch.ones_like(value))
                        entropy_loss = 0.0  # -0.001*entropy.mean()

                        self.optimizer.zero_grad()
                        policy_loss = (
                            action_loss +
                            value_loss).mean()  # + entropy_loss).mean()
                        policy_loss.backward()
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), 0.5 * self.args.grad_clip)
                        self.optimizer.step()

                        action_loss_mean += action_loss.item()
                        value_loss_mean += value_loss.item()
                        entropy_mean += 0  # entropy_loss.item()
                        num_iters += 1

        action_loss_mean = action_loss_mean / num_iters
        value_loss_mean = value_loss_mean / num_iters
        entropy_mean = entropy_mean / num_iters

        self.model.train()
        return action_loss_mean, value_loss_mean, entropy_mean

    def save_coach(self, index):
        assert self.save_folder is not None

        model_file = os.path.join(
            self.save_folder, f"best_coach_checkpoint{self.tag}_{index}.pt")
        print("Saving model coach to: ", model_file)
        self.model.coach.save(model_file)
        wandb.save(model_file)
        wandb.save(model_file + ".params")

        model_file = os.path.join(self.save_folder,
                                  f"best_coach_checkpoint{self.tag}.pt")
        print("Saving model coach to: ", model_file)
        self.model.coach.save(model_file)
        wandb.save(model_file)
        wandb.save(model_file + ".params")

    ## Dual RL Training
    def train_both(self, gen_id):
        assert self.__trainable
        assert self.args.sampling_freq >= 1.0

        if len(self.sa_buffer):

            win_batches, loss_batches = self.sa_buffer.pop(self.result_dict)
            # self.align_executor_actions(win_batches)
            # self.align_executor_actions(loss_batches)

            if self.pg == "vanilla":
                l1_loss, mse_loss, value = self.__update_both_vanilla(
                    win_batches, loss_batches)

                if self.tb_log:
                    wandb.log(
                        {
                            "Loss/Both/RL-Loss": l1_loss,
                            "Loss/Both/MSE-Loss": mse_loss,
                            "Loss/Both/Value": value,
                        },
                        step=gen_id,
                    )
                    self.tb_writer.add_scalar("Loss/Both/RL-Loss", l1_loss,
                                              gen_id)
                    self.tb_writer.add_scalar("Loss/Both/MSE-Loss", mse_loss,
                                              gen_id)
                    self.tb_writer.add_scalar("Loss/Both/Value", value, gen_id)

            elif self.pg == "ppo":
                action_loss, value_loss, entropy = self.__update_both_ppo(
                    win_batches, loss_batches)

                if self.tb_log:
                    wandb.log(
                        {
                            "Loss/Both/Action-Loss": action_loss,
                            "Loss/Both/Value-Loss": value_loss,
                            "Loss/Both/Entropy": entropy,
                        },
                        step=gen_id,
                    )

                    self.tb_writer.add_scalar("Loss/Both/Action-Loss",
                                              action_loss, gen_id)
                    self.tb_writer.add_scalar("Loss/Both/Value-Loss",
                                              value_loss, gen_id)
                    self.tb_writer.add_scalar("Loss/Both/Entropy", entropy,
                                              gen_id)
            else:
                raise NotImplementedError
        else:
            print("State-Action Buffer is empty.")

    def __update_both_vanilla(self, win_batches, loss_batches):
        raise NotImplementedError
        assert self.__trainable

        self.model.train()
        self.optimizer.zero_grad()

        mse = torch.nn.MSELoss(reduction="none")

        denom = np.sum([
            y["inst"].size()[0] for x, y in win_batches.items()
        ]) + np.sum([y["inst"].size()[0] for x, y in loss_batches.items()])

        l1_loss_mean = 0
        mse_loss_mean = 0

        total_value = 0
        for (kind, r), batches in zip(reward_tuple,
                                      [win_batches, loss_batches]):
            # if kind == "win":
            #     denom = np.sum([y['inst'].size()[0] for x, y in win_batches.items()])
            # else:
            #     denom = np.sum([y['inst'].size()[0] for x, y in loss_batches.items()])

            for game_id, batch in batches.items():
                log_prob, value = self.model.get_coach_vanilla_rl_train_loss(
                    batch)
                l1 = (r * torch.ones_like(value) - value.detach()) * log_prob
                l2 = mse(value, r * torch.ones_like(value))

                l1_mean = -1.0 * (l1.sum()) / denom
                mse_loss = 0.1 * (l2.sum()) / denom
                policy_loss = l1_mean + mse_loss

                policy_loss.backward()
                l1_loss_mean += l1_mean.item()
                mse_loss_mean += mse_loss.item()

                total_value += value.sum().item() / denom

            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           self.args.grad_clip * 10)
            self.optimizer.step()
            self.optimizer.zero_grad()

        self.model.train()
        return l1_loss_mean, mse_loss_mean, total_value

    def __update_both_ppo(self, win_batches, loss_batches):
        assert self.__trainable

        self.model.train()
        self.optimizer.zero_grad()

        mse = torch.nn.MSELoss()

        action_loss_mean = 0
        value_loss_mean = 0
        entropy_mean = 0
        num_iters = 0

        ##TODO: Add add discount factor

        for epoch in range(self.args.ppo_epochs):

            for (kind, r), batches in zip(reward_tuple,
                                          [win_batches, loss_batches]):

                for game_id, batch in batches.items():
                    episode_len = batch["reward"].size(0)
                    intervals = list(
                        range(0, episode_len, self.args.train_batch_size))
                    interval_tuples_iter = zip(intervals,
                                               intervals[1:] + [episode_len])

                    for (s, e) in interval_tuples_iter:
                        sliced_batch = slice_batch(batch, s, e)
                        (
                            c_log_prob,
                            old_coach_log_probs,
                            c_entropy,
                            value,
                        ) = self.model.get_coach_ppo_rl_train_loss(
                            sliced_batch)
                        (
                            e_log_prob,
                            old_exec_log_probs,
                            e_entropy,
                            value,
                        ) = self.model.get_executor_ppo_train_loss(
                            sliced_batch)

                        log_prob = c_log_prob + e_log_prob
                        old_log_prob = old_coach_log_probs + old_exec_log_probs
                        adv = r * torch.ones_like(value) - value.detach()
                        ratio = torch.exp(log_prob - old_log_prob)
                        surr1 = ratio * adv
                        surr2 = (torch.clamp(ratio, 1.0 - self.args.ppo_eps,
                                             1.0 + self.args.ppo_eps) * adv)
                        action_loss = -torch.min(surr1, surr2).mean()
                        value_loss = 1.0 * mse(value,
                                               r * torch.ones_like(value))
                        entropy_loss = 0.0  # -0.001*entropy.mean()

                        self.optimizer.zero_grad()
                        policy_loss = (
                            action_loss +
                            value_loss).mean()  # + entropy_loss).mean()
                        policy_loss.backward()
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), 0.5 * self.args.grad_clip)
                        self.optimizer.step()

                        action_loss_mean += action_loss.item()
                        value_loss_mean += value_loss.item()
                        entropy_mean += 0  # entropy_loss.item()
                        num_iters += 1

        action_loss_mean = action_loss_mean / num_iters
        value_loss_mean = value_loss_mean / num_iters
        entropy_mean = entropy_mean / num_iters

        self.model.train()
        return action_loss_mean, value_loss_mean, entropy_mean
예제 #7
0
# Copyright (c) Facebook, Inc. and its affiliates.
예제 #8
0
# Copyright (c) Facebook, Inc. and its affiliates.