Пример #1
0
def train(model, device, optimizer, grad_clip, data_loader, epoch, stat):
    assert model.training

    # losses = defaultdict(list)
    for batch_idx, batch in enumerate(data_loader):
        batch = common_utils.to_device(batch, device)
        optimizer.zero_grad()
        # loss, all_losses = model.compute_loss(batch)
        loss, all_losses = model(batch)
        # print(loss.mean())
        # print(all_losses)
        loss = loss.mean()

        loss.backward()
        if grad_clip > 0:
            nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

        optimizer.step()

        for key, val in all_losses.items():
            v = val.mean().item()
            stat[key].feed(v)

            # losses[key].append(val.item())

    # for key, val in losses.items():
    #     print('\t%s: %.5f' % (key, np.mean(val)))
    # return np.mean(losses['loss'])
    return stat['loss'].mean()
Пример #2
0
def evaluate(model, device, data_loader, epoch, stat):
    assert not model.training

    # losses = defaultdict(list)
    for batch_idx, batch in enumerate(data_loader):
        batch = common_utils.to_device(batch, device)
        loss, all_losses = model(batch)
        # for key, val in all_losses.items():
        #     losses[key].append(val.item())
        for key, val in all_losses.items():
            stat[key].feed(val.mean().item())

    # print('eval:')
    #     print('\t%s: %.5f' % (key, np.mean(val)))

    # return np.mean(losses['loss'])
    return stat['loss'].mean()
Пример #3
0
def evaluate(model, device, data_loader, epoch, name, norm_loss):
    assert not model.training

    losses = defaultdict(list)
    t = time.time()
    for batch_idx, batch in enumerate(data_loader):
        batch = common_utils.to_device(batch, device)
        if norm_loss:
            loss, all_losses = model.compute_eval_loss(batch)
        else:
            loss, all_losses = model.compute_loss(batch)

        for key, val in all_losses.items():
            losses[key].append(val.item())

    print('%s epoch: %d, time: %.2f' % (name, epoch, time.time() - t))
    for key, val in losses.items():
        print('\t%s: %.5f' % (key, np.mean(val)))

    return np.mean(losses['loss'])
Пример #4
0
def train(model, device, optimizer, grad_clip, data_loader, epoch):
    assert model.training

    losses = defaultdict(list)
    t = time.time()
    for batch_idx, batch in enumerate(data_loader):
        batch = common_utils.to_device(batch, device)
        optimizer.zero_grad()
        loss, all_losses = model.compute_loss(batch)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()

        for key, val in all_losses.items():
            losses[key].append(val.item())

    print('train epoch: %d, time: %.2f' % (epoch, time.time() - t))
    for key, val in losses.items():
        print('\t%s: %.5f' % (key, np.mean(val)))

    return np.mean(losses['loss'])
Пример #5
0
    game_option = get_game_option(args)
    ai1_option, ai2_option = get_ai_option(args)
    context, act_dc = create_game(ai1_option, ai2_option, game_option)

    device = torch.device('cuda:%d' % args.gpu)
    executor = Executor.load(args.model_path).to(device)
    print('top 500 insts')
    for inst in executor.inst_dict._idx2inst[:500]:
        print(inst)
    executor_wrapper = ExecutorWrapper(None, executor, args.num_instructions,
                                       args.max_raw_chars, False)
    executor_wrapper.train(False)

    context.start()
    dc = pytube.DataChannelManager([act_dc])
    while not context.terminated():
        data = dc.get_input()['act']
        data = to_device(data, device)
        # import IPython
        # IPython.embed()
        import pdb
        pdb.set_trace()
        reply = executor_wrapper.forward(data)

        # reply = {key : reply[key].detach().cpu() for key in reply}
        dc.set_reply('act', reply)
        print('===end of a step===')

    # import IPython
    # IPython.embed()
Пример #6
0
    device = torch.device('cuda:%d' % args.gpu)
    coach = ConvRnnCoach.load(args.coach_path).to(device)
    coach.max_raw_chars = args.max_raw_chars
    executor = Executor.load(args.model_path).to(device)
    executor_wrapper = ExecutorWrapper(coach, executor, coach.num_instructions,
                                       args.max_raw_chars, args.cheat,
                                       args.inst_mode)
    executor_wrapper.train(False)

    game_option = get_game_option(args)
    ai1_option, ai2_option = get_ai_options(args, coach.num_instructions)

    context, act_dc = create_game(args.num_thread, ai1_option, ai2_option,
                                  game_option)
    context.start()
    dc = DataChannelManager([act_dc])

    result_stat = ResultStat('reward', None)
    while not context.terminated():
        data = dc.get_input(max_timeout_s=1)
        if len(data) == 0:
            continue
        data = to_device(data['act'], device)
        result_stat.feed(data)
        reply = executor_wrapper.forward(data)

        dc.set_reply('act', reply)

    print(result_stat.log(0))
    dc.terminate()
