コード例 #1
0
def load_agent(weight_file, overwrite):
    """
    overwrite has to contain "device"
    """
    cfg = get_train_config(weight_file)
    assert cfg is not None

    if "core" in cfg:
        new_cfg = {}
        flatten_dict(cfg, new_cfg)
        cfg = new_cfg

    game = create_envs(
        1,
        1,
        cfg["num_player"],
        cfg["train_bomb"],
        [0],  # explore_eps,
        [100],  # boltzmann_t,
        cfg["max_len"],
        cfg["sad"] if "sad" in cfg else cfg["greedy_extra"],
        cfg["shuffle_obs"],
        cfg["shuffle_color"],
        cfg["hide_action"],
        True,
    )[0]

    config = {
        "vdn":
        overwrite["vdn"] if "vdn" in overwrite else cfg["method"] == "vdn",
        "multi_step":
        overwrite.get("multi_step", cfg["multi_step"]),
        "gamma":
        overwrite.get("gamma", cfg["gamma"]),
        "eta":
        0.9,
        "device":
        overwrite["device"],
        "in_dim":
        game.feature_size(),
        "hid_dim":
        cfg["hid_dim"] if "hid_dim" in cfg else cfg["rnn_hid_dim"],
        "out_dim":
        game.num_action(),
        "num_lstm_layer":
        cfg.get("num_lstm_layer", overwrite.get("num_lstm_layer", 2)),
        "boltzmann_act":
        overwrite.get("boltzmann_act", cfg["boltzmann_act"]),
        "uniform_priority":
        overwrite.get("uniform_priority", False),
    }

    agent = r2d2.R2D2Agent(**config).to(config["device"])
    load_weight(agent.online_net, weight_file, config["device"])
    agent.sync_target_with_online()
    return agent, cfg
コード例 #2
0
def evaluate_legacy_model(weight_files,
                          num_game,
                          seed,
                          bomb,
                          num_run=1,
                          verbose=True):
    # model_lockers = []
    # greedy_extra = 0
    agents = []
    num_player = len(weight_files)
    assert num_player > 1, "1 weight file per player"

    for weight_file in weight_files:
        if verbose:
            print("evaluating: %s\n\tfor %dx%d games" %
                  (weight_file, num_run, num_game))
        if "sad" in weight_file or "aux" in weight_file:
            sad = True
        else:
            sad = False

        device = "cuda:0"

        state_dict = torch.load(weight_file)
        input_dim = state_dict["net.0.weight"].size()[1]
        hid_dim = 512
        output_dim = state_dict["fc_a.weight"].size()[0]

        agent = r2d2.R2D2Agent(False, 3, 0.999, 0.9, device, input_dim,
                               hid_dim, output_dim, 2, 5, False).to(device)
        utils.load_weight(agent.online_net, weight_file, device)
        agents.append(agent)

    scores = []
    perfect = 0
    for i in range(num_run):
        _, _, score, p = evaluate(
            agents,
            num_game,
            num_game * i + seed,
            bomb,
            0,
            sad,
        )
        scores.extend(score)
        perfect += p

    mean = np.mean(scores)
    sem = np.std(scores) / np.sqrt(len(scores))
    perfect_rate = perfect / (num_game * num_run)
    if verbose:
        print("score: %f +/- %f" % (mean, sem), "; perfect: ", perfect_rate)
    return mean, sem, perfect_rate
コード例 #3
0
def load_sad_model(weight_files, device):
    agents = []
    for weight_file in weight_files:
        state_dict = torch.load(weight_file, map_location=device)
        input_dim = state_dict["net.0.weight"].size()[1]
        hid_dim = 512
        output_dim = state_dict["fc_a.weight"].size()[0]

        agent = r2d2.R2D2Agent(False, 3, 0.999, 0.9, device, input_dim,
                               hid_dim, output_dim, 2, 5, False).to(device)
        load_weight(agent.online_net, weight_file, device)
        agents.append(agent)
    return agents
コード例 #4
0
def load_op_model(method, idx1, idx2, device):
    """load op models, op models was trained only for 2 player
    """
    root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    # assume model saved in root/models/op
    folder = os.path.join(root, "models", "op", method)
    agents = []
    for idx in [idx1, idx2]:
        if idx is None:
            continue

        if idx >= 0 and idx < 3:
            num_fc = 1
            skip_connect = False
        elif idx >= 3 and idx < 6:
            num_fc = 1
            skip_connect = True
        elif idx >= 6 and idx < 9:
            num_fc = 2
            skip_connect = False
        else:
            num_fc = 2
            skip_connect = True
        weight_file = os.path.join(folder, f"M{idx}.pthw")
        if not os.path.exists(weight_file):
            print(f"Cannot find weight at: {weight_file}")
            assert False

        state_dict = torch.load(weight_file)
        input_dim = state_dict["net.0.weight"].size()[1]
        hid_dim = 512
        output_dim = state_dict["fc_a.weight"].size()[0]
        agent = r2d2.R2D2Agent(
            False,
            3,
            0.999,
            0.9,
            device,
            input_dim,
            hid_dim,
            output_dim,
            2,
            5,
            False,
            num_fc_layer=num_fc,
            skip_connect=skip_connect,
        ).to(device)
        load_weight(agent.online_net, weight_file, device)
        agents.append(agent)
    return agents
コード例 #5
0
ファイル: eval_model.py プロジェクト: jzand/hanabi_SAD
def load_sad_model(weight_files):
    agents = []
    for weight_file in weight_files:
        if verbose:
            print(
                "evaluating: %s\n\tfor %dx%d games" % (weight_file, num_run, num_game)
            )
        if "sad" in weight_file or "aux" in weight_file:
            sad = True
        else:
            sad = False

        device = "cuda:0"
        state_dict = torch.load(weight_file)
        input_dim = state_dict["net.0.weight"].size()[1]
        hid_dim = 512
        output_dim = state_dict["fc_a.weight"].size()[0]

        agent = r2d2.R2D2Agent(
            False, 3, 0.999, 0.9, device, input_dim, hid_dim, output_dim, 2, 5, False
        ).to(device)
        utils.load_weight(agent.online_net, weight_file, device)
        agents.append(agent)
    return agents
コード例 #6
0
        args.hand_size,
        args.train_bomb,
        explore_eps,
        args.max_len,
        args.sad,
        args.shuffle_obs,
        args.shuffle_color,
    )

    agent = r2d2.R2D2Agent(
        (args.method == "vdn"),
        args.multi_step,
        args.gamma,
        args.eta,
        args.train_device,
        games[0].feature_size(),
        args.rnn_hid_dim,
        games[0].num_action(),
        args.num_lstm_layer,
        args.hand_size,
        False,  # uniform priority
    )
    agent.sync_target_with_online()

    if args.load_model:
        print("*****loading pretrained model*****")
        utils.load_weight(agent.online_net, args.load_model, args.train_device)
        print("*****done*****")

    agent = agent.to(args.train_device)
    optim = torch.optim.Adam(agent.online_net.parameters(),