示例#1
0
def train_loop(cfg, agent, logger):
    curr_iter, max_iter, eval_iter, eval_batch_sz, batch_sz, save_iter =\
        cfg.require("current training iter", "max iter", "eval interval",
                    "eval batch size", "batch size", "save interval")

    training_cfg = ParamDict({
        "policy state dict": agent.policy().getStateDict(),
        "filter state dict": agent.filter().getStateDict(),
        "trajectory max step": 64,
        "batch size": batch_sz,
        "fixed environment": False,
        "fixed policy": False,
        "fixed filter": False
    })
    validate_cfg = ParamDict({
        "policy state dict": None,
        "filter state dict": None,
        "trajectory max step": 64,
        "batch size": eval_batch_sz,
        "fixed environment": False,
        "fixed policy": True,
        "fixed filter": True
    })

    for i_iter in range(curr_iter, max_iter):

        s_time = float(running_time(fmt=False))

        """sample new batch and perform TRPO update"""
        batch_train, info_train = agent.rollout(training_cfg)
        trpo_step(cfg, batch_train, agent.policy())

        e_time = float(running_time(fmt=False))

        logger.train()
        info_train["duration"] = e_time - s_time
        info_train["epoch"] = i_iter
        logger(info_train)

        cfg["current training iter"] = i_iter + 1
        cfg["policy state dict"] = training_cfg["policy state dict"] = validate_cfg["policy state dict"] = agent.policy().getStateDict()
        cfg["filter state dict"] = training_cfg["filter state dict"] = validate_cfg["filter state dict"] = agent.filter().getStateDict()

        if i_iter % eval_iter == 0:
            batch_eval, info_eval = agent.rollout(validate_cfg)

            logger.train(False)
            info_eval["duration"] = e_time - s_time
            info_eval["epoch"] = i_iter
            logger(info_eval)

        if i_iter != 0 and i_iter % save_iter == 0:
            file_name = os.path.join(model_dir(cfg), f"I_{i_iter}.pkl")
            cfg.save(file_name)
            print(f"Saving current step at {file_name}")

    file_name = os.path.join(model_dir(cfg), f"final.pkl")
    cfg.save(file_name)
    print(f"Total running time: {running_time(fmt=True)}, result saved at {file_name}")
示例#2
0
def train_loop(cfg, agent, logger):
    curr_iter, max_iter, eval_iter, eval_batch_sz, batch_sz, save_iter, demo_loader =\
        cfg.require("current training iter", "max iter", "eval interval",
                    "eval batch size", "batch size", "save interval", "demo loader")

    training_cfg = ParamDict({
        "policy state dict": agent.policy().getStateDict(),
        "filter state dict": agent.filter().getStateDict(),
        "trajectory max step": 1024,
        "batch size": batch_sz,
        "fixed environment": False,
        "fixed policy": False,
        "fixed filter": False
    })
    validate_cfg = ParamDict({
        "policy state dict": None,
        "filter state dict": None,
        "trajectory max step": 1024,
        "batch size": eval_batch_sz,
        "fixed environment": False,
        "fixed policy": True,
        "fixed filter": True
    })

    # we use the entire demo set without sampling
    demo_trajectory = demo_loader.generate_all()
    if demo_trajectory is None:
        print("Warning: No demo loaded, fall back compatible with TRPO method")
    else:
        print("Info: Demo loaded successfully")
        demo_actions = []
        demo_states = []
        for p in demo_trajectory:
            demo_actions.append(
                torch.as_tensor([t['a'] for t in p],
                                dtype=torch.float32,
                                device=agent.policy().device))
            demo_states.append(
                torch.as_tensor([t['s'] for t in p],
                                dtype=torch.float32,
                                device=agent.policy().device))
        demo_states = torch.cat(demo_states, dim=0)
        demo_actions = torch.cat(demo_actions, dim=0)
        demo_trajectory = (demo_states, demo_actions)

    for i_iter in range(curr_iter, max_iter):

        s_time = float(running_time(fmt=False))
        """sample new batch and perform MCPO update"""
        batch_train, info_train = agent.rollout(training_cfg)

        demo_batch = None
        if demo_trajectory is not None:
            filter_dict = agent.filter().getStateDict()
            errsum, mean, n_step = filter_dict["zfilter errsum"], filter_dict[
                "zfilter mean"], filter_dict["zfilter n_step"]
            errsum = torch.as_tensor(errsum,
                                     dtype=torch.float32,
                                     device=agent.policy().device)
            mean = torch.as_tensor(mean,
                                   dtype=torch.float32,
                                   device=agent.policy().device)
            std = torch.sqrt(errsum / (n_step - 1)) if n_step > 1 else mean
            demo_batch = ((demo_trajectory[0] - mean) / (std + 1e-8),
                          demo_trajectory[1])

        mcpo_step(cfg, batch_train, agent.policy(), demo_batch)

        e_time = float(running_time(fmt=False))

        logger.train()
        info_train["duration"] = e_time - s_time
        info_train["epoch"] = i_iter
        logger(info_train)

        cfg["current training iter"] = i_iter + 1
        cfg["policy state dict"] = training_cfg[
            "policy state dict"] = validate_cfg[
                "policy state dict"] = agent.policy().getStateDict()
        cfg["filter state dict"] = training_cfg[
            "filter state dict"] = validate_cfg[
                "filter state dict"] = agent.filter().getStateDict()

        if i_iter % eval_iter == 0:
            batch_eval, info_eval = agent.rollout(validate_cfg)

            logger.train(False)
            info_eval["duration"] = e_time - s_time
            info_eval["epoch"] = i_iter
            logger(info_eval)

        if i_iter != 0 and i_iter % save_iter == 0:
            file_name = os.path.join(model_dir(cfg), f"I_{i_iter}.pkl")
            cfg.save(file_name)
            print(f"Saving current step at {file_name}")

    file_name = os.path.join(model_dir(cfg), f"final.pkl")
    cfg.save(file_name)
    print(
        f"Total running time: {running_time(fmt=True)}, result saved at {file_name}"
    )