Пример #7
0
# Copyright (c) Facebook, Inc. and its affiliates.
Пример #8
0
    result1 = ResultStat('reward', None)
    result2 = ResultStat('reward', None)
    i = 0
    while not context.terminated():
        i += 1
        if i % 1000 == 0:
            print('%d, progress agent1: win %d, loss %d' %
                  (i, result1.win, result1.loss))

        data = dc.get_input(max_timeout_s=1)
        if len(data) == 0:
            continue
        for key in data:
            # print(key)
            batch = to_device(data[key], device)
            if key == 'act1':
                result1.feed(batch)
                with torch.no_grad():
                    reply = model1.forward(batch)
            elif key == 'act2':
                result2.feed(batch)
                with torch.no_grad():
                    reply = model2.forward(batch)
            else:
                assert False
            dc.set_reply(key, reply)

    print(result1.log(0))
    print(result2.log(0))
    dc.terminate()
Пример #9
0
    def analyze_rule_games(self,
                           epoch,
                           rule_idx,
                           split="valid",
                           viz=False,
                           num_games=100,
                           num_sp=0):
        device = torch.device("cuda:%d" % self.args.gpu)
        num_games = num_games

        if split == "valid":
            permute = self.valid_permute
        elif split == "test":
            permute = self.test_permute
        elif split == "train":
            permute = self.train_permute
        else:
            raise Exception("Invalid split.")

        cur_iter_idx = 0
        results = {}
        for rule_id in rule_idx:  ##TODO: Not randomized
            rule = permute[rule_id]
            self.init_rule_games(rule,
                                 num_sp=num_sp,
                                 num_rb=num_games,
                                 viz=viz)
            agent1, agent2 = self.start()

            agent1.eval()
            agent2.eval()

            if num_sp > 0:
                pbar = tqdm(total=num_games * 2 + num_sp)
            else:
                pbar = tqdm(total=num_games)

            while not self.finished():

                data = self.get_input()

                if len(data) == 0:
                    continue
                for key in data:
                    # print(key)
                    batch = to_device(data[key], device)

                    if key == "act1":
                        batch["actor"] = "act1"
                        reply = agent1.simulate(cur_iter_idx, batch)
                        t_count = agent1.update_logs(cur_iter_idx, batch,
                                                     reply)

                    elif key == "act2":
                        batch["actor"] = "act2"
                        reply = agent2.simulate(cur_iter_idx, batch)
                        t_count = agent2.update_logs(cur_iter_idx, batch,
                                                     reply)

                    else:
                        assert False

                    self.set_reply(key, reply)
                    pbar.update(t_count)

            a1_result = self.agent1.result

            results[rule_id] = {
                "win": a1_result.win / a1_result.num_games,
                "loss": a1_result.loss / a1_result.num_games,
            }

            if num_sp > 0:
                print("#" * 50)
                print(f"Win: {a1_result.win / a1_result.num_games}")
                print(
                    f"Loss: {self.agent2.result.win / self.agent2.result.num_games}"
                )
                print(
                    f"Draw: {(a1_result.loss - self.agent2.result.win) / self.agent2.result.num_games}"
                )
                print("#" * 50)

            cur_iter_idx += 1
            print(results)
            # counter = Counter()
            # for game_id, insts in agent1.traj_dict.items():
            #     for inst in insts:
            #         counter[inst] += 1
            #
            # print("##### TOP N Instructions #####")
            # print(counter.most_common(10))
            # print("##############################")

            pbar.close()
            self.terminate()

        avg_win_rate = 0
        for rule, wl in results.items():
            wandb.log({
                "{}/{}/Win".format(split, rule): wl["win"],
                "{}/{}/Loss".format(split, rule): wl["loss"],
            })
            avg_win_rate += wl["win"] / len(rule_idx)

        if num_sp == 0:
            print("Average win rate: {}".format(avg_win_rate))

        return results
