def load_agent(weight_file, overwrite): """ overwrite has to contain "device" """ cfg = get_train_config(weight_file) assert cfg is not None if "core" in cfg: new_cfg = {} flatten_dict(cfg, new_cfg) cfg = new_cfg game = create_envs( 1, 1, cfg["num_player"], cfg["train_bomb"], [0], # explore_eps, [100], # boltzmann_t, cfg["max_len"], cfg["sad"] if "sad" in cfg else cfg["greedy_extra"], cfg["shuffle_obs"], cfg["shuffle_color"], cfg["hide_action"], True, )[0] config = { "vdn": overwrite["vdn"] if "vdn" in overwrite else cfg["method"] == "vdn", "multi_step": overwrite.get("multi_step", cfg["multi_step"]), "gamma": overwrite.get("gamma", cfg["gamma"]), "eta": 0.9, "device": overwrite["device"], "in_dim": game.feature_size(), "hid_dim": cfg["hid_dim"] if "hid_dim" in cfg else cfg["rnn_hid_dim"], "out_dim": game.num_action(), "num_lstm_layer": cfg.get("num_lstm_layer", overwrite.get("num_lstm_layer", 2)), "boltzmann_act": overwrite.get("boltzmann_act", cfg["boltzmann_act"]), "uniform_priority": overwrite.get("uniform_priority", False), } agent = r2d2.R2D2Agent(**config).to(config["device"]) load_weight(agent.online_net, weight_file, config["device"]) agent.sync_target_with_online() return agent, cfg
def evaluate_legacy_model(weight_files, num_game, seed, bomb, num_run=1, verbose=True): # model_lockers = [] # greedy_extra = 0 agents = [] num_player = len(weight_files) assert num_player > 1, "1 weight file per player" for weight_file in weight_files: if verbose: print("evaluating: %s\n\tfor %dx%d games" % (weight_file, num_run, num_game)) if "sad" in weight_file or "aux" in weight_file: sad = True else: sad = False device = "cuda:0" state_dict = torch.load(weight_file) input_dim = state_dict["net.0.weight"].size()[1] hid_dim = 512 output_dim = state_dict["fc_a.weight"].size()[0] agent = r2d2.R2D2Agent(False, 3, 0.999, 0.9, device, input_dim, hid_dim, output_dim, 2, 5, False).to(device) utils.load_weight(agent.online_net, weight_file, device) agents.append(agent) scores = [] perfect = 0 for i in range(num_run): _, _, score, p = evaluate( agents, num_game, num_game * i + seed, bomb, 0, sad, ) scores.extend(score) perfect += p mean = np.mean(scores) sem = np.std(scores) / np.sqrt(len(scores)) perfect_rate = perfect / (num_game * num_run) if verbose: print("score: %f +/- %f" % (mean, sem), "; perfect: ", perfect_rate) return mean, sem, perfect_rate
def load_sad_model(weight_files, device): agents = [] for weight_file in weight_files: state_dict = torch.load(weight_file, map_location=device) input_dim = state_dict["net.0.weight"].size()[1] hid_dim = 512 output_dim = state_dict["fc_a.weight"].size()[0] agent = r2d2.R2D2Agent(False, 3, 0.999, 0.9, device, input_dim, hid_dim, output_dim, 2, 5, False).to(device) load_weight(agent.online_net, weight_file, device) agents.append(agent) return agents
def load_op_model(method, idx1, idx2, device): """load op models, op models was trained only for 2 player """ root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # assume model saved in root/models/op folder = os.path.join(root, "models", "op", method) agents = [] for idx in [idx1, idx2]: if idx is None: continue if idx >= 0 and idx < 3: num_fc = 1 skip_connect = False elif idx >= 3 and idx < 6: num_fc = 1 skip_connect = True elif idx >= 6 and idx < 9: num_fc = 2 skip_connect = False else: num_fc = 2 skip_connect = True weight_file = os.path.join(folder, f"M{idx}.pthw") if not os.path.exists(weight_file): print(f"Cannot find weight at: {weight_file}") assert False state_dict = torch.load(weight_file) input_dim = state_dict["net.0.weight"].size()[1] hid_dim = 512 output_dim = state_dict["fc_a.weight"].size()[0] agent = r2d2.R2D2Agent( False, 3, 0.999, 0.9, device, input_dim, hid_dim, output_dim, 2, 5, False, num_fc_layer=num_fc, skip_connect=skip_connect, ).to(device) load_weight(agent.online_net, weight_file, device) agents.append(agent) return agents
def load_sad_model(weight_files): agents = [] for weight_file in weight_files: if verbose: print( "evaluating: %s\n\tfor %dx%d games" % (weight_file, num_run, num_game) ) if "sad" in weight_file or "aux" in weight_file: sad = True else: sad = False device = "cuda:0" state_dict = torch.load(weight_file) input_dim = state_dict["net.0.weight"].size()[1] hid_dim = 512 output_dim = state_dict["fc_a.weight"].size()[0] agent = r2d2.R2D2Agent( False, 3, 0.999, 0.9, device, input_dim, hid_dim, output_dim, 2, 5, False ).to(device) utils.load_weight(agent.online_net, weight_file, device) agents.append(agent) return agents
args.hand_size, args.train_bomb, explore_eps, args.max_len, args.sad, args.shuffle_obs, args.shuffle_color, ) agent = r2d2.R2D2Agent( (args.method == "vdn"), args.multi_step, args.gamma, args.eta, args.train_device, games[0].feature_size(), args.rnn_hid_dim, games[0].num_action(), args.num_lstm_layer, args.hand_size, False, # uniform priority ) agent.sync_target_with_online() if args.load_model: print("*****loading pretrained model*****") utils.load_weight(agent.online_net, args.load_model, args.train_device) print("*****done*****") agent = agent.to(args.train_device) optim = torch.optim.Adam(agent.online_net.parameters(),