示例#3
0
def train_loop(cfg, agent, logger):
    curr_iter, max_iter, eval_iter, eval_batch_sz, save_iter, demo_loader =\
        cfg.require("current training iter", "max iter", "eval interval",
                    "eval batch size", "save interval", "demo loader")

    validate_cfg = ParamDict({
        "policy state dict": None,
        "filter state dict": None,
        "trajectory max step": 64,
        "batch size": eval_batch_sz,
        "fixed environment": False,
        "fixed policy": True,
        "fixed filter": True
    })

    # we use the entire demo set without sampling
    demo_trajectory = demo_loader.generate_all()
    if demo_trajectory is None:
        raise FileNotFoundError(
            "Demo file not exists or cannot be loaded, abort !")
    else:
        print("Info: Demo loaded successfully")
        demo_actions = []
        demo_states = []
        for p in demo_trajectory:
            demo_actions.append(
                torch.as_tensor([t['a'] for t in p],
                                dtype=torch.float32,
                                device=agent.policy().device))
            demo_states.append(
                torch.as_tensor([t['s'] for t in p],
                                dtype=torch.float32,
                                device=agent.policy().device))
        demo_states = torch.cat(demo_states, dim=0)
        demo_actions = torch.cat(demo_actions, dim=0)
        demo_trajectory = (demo_states, demo_actions)

    for i_iter in range(curr_iter, max_iter):

        s_time = float(running_time(fmt=False))
        """sample new data batch and perform Behavior Cloning update"""
        loss = bc_step(cfg, agent.policy(), demo_trajectory)

        e_time = float(running_time(fmt=False))

        cfg["current training iter"] = i_iter + 1
        cfg["policy state dict"] = validate_cfg[
            "policy state dict"] = agent.policy().getStateDict()
        cfg["filter state dict"] = validate_cfg[
            "filter state dict"] = agent.filter().getStateDict()

        if i_iter % eval_iter == 0:
            batch_eval, info_eval = agent.rollout(validate_cfg)
            logger.train(False)
            info_eval["duration"] = e_time - s_time
            info_eval["epoch"] = i_iter
            info_eval["loss"] = loss
            logger(info_eval)

        if i_iter != 0 and i_iter % save_iter == 0:
            file_name = os.path.join(model_dir(cfg), f"I_{i_iter}.pkl")
            cfg.save(file_name)
            print(f"Saving current step at {file_name}")

    file_name = os.path.join(model_dir(cfg), f"final.pkl")
    cfg.save(file_name)
    print(
        f"Total running time: {running_time(fmt=True)}, result saved at {file_name}"
    )