Пример #10
0
    def evaluate(self, epoch, split="valid", num_rules=5):
        print("Validating...")
        device = torch.device("cuda:%d" % self.args.gpu)
        num_games = 100

        if split == "valid":
            permute = self.valid_permute
        elif split == "test":
            permute = self.test_permute
        elif split == "train":
            permute = self.train_permute
        else:
            raise Exception("Invalid split.")

        cur_iter_idx = 0
        results = {}
        for rule_idx in range(num_rules):  ##TODO: Not randomized
            rule = permute[rule_idx]
            self.init_rule_games(rule, num_sp=0, num_rb=num_games)
            print(f"Validating on rule ({rule_idx}): {rule}")

            agent1, agent2 = self.start()

            agent1.eval()
            agent2.eval()

            pbar = tqdm(total=num_games)

            while not self.finished():

                data = self.get_input()

                if len(data) == 0:
                    continue
                for key in data:
                    # print(key)
                    batch = to_device(data[key], device)

                    rule_tensor = (torch.tensor([
                        UNIT_DICT[unit] for unit in rule
                    ]).to(device).repeat(batch["game_id"].size(0), 1))
                    batch["rule_tensor"] = rule_tensor

                    if key == "act1":
                        batch["actor"] = "act1"
                        reply = agent1.simulate(cur_iter_idx, batch)
                        t_count = agent1.update_logs(cur_iter_idx, batch,
                                                     reply)

                    elif key == "act2":
                        batch["actor"] = "act2"
                        reply = agent2.simulate(cur_iter_idx, batch)
                        t_count = agent2.update_logs(cur_iter_idx, batch,
                                                     reply)

                    else:
                        assert False

                    self.set_reply(key, reply)
                    pbar.update(t_count)

            a1_result = self.agent1.result

            results[rule_idx] = {
                "win": a1_result.win / a1_result.num_games,
                "loss": a1_result.loss / a1_result.num_games,
            }

            cur_iter_idx += 1
            pbar.close()
            self.terminate()

        avg_win_rate = 0
        for rule, wl in results.items():
            wandb.log(
                {
                    "Zero/{}/{}/Win".format(split, rule): wl["win"],
                    "Zero/{}/{}/Loss".format(split, rule): wl["loss"],
                },
                step=epoch,
            )
            avg_win_rate += wl["win"] / num_rules

        print("Average win rate: {}".format(avg_win_rate))
        if avg_win_rate > self.agent1.best_test_win_pct:
            self.agent1.best_test_win_pct = avg_win_rate
            self.agent1.save_coach(epoch)
            self.agent1.save_executor(epoch)

        return results
Пример #11
0
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse
import os
import sys
import pprint
from set_path import append_sys_path
append_sys_path()
import torch
import tube
from pytube import DataChannelManager
import minirts
import numpy as np
import random
import pickle
from collections import defaultdict
from rnn_coach import ConvRnnCoach
from onehot_coach import ConvOneHotCoach
from rnn_generator import RnnGenerator
from itertools import groupby
from executor_wrapper import ExecutorWrapper
from executor import Executor
from common_utils import to_device, ResultStat, Logger
from best_models import best_executors, best_coaches
from tqdm import tqdm
p1dict = defaultdict(list)
Пример #12
0
def run_eval(args, model1, model2, device, num_games=100):

    num_eval_games = num_games

    result1 = ResultStat("reward", None)
    result2 = ResultStat("reward", None)

    game_option = get_game_option(args)
    ai1_option, ai2_option = get_ai_options(
        args, [model1.coach.num_instructions, model2.coach.num_instructions])

    if args.opponent == "sp":
        context, act1_dc, act2_dc = init_mt_games(num_eval_games, 0, args,
                                                  ai1_option, ai2_option,
                                                  game_option)
        pbar = tqdm(total=num_eval_games * 2)
    else:
        context, act1_dc, act2_dc = init_mt_games(0, num_eval_games, args,
                                                  ai1_option, ai2_option,
                                                  game_option)
        pbar = tqdm(total=num_eval_games)
    # context, act1_dc, act2_dc = init_games(
    #     num_eval_games, ai1_option, ai2_option, game_option)
    context.start()
    dc = DataChannelManager([act1_dc, act2_dc])

    i = 0
    model1.eval()
    model2.eval()

    while not context.terminated():
        i += 1
        # if i % 1000 == 0:
        #     print('%d, progress agent1: win %d, loss %d' % (i, result1.win, result1.loss))

        data = dc.get_input(max_timeout_s=1)
        if len(data) == 0:
            continue
        for key in data:
            # print(key)
            batch = to_device(data[key], device)
            if key == "act1":
                batch["actor"] = "act1"
                ## Add batches to state table using sampling before adding
                ## Add based on the game_id

                result1.feed(batch)
                with torch.no_grad():
                    reply, _ = model1.forward(batch)  # , exec_sample=True)

            elif key == "act2":
                batch["actor"] = "act2"
                result2.feed(batch)

                with torch.no_grad():
                    reply, _ = model2.forward(batch)

            else:
                assert False

            dc.set_reply(key, reply)

            game_ids = batch["game_id"].cpu().numpy()
            terminals = batch["terminal"].cpu().numpy().flatten()

            for i, g_id in enumerate(game_ids):
                if terminals[i] == 1:
                    pbar.update(1)

    model1.eval()
    model2.eval()
    pbar.close()

    return result1, result2
Пример #13
0
def self_play(args):

    wandb.init(project="adapt-minirts-pop",
               sync_tensorboard=True,
               dir=args.wandb_dir)
    # run_id = f"multitask-fixed_selfplay-{args.coach1}-{args.executor1}-{args.train_mode}-rule{args.rule}-{args.tag}"
    log_series = ",".join(args.rule_series)
    wandb.run.name = (
        f"multitask-pop_selfplay-{wandb.run.id}-{args.coach1}-{args.executor1}"
        f"-{args.train_mode}-rule_series={log_series}-random_coach={args.coach_random_init}-{args.tag}"
    )
    # wandb.run.save()
    wandb.config.update(args)

    print("args:")
    pprint.pprint(vars(args))

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    print("Train Mode: {}".format(args.train_mode))

    if args.coach_reload:
        print("Reloading coach model.... ")
        args.coach1 = args.coach_load_file
        _coach1 = os.path.basename(args.coach1).replace(".pt", "")

    else:
        _coach1 = args.coach1
        args.coach1 = best_coaches[args.coach1]

    if args.exec_reload:
        print("Reloading executor model.... ")
        args.executor1 = args.exec_load_file
        _executor1 = os.path.basename(args.executor1).replace(".pt", "")
    else:
        _executor1 = args.executor1
        args.executor1 = best_executors[args.executor1]

    log_name = "multitask-pop_c1_type={}_c2_type={}__e1_type={}_e2_type={}__lr={}__num_sp={}__num_rb={}_coach_random_init={}_{}_{}".format(
        _coach1,
        args.coach2,
        _executor1,
        args.executor2,
        args.lr,
        args.num_sp,
        args.num_rb,
        args.coach_random_init,
        args.tag,
        random.randint(1111, 9999),
    )
    writer = SummaryWriter(comment=log_name)

    args.coach2 = best_coaches[args.coach2]
    args.executor2 = best_executors[args.executor2]

    logger_path = os.path.join(args.save_dir, "train.log")

    sys.stdout = Logger(logger_path)

    device = torch.device("cuda:%d" % args.gpu)

    agent_dict = {}
    for rule in args.rule_series:
        sp_agent = Agent(
            coach=args.coach1,
            executor=args.executor1,
            device=device,
            args=args,
            writer=writer,
            trainable=True,
            exec_sample=True,
            pg=args.pg,
            tag=f"_{rule}",
        )

        ## Sharing executors
        args.executor1 = sp_agent.model.executor
        sp_agent.init_save_folder(wandb.run.name)

        bc_agent = Agent(
            coach=args.coach2,
            executor=args.executor2,
            device=device,
            args=args,
            writer=writer,
            trainable=False,
            exec_sample=False,
            tag=f"_{rule}",
        )

        agent_dict[int(rule)] = {"sp_agent": sp_agent, "bc_agent": bc_agent}

    if args.same_opt:
        params = []
        for k, v in agent_dict.items():
            agent = v["sp_agent"]
            coach_params = list(agent.model.coach.parameters())
            params += coach_params

        params += list(agent.model.executor.parameters())
        optimizer = optim.Adam(params, lr=args.lr)

        for k, v in agent_dict.items():
            agent = v["sp_agent"]
            agent.set_optimizer(optimizer)

    print("Progress: ")
    ## Create Save folder:
    working_rule_dir = os.path.join(sp_agent.save_folder, "rules")
    create_working_dir(args, working_rule_dir)

    cur_iter_idx = 1
    rules = [int(str_rule) for str_rule in args.rule_series]
    agg_agents = []
    agg_win_batches = defaultdict(dict)
    agg_loss_batches = defaultdict(dict)

    for epoch in range(args.train_epochs):
        for rule_idx in rules:
            if cur_iter_idx % args.eval_factor == 0:
                for eval_rule_idx in rules:
                    sp_agent = agent_dict[eval_rule_idx]["sp_agent"]
                    bc_agent = agent_dict[eval_rule_idx]["bc_agent"]
                    game = MultiTaskGame(sp_agent, bc_agent, cur_iter_idx,
                                         args, working_rule_dir)
                    game.evaluate_lifelong_rules(cur_iter_idx, [eval_rule_idx],
                                                 "train")
                    game.terminate()
                    del game

            sp_agent = agent_dict[rule_idx]["sp_agent"]
            bc_agent = agent_dict[rule_idx]["bc_agent"]

            print("Current rule: {}".format(rule_idx))
            game = MultiTaskGame(sp_agent, bc_agent, cur_iter_idx, args,
                                 working_rule_dir)

            rule = game.train_permute[rule_idx]
            print("Current rule: {}".format(rule))
            game.init_rule_games(rule)
            agent1, agent2 = game.start()

            agent1.train()
            agent2.train()

            pbar = tqdm(total=(args.num_sp * 2 + args.num_rb))

            while not game.finished():

                data = game.get_input()

                if len(data) == 0:
                    continue
                for key in data:
                    # print(key)
                    batch = to_device(data[key], device)

                    if key == "act1":
                        batch["actor"] = "act1"
                        reply = agent1.simulate(cur_iter_idx, batch)
                        t_count = agent1.update_logs(cur_iter_idx, batch,
                                                     reply)

                    elif key == "act2":
                        batch["actor"] = "act2"
                        reply = agent2.simulate(cur_iter_idx, batch)
                        t_count = agent2.update_logs(cur_iter_idx, batch,
                                                     reply)

                    else:
                        assert False

                    game.set_reply(key, reply)
                    pbar.update(t_count)

            if not args.split_train:
                if args.train_mode == "coach":
                    agent1.train_coach(cur_iter_idx)
                elif args.train_mode == "executor":
                    agent1.train_executor(cur_iter_idx)
                elif args.train_mode == "both":
                    agent1.train_both(cur_iter_idx)
                else:
                    raise Exception("Invalid train mode.")
                game.print_logs(cur_iter_idx)
                game.terminate()
            else:

                if cur_iter_idx % len(rules):
                    win_batches, loss_batches = agent1.train_coach(
                        cur_iter_idx)
                    agg_win_batches.update(win_batches)
                    agg_loss_batches.update(loss_batches)
                    agg_agents.append((agent1, agent2))
                    game.print_logs(cur_iter_idx)
                    game.terminate(keep_agents=True)

                else:
                    win_batches, loss_batches = agent1.train_coach(
                        cur_iter_idx)

                    agg_win_batches.update(win_batches)
                    agg_loss_batches.update(loss_batches)

                    # Change shuffling
                    agent1.train_executor(
                        cur_iter_idx,
                        agg_win_batches=agg_win_batches,
                        agg_loss_batches=agg_loss_batches,
                    )

                    for agent1, agent2 in agg_agents:
                        agent1.reset()
                        agent2.reset()

                    game.print_logs(cur_iter_idx)
                    game.terminate()

                    del agg_loss_batches
                    del agg_win_batches

                    agg_win_batches = defaultdict(dict)
                    agg_loss_batches = defaultdict(dict)

                cur_iter_idx += 1
                wandb.run.summary[f"max_iterations"] = cur_iter_idx
                pbar.close()

        del game

    writer.close()
Пример #14
0
def self_play(args):

    wandb.init(project="adapt-minirts-zero",
               sync_tensorboard=True,
               dir=args.wandb_dir)
    # run_id = f"multitask-fixed_selfplay-{args.coach1}-{args.executor1}-{args.train_mode}-rule{args.rule}-{args.tag}"
    wandb.run.name = (
        f"multitask-zero_selfplay-{wandb.run.id}-{args.coach1}-{args.executor1}"
        f"-{args.train_mode}-{args.tag}")
    # wandb.run.save()
    wandb.config.update(args)

    print("args:")
    pprint.pprint(vars(args))

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    print("Train Mode: {}".format(args.train_mode))

    if args.coach_reload:
        print("Reloading coach model.... ")
        args.coach1 = args.coach_load_file
        _coach1 = os.path.basename(args.coach1).replace(".pt", "")

    else:
        _coach1 = args.coach1
        args.coach1 = best_coaches[args.coach1]

    if args.exec_reload:
        print("Reloading executor model.... ")
        args.executor1 = args.exec_load_file
        _executor1 = os.path.basename(args.executor1).replace(".pt", "")
    else:
        _executor1 = args.executor1
        args.executor1 = best_executors[args.executor1]

    log_name = (
        f"multitask-zero_c1_type={_coach1}_c2_type={args.coach2}__e1_type={_executor1}_e2_type={args.executor2}__lr={args.lr}_coach_emb"
        f"={args.coach_rule_emb_size}_exec_emb={args.executor_rule_emb_size}__num_sp={args.num_sp}__num_rb={args.num_rb}_{args.tag}_{random.randint(1111, 99999)}"
    )
    writer = SummaryWriter(comment=log_name)

    args.coach2 = best_coaches[args.coach2]
    args.executor2 = best_executors[args.executor2]

    logger_path = os.path.join(args.save_dir, "train.log")

    sys.stdout = Logger(logger_path)

    device = torch.device("cuda:%d" % args.gpu)

    sp_agent = Agent(
        coach=args.coach1,
        executor=args.executor1,
        device=device,
        args=args,
        writer=writer,
        trainable=True,
        exec_sample=True,
        pg=args.pg,
    )

    sp_agent.init_save_folder(wandb.run.name)

    bc_agent = Agent(
        coach=args.coach2,
        executor=args.executor2,
        device=device,
        args=args,
        writer=writer,
        trainable=False,
        exec_sample=False,
    )

    print("Progress: ")
    ## Create Save folder:
    working_rule_dir = os.path.join(sp_agent.save_folder, "rules")
    create_working_dir(args, working_rule_dir)

    cur_iter_idx = 1
    for epoch in range(args.train_epochs):
        print("Current epoch: {}".format(epoch))

        game = MultiTaskGame(sp_agent, bc_agent, epoch, args, working_rule_dir)
        # game.evaluate(epoch, 'valid', 3)

        for rule_idx in range(game.num_train_rules):

            # if rule_idx%args.eval_factor == 0:
            #     game.evaluate(epoch*game.num_train_rules + rule_idx, 'valid', 10)

            rule = game.train_permute[rule_idx]
            print(f"Current rule ({rule_idx}): {rule}")
            game.init_rule_games(rule)
            agent1, agent2 = game.start()

            agent1.train()
            agent2.train()

            pbar = tqdm(total=(args.num_sp * 2 + args.num_rb))

            while not game.finished():

                data = game.get_input()

                if len(data) == 0:
                    continue
                for key in data:

                    batch = to_device(data[key], device)
                    rule_tensor = (torch.tensor([
                        UNIT_DICT[unit] for unit in rule
                    ]).to(device).repeat(batch["game_id"].size(0), 1))
                    batch["rule_tensor"] = rule_tensor

                    if key == "act1":
                        batch["actor"] = "act1"
                        reply = agent1.simulate(cur_iter_idx, batch)
                        t_count = agent1.update_logs(cur_iter_idx, batch,
                                                     reply)

                    elif key == "act2":
                        batch["actor"] = "act2"
                        reply = agent2.simulate(cur_iter_idx, batch)
                        t_count = agent2.update_logs(cur_iter_idx, batch,
                                                     reply)

                    else:
                        assert False

                    game.set_reply(key, reply)
                    pbar.update(t_count)

            cur_iter_idx += 1
            pbar.close()

            if cur_iter_idx % args.update_iter:
                if args.train_mode == "coach":
                    agent1.train_coach(cur_iter_idx)
                elif args.train_mode == "executor":
                    agent1.train_executor(cur_iter_idx)
                elif args.train_mode == "both":
                    agent1.train_both(cur_iter_idx)
                else:
                    raise Exception("Invalid train mode.")
                game.print_logs(cur_iter_idx)
                game.terminate()
            else:
                game.terminate(keep_agents=True)

        del game

    writer.close()
Пример #15
0
def self_play(args):

    wandb.init(project="adapt-minirts",
               sync_tensorboard=True,
               dir=args.wandb_dir)
    # run_id = f"multitask-fixed_selfplay-{args.coach1}-{args.executor1}-{args.train_mode}-rule{args.rule}-{args.tag}"
    wandb.run.name = (
        f"multitask-fixed_selfplay-{wandb.run.id}-{args.coach1}-{args.executor1}"
        f"-{args.train_mode}-rule{args.rule}-{args.tag}")
    # wandb.run.save()
    wandb.config.update(args)

    print("args:")
    pprint.pprint(vars(args))

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    print("Train Mode: {}".format(args.train_mode))

    if args.coach_reload:
        print("Reloading coach model.... ")
        args.coach1 = args.coach_load_file
        _coach1 = os.path.basename(args.coach1).replace(".pt", "")

    else:
        _coach1 = args.coach1
        args.coach1 = best_coaches[args.coach1]

    if args.exec_reload:
        print("Reloading executor model.... ")
        args.executor1 = args.exec_load_file
        _executor1 = os.path.basename(args.executor1).replace(".pt", "")
    else:
        _executor1 = args.executor1
        args.executor1 = best_executors[args.executor1]

    log_name = "multitask-fixed_c1_type={}_c2_type={}__e1_type={}_e2_type={}__lr={}__num_sp={}__num_rb={}_{}_{}".format(
        _coach1,
        args.coach2,
        _executor1,
        args.executor2,
        args.lr,
        args.num_sp,
        args.num_rb,
        args.tag,
        random.randint(1111, 9999),
    )
    writer = SummaryWriter(comment=log_name)

    args.coach2 = best_coaches[args.coach2]
    args.executor2 = best_executors[args.executor2]

    logger_path = os.path.join(args.save_dir, "train.log")

    sys.stdout = Logger(logger_path)

    device = torch.device("cuda:%d" % args.gpu)

    sp_agent = Agent(
        coach=args.coach1,
        executor=args.executor1,
        device=device,
        args=args,
        writer=writer,
        trainable=True,
        exec_sample=True,
        pg=args.pg,
    )

    sp_agent.init_save_folder(wandb.run.name)

    bc_agent = Agent(
        coach=args.coach2,
        executor=args.executor2,
        device=device,
        args=args,
        writer=writer,
        trainable=False,
        exec_sample=False,
    )

    print("Progress: ")
    ## Create Save folder:
    working_rule_dir = os.path.join(sp_agent.save_folder, "rules")
    create_working_dir(args, working_rule_dir)

    cur_iter_idx = 1
    rules = [args.rule]
    for rule_idx in rules:

        print("Current rule: {}".format(rule_idx))
        game = MultiTaskGame(sp_agent, bc_agent, cur_iter_idx, args,
                             working_rule_dir)

        for epoch in range(args.train_epochs):
            if epoch % args.eval_factor == 0:
                game.evaluate_rules(cur_iter_idx, rules, "train")

            rule = game.train_permute[rule_idx]
            print("Current rule: {}".format(rule))
            game.init_rule_games(rule)
            agent1, agent2 = game.start()

            agent1.train()
            agent2.train()

            pbar = tqdm(total=(args.num_sp * 2 + args.num_rb))

            while not game.finished():

                data = game.get_input()

                if len(data) == 0:
                    continue
                for key in data:
                    # print(key)
                    batch = to_device(data[key], device)

                    if key == "act1":
                        batch["actor"] = "act1"
                        reply = agent1.simulate(cur_iter_idx, batch)
                        t_count = agent1.update_logs(cur_iter_idx, batch,
                                                     reply)

                    elif key == "act2":
                        batch["actor"] = "act2"
                        reply = agent2.simulate(cur_iter_idx, batch)
                        t_count = agent2.update_logs(cur_iter_idx, batch,
                                                     reply)

                    else:
                        assert False

                    game.set_reply(key, reply)
                    pbar.update(t_count)

            if args.train_mode == "coach":
                agent1.train_coach(cur_iter_idx)
            elif args.train_mode == "executor":
                agent1.train_executor(cur_iter_idx)
            elif args.train_mode == "both":
                agent1.train_both(cur_iter_idx)
            else:
                raise Exception("Invalid train mode.")

            game.print_logs(cur_iter_idx)
            cur_iter_idx += 1
            wandb.run.summary[f"max_iterations"] = cur_iter_idx
            pbar.close()
            game.terminate()

        del game

    writer.close()
Пример #16
0
# Copyright (c) Facebook, Inc. and its affiliates.
Пример #17
0
    def analyze_rule_games_vbot(self,
                                epoch,
                                rule_idx,
                                split="valid",
                                viz=False,
                                num_games=100):
        device = torch.device("cuda:%d" % self.args.gpu)
        num_games = num_games

        if split == "valid":
            permute = self.valid_permute
        elif split == "test":
            permute = self.test_permute
        elif split == "train":
            permute = self.train_permute
        else:
            raise Exception("Invalid split.")

        cur_iter_idx = 0
        results = {}
        unitidx = [0, 1, 2, 3, 4]
        botidx = random.choice(unitidx)
        counter = Counter()
        idx2utype = [
            "SWORDMAN",
            "SPEARMAN",
            "CAVALRY",
            "ARCHER",
            "DRAGON",
        ]

        rule = permute[rule_idx]
        rule_rps_dict = rps_dict.copy()
        for i, unit in enumerate(rule):
            rule_rps_dict[UNITS[i]] = rps_dict[unit]

        print("############RULE RPS###################")
        for unit, multiplier in rule_rps_dict.items():
            print(f"{unit}: {multiplier}")
        print("#######################################")

        print(f"Playing against bot {idx2utype[botidx]}")
        self.init_rule_games_vbot(botidx=botidx,
                                  rule=rule,
                                  num_games=num_games,
                                  viz=viz)
        agent1, agent2 = self.start()

        agent1.eval()
        agent2.eval()

        pbar = tqdm(total=num_games)

        while not self.finished():

            data = self.get_input()

            if len(data) == 0:
                continue
            for key in data:
                # print(key)
                batch = to_device(data[key], device)

                if key == "act1":
                    batch["actor"] = "act1"
                    reply = agent1.simulate(cur_iter_idx, batch)
                    t_count = agent1.update_logs(cur_iter_idx, batch, reply)

                elif key == "act2":
                    batch["actor"] = "act2"
                    reply = agent2.simulate(cur_iter_idx, batch)
                    t_count = agent2.update_logs(cur_iter_idx, batch, reply)

                else:
                    assert False

                self.set_reply(key, reply)
                pbar.update(t_count)

        a1_result = self.agent1.result

        results[rule_idx] = {
            "win": a1_result.win / a1_result.num_games,
            "loss": a1_result.loss / a1_result.num_games,
        }

        cur_iter_idx += 1
        print(results)

        for game_id, insts in agent1.traj_dict.items():
            for inst in insts:
                counter[inst] += 1

        pbar.close()
        self.terminate()

        avg_win_rate = 0
        for rule, wl in results.items():
            wandb.log({
                "{}/{}/Win".format(split, rule): wl["win"],
                "{}/{}/Loss".format(split, rule): wl["loss"],
            })
            avg_win_rate += wl["win"]
        print(f"Top-10 Instructions: {counter.most_common(10)}")
        print(f"Average win rate: {avg_win_rate}")
        if avg_win_rate > self.agent1.best_test_win_pct:
            self.agent1.best_test_win_pct = avg_win_rate
            self.agent1.save_coach(epoch)
            self.agent1.save_executor(epoch)

        return idx2utype[botidx], rule_idx, avg_win_rate, counter.most_common(
            10)
Пример #18
0
    def evaluate_lifelong_rules(self, epoch, rule_series, split="train"):
        device = torch.device("cuda:%d" % self.args.gpu)
        num_games = 100

        if split == "valid":
            permute = self.valid_permute
        elif split == "test":
            permute = self.test_permute
        elif split == "train":
            permute = self.train_permute
        else:
            raise Exception("Invalid split.")

        cur_iter_idx = 0
        results = {}
        for rule_id in rule_series:
            rule = permute[rule_id]
            print("Evaluating current rule: {}".format(rule))
            self.init_rule_games(rule, num_sp=0, num_rb=num_games)
            agent1, agent2 = self.start()

            agent1.eval()
            agent2.eval()

            pbar = tqdm(total=num_games)

            while not self.finished():

                data = self.get_input()

                if len(data) == 0:
                    continue
                for key in data:
                    # print(key)
                    batch = to_device(data[key], device)

                    if key == "act1":
                        batch["actor"] = "act1"
                        reply = agent1.simulate(cur_iter_idx, batch)
                        t_count = agent1.update_logs(cur_iter_idx, batch,
                                                     reply)

                    elif key == "act2":
                        batch["actor"] = "act2"
                        reply = agent2.simulate(cur_iter_idx, batch)
                        t_count = agent2.update_logs(cur_iter_idx, batch,
                                                     reply)

                    else:
                        assert False

                    self.set_reply(key, reply)
                    pbar.update(t_count)

            a1_result = self.agent1.result

            results[rule_id] = {
                "win": a1_result.win / a1_result.num_games,
                "loss": a1_result.loss / a1_result.num_games,
            }

            cur_iter_idx += 1
            pbar.close()
            self.terminate()

        avg_win_rate = 0
        for rule, wl in results.items():
            wandb.log({
                "Lifelong/{}/{}/Win".format(split, rule): wl["win"],
                "Lifelong/{}/{}/Loss".format(split, rule): wl["loss"],
            })
            avg_win_rate += wl["win"] / len(rule_series)

        print("Average win rate: {}".format(avg_win_rate))
        if avg_win_rate > self.agent1.best_test_win_pct:
            wandb.run.summary[f"best_win_rate{self.agent1.tag}"] = avg_win_rate
            wandb.run.summary[f"best_iteration{self.agent1.tag}"] = epoch

            self.agent1.best_test_win_pct = avg_win_rate
            self.agent1.save_coach(epoch)
            self.agent1.save_executor(epoch)

        return results
Пример #19
0
    def drift_analysis_games(self,
                             epoch,
                             rule_idx,
                             split="valid",
                             viz=False,
                             num_games=1):
        device = torch.device("cuda:%d" % self.args.gpu)
        num_games = num_games
        permute = self.train_permute

        cur_iter_idx = 0
        results = {}
        reply_dicts = []
        for rule_id in rule_idx:  ##TODO: Not randomized
            rule = permute[rule_id]
            self.init_drift_games(rule, num_sp=0, num_rb=num_games, viz=viz)
            agent1, agent2 = self.start()

            agent1.eval()
            agent2.eval()

            # pbar = tqdm(total=num_games)

            while not self.finished():

                data = self.get_input()

                if len(data) == 0:
                    continue
                for key in data:
                    # print(key)
                    batch = to_device(data[key], device)

                    if key == "act1":
                        batch["actor"] = "act1"
                        reply, replies = agent1.simulate(cur_iter_idx, batch)
                        t_count = agent1.update_logs(cur_iter_idx, batch,
                                                     reply)
                        reply_dicts.append(replies)

                    elif key == "act2":
                        batch["actor"] = "act2"
                        reply = agent2.simulate(cur_iter_idx, batch)
                        t_count = agent2.update_logs(cur_iter_idx, batch,
                                                     reply)

                    else:
                        assert False

                    self.set_reply(key, reply)
                    # pbar.update(t_count)

            a1_result = self.agent1.result

            results[rule_id] = {
                "win": a1_result.win / a1_result.num_games,
                "loss": a1_result.loss / a1_result.num_games,
            }

            cur_iter_idx += 1
            print(results)
            # counter = Counter()
            # for game_id, insts in agent1.traj_dict.items():
            #     for inst in insts:
            #         counter[inst] += 1
            #
            # print("##### TOP N Instructions #####")
            # print(counter.most_common(10))
            # print("##############################")

            # pbar.close()
            self.terminate()

        avg_win_rate = 0
        for rule, wl in results.items():
            wandb.log({
                "{}/{}/Win".format(split, rule): wl["win"],
                "{}/{}/Loss".format(split, rule): wl["loss"],
            })
            avg_win_rate += wl["win"] / len(rule_idx)

        print("Average win rate: {}".format(avg_win_rate))

        return results, reply_dicts
Пример #20
0
# Copyright (c) Facebook, Inc. and its affiliates.
Пример #21
0
# Copyright (c) Facebook, Inc. and its affiliates.