Пример #1
0
    def reset(self, param: ParamDict):
        assert self.policy_net is not None, "Error: you should call init before reset !"
        policy = param.require("policy state_dict")
        if "fixed policy" in param:
            self.is_fixed = param.require("fixed policy")

        self.policy_net.load_state_dict(policy)
Пример #2
0
    def broadcast(self, config: ParamDict):
        policy_state, filter_state, max_step, self._batch_size, fixed_env, fixed_policy, fixed_filter = \
            config.require("policy state dict", "filter state dict", "trajectory max step", "batch size",
                           "fixed environment", "fixed policy", "fixed filter")

        self._replay_buffer = []
        policy_state["fixed policy"] = fixed_policy
        filter_state["fixed filter"] = fixed_filter
        cmd = ParamDict({"trajectory max step": max_step,
                         "fixed environment": fixed_env,
                         "filter state dict": filter_state})

        assert self._sync_signal.value < 1, "Last sync event not finished due to some error, some sub-proc maybe died, abort"
        # tell sub-process to reset
        self._sync_signal.value = len(self._policy_proc) + len(self._environment_proc)

        # sync net parameters
        with self._policy_lock:
            for _ in range(len(self._policy_proc)):
                self._param_pipe.send(policy_state)

        # wait for all agents' ready feedback
        while self._sync_signal.value > 0:
            sleep(0.01)

        # sending commands
        with self._environment_lock:
            for _ in range(self._batch_size):
                self._control_pipe.send(cmd)
Пример #3
0
def train_loop(cfg, agent, logger):
    curr_iter, max_iter, eval_iter, eval_batch_sz, batch_sz, save_iter =\
        cfg.require("current training iter", "max iter", "eval interval",
                    "eval batch size", "batch size", "save interval")

    training_cfg = ParamDict({
        "policy state dict": agent.policy().getStateDict(),
        "filter state dict": agent.filter().getStateDict(),
        "trajectory max step": 64,
        "batch size": batch_sz,
        "fixed environment": False,
        "fixed policy": False,
        "fixed filter": False
    })
    validate_cfg = ParamDict({
        "policy state dict": None,
        "filter state dict": None,
        "trajectory max step": 64,
        "batch size": eval_batch_sz,
        "fixed environment": False,
        "fixed policy": True,
        "fixed filter": True
    })

    for i_iter in range(curr_iter, max_iter):

        s_time = float(running_time(fmt=False))

        """sample new batch and perform TRPO update"""
        batch_train, info_train = agent.rollout(training_cfg)
        trpo_step(cfg, batch_train, agent.policy())

        e_time = float(running_time(fmt=False))

        logger.train()
        info_train["duration"] = e_time - s_time
        info_train["epoch"] = i_iter
        logger(info_train)

        cfg["current training iter"] = i_iter + 1
        cfg["policy state dict"] = training_cfg["policy state dict"] = validate_cfg["policy state dict"] = agent.policy().getStateDict()
        cfg["filter state dict"] = training_cfg["filter state dict"] = validate_cfg["filter state dict"] = agent.filter().getStateDict()

        if i_iter % eval_iter == 0:
            batch_eval, info_eval = agent.rollout(validate_cfg)

            logger.train(False)
            info_eval["duration"] = e_time - s_time
            info_eval["epoch"] = i_iter
            logger(info_eval)

        if i_iter != 0 and i_iter % save_iter == 0:
            file_name = os.path.join(model_dir(cfg), f"I_{i_iter}.pkl")
            cfg.save(file_name)
            print(f"Saving current step at {file_name}")

    file_name = os.path.join(model_dir(cfg), f"final.pkl")
    cfg.save(file_name)
    print(f"Total running time: {running_time(fmt=True)}, result saved at {file_name}")
Пример #4
0
    def __init__(self, config: ParamDict, environment: Environment,
                 policy: Policy, filter_op: Filter):
        threads, gpu = config.require("threads", "gpu")
        threads_gpu = config["gpu threads"] if "gpu threads" in config else 2
        super(Agent_async, self).__init__(config, environment, policy,
                                          filter_op)

        # sync signal, -1: terminate, 0: normal running, >0 restart and waiting for parameter update
        self._sync_signal = Value('i', 0)

        # environment sub-process list
        self._environment_proc = []
        # policy sub-process list
        self._policy_proc = []

        # used for synchronize policy parameters
        self._param_pipe = None
        self._policy_lock = Lock()
        # used for synchronize roll-out commands
        self._control_pipe = None
        self._environment_lock = Lock()

        step_pipe = []
        cmd_pipe_child, cmd_pipe_parent = Pipe(duplex=True)
        param_pipe_child, param_pipe_parent = Pipe(duplex=False)
        self._control_pipe = cmd_pipe_parent
        self._param_pipe = param_pipe_parent
        for i_envs in range(threads):
            child_name = f"environment_{i_envs}"
            step_pipe_pi, step_pipe_env = Pipe(duplex=True)
            step_lock = Lock()
            worker_cfg = ParamDict({
                "seed": self.seed + 1024 + i_envs,
                "gpu": gpu
            })
            child = Process(target=Agent_async._environment_worker,
                            name=child_name,
                            args=(worker_cfg, cmd_pipe_child, step_pipe_env,
                                  self._environment_lock, step_lock,
                                  self._sync_signal, deepcopy(environment),
                                  deepcopy(filter_op)))
            self._environment_proc.append(child)
            step_pipe.append((step_pipe_pi, step_lock))
            child.start()

        for i_policies in range(threads_gpu):
            child_name = f"policy_{i_policies}"
            worker_cfg = ParamDict({
                "seed": self.seed + 2048 + i_policies,
                "gpu": gpu
            })
            child = Process(target=Agent_async._policy_worker,
                            name=child_name,
                            args=(worker_cfg, param_pipe_child, step_pipe,
                                  self._policy_lock, self._sync_signal,
                                  deepcopy(policy)))
            self._policy_proc.append(child)
            child.start()
        sleep(5)
Пример #5
0
    def __init__(self, config: ParamDict, environment: Environment,
                 policy: Policy, filter_op: Filter):
        seed, gpu = config.require("seed", "gpu")
        # replay buffer
        self._replay_buffer = []
        self._batch_size = 0

        # device and seed
        self.device = decide_device(gpu)
        self.seed = seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

        # policy which will be copied to child thread before each roll-out
        self._policy = deepcopy(policy)
        # environment which will be copied to child thread and not inited in main thread
        self._environment = deepcopy(environment)
        # filter which will be copied to child thread and also be kept in main thread
        self._filter = deepcopy(filter_op)

        self._filter.init()
        self._filter.to_device(self.device)
        self._policy.init()
        self._policy.to_device(self.device)
Пример #6
0
 def getStateDict(self) -> ParamDict:
     return ParamDict({
         "value state_dict":
         cpu_state_dict(self.value_net.state_dict()),
         "policy state_dict":
         cpu_state_dict(self.policy_net.state_dict())
     })
Пример #7
0
def bc_step(config: ParamDict, policy: Policy, demo):
    lr_init, lr_factor, l2_reg, bc_method, batch_sz, i_iter = \
        config.require("lr", "lr factor", "l2 reg", "bc method", "batch size", "current training iter")
    states, actions = demo

    # ---- annealing on learning rate ---- #
    lr = max(lr_init + lr_factor * i_iter, 1.e-8)

    optimizer = Adam(policy.policy_net.parameters(),
                     weight_decay=l2_reg,
                     lr=lr)

    # ---- define BC from demonstrations ---- #
    total_len = states.size(0)
    idx = torch.randperm(total_len, device=policy.device)
    err = 0.
    for i_b in range(int(total_len // batch_sz) + 1):
        idx_b = idx[i_b * batch_sz:(i_b + 1) * batch_sz]
        s_b = states[idx_b]
        a_b = actions[idx_b]

        optimizer.zero_grad()
        a_mean_pred, a_logvar_pred = policy.policy_net(s_b)
        bc_loss = mse_loss(a_mean_pred + 0. * a_logvar_pred, a_b)
        err += bc_loss.item() * s_b.size(0) / total_len
        bc_loss.backward()
        optimizer.step()

    return err
Пример #8
0
 def getStateDict(self) -> ParamDict:
     state_dict = super(ZFilter, self).getStateDict()
     return state_dict + ParamDict({
         "zfilter mean": self.mean,
         "zfilter errsum": self.errsum,
         "zfilter n_step": self.n_step,
         "fixed filter": self.is_fixed
     })
Пример #9
0
    def init(self, config: ParamDict):
        env_name, tag, short, seed = config.require("env name", "tag", "short", "seed")
        self.default_name = f"{tag}-{env_name}-{short}-{seed}"
        self._log_dir = log_dir(config)
        self.loggerx = SummaryWriter(log_dir=self._log_dir)

        self._epoch_train = 0
        self._epoch_val = 0
Пример #10
0
def trpo_step(config: ParamDict, batch: StepDictList, policy: PolicyWithValue):
    max_kl, damping, l2_reg = config.require("max kl", "damping", "l2 reg")
    states, actions, advantages, returns = get_tensor(batch, policy.device)
    """update critic"""
    update_value_net(policy.value_net, states, returns, l2_reg)
    """update policy"""
    with torch.no_grad():
        fixed_log_probs = policy.policy_net.get_log_prob(states,
                                                         actions).detach()
    """define the loss function for TRPO"""
    def get_loss(volatile=False):
        if volatile:
            with torch.no_grad():
                log_probs = policy.policy_net.get_log_prob(states, actions)
        else:
            log_probs = policy.policy_net.get_log_prob(states, actions)
        action_loss = -advantages * torch.exp(log_probs - fixed_log_probs)
        return action_loss.mean()

    """define Hessian*vector for KL"""

    def Fvp(v):
        kl = policy.policy_net.get_kl(states).mean()

        grads = torch.autograd.grad(kl,
                                    policy.policy_net.parameters(),
                                    create_graph=True)
        flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])

        kl_v = (flat_grad_kl * v).sum()
        grads = torch.autograd.grad(kl_v, policy.policy_net.parameters())
        flat_grad_grad_kl = torch.cat(
            [grad.contiguous().view(-1) for grad in grads]).detach()

        return flat_grad_grad_kl + v * damping

    for _ in range(2):
        loss = get_loss()
        grads = torch.autograd.grad(loss, policy.policy_net.parameters())
        loss_grad = torch.cat([grad.view(-1) for grad in grads]).detach()
        stepdir = cg(Fvp, -loss_grad, nsteps=10)

        shs = 0.5 * (stepdir.dot(Fvp(stepdir)))
        lm = (max_kl / shs).sqrt_()
        fullstep = stepdir * lm
        expected_improve = -loss_grad.dot(fullstep)

        prev_params = policy.policy_net.get_flat_params().detach()
        success, new_params = line_search(policy.policy_net, get_loss,
                                          prev_params, fullstep,
                                          expected_improve)
        policy.policy_net.set_flat_params(new_params)
        if not success:
            return False
    return True
Пример #11
0
    def broadcast(self, config: ParamDict):
        policy_state, filter_state, self._max_step, self._batch_size, self._fixed_env, fixed_policy, fixed_filter = \
            config.require("policy state dict", "filter state dict", "trajectory max step",
                           "batch size", "fixed environment", "fixed policy", "fixed filter")

        self._replay_buffer = []
        policy_state["fixed policy"] = fixed_policy
        filter_state["fixed filter"] = fixed_filter

        self._filter.reset(filter_state)
        self._policy.reset(policy_state)
Пример #12
0
def ppo_step(config: ParamDict, batch: SampleBatch, policy: PolicyWithValue):
    lr, l2_reg, clip_epsilon, policy_iter, i_iter, max_iter, mini_batch_sz = \
        config.require("lr", "l2 reg", "clip eps", "optimize policy epochs",
                       "current training iter", "max iter", "optimize batch size")
    lam_entropy = 0.
    states, actions, advantages, returns = get_tensor(batch, policy.device)

    lr_mult = max(1.0 - i_iter / max_iter, 0.)
    clip_epsilon = clip_epsilon * lr_mult

    optimizer_policy = Adam(policy.policy_net.parameters(),
                            lr=lr * lr_mult,
                            weight_decay=l2_reg)
    optimizer_value = Adam(policy.value_net.parameters(),
                           lr=lr * lr_mult,
                           weight_decay=l2_reg)

    with torch.no_grad():
        fixed_log_probs = policy.policy_net.get_log_prob(states,
                                                         actions).detach()

    for _ in range(policy_iter):
        inds = torch.randperm(states.size(0))
        """perform mini-batch PPO update"""
        for i_b in range(inds.size(0) // mini_batch_sz):
            slc = slice(i_b * mini_batch_sz, (i_b + 1) * mini_batch_sz)

            states_i = states[slc]
            actions_i = actions[slc]
            returns_i = returns[slc]
            advantages_i = advantages[slc]
            log_probs_i = fixed_log_probs[slc]
            """update critic"""
            for _ in range(1):
                value_loss = F.mse_loss(policy.value_net(states_i), returns_i)
                optimizer_value.zero_grad()
                value_loss.backward()
                torch.nn.utils.clip_grad_norm_(policy.value_net.parameters(),
                                               0.5)
                optimizer_value.step()
            """update policy"""
            log_probs, entropy = policy.policy_net.get_log_prob_entropy(
                states_i, actions_i)
            ratio = (log_probs - log_probs_i).clamp_max(15.).exp()
            surr1 = ratio * advantages_i
            surr2 = torch.clamp(ratio, 1.0 - clip_epsilon,
                                1.0 + clip_epsilon) * advantages_i
            policy_surr = -torch.min(
                surr1, surr2).mean() - entropy.mean() * lam_entropy
            optimizer_policy.zero_grad()
            policy_surr.backward()
            torch.nn.utils.clip_grad_norm_(policy.policy_net.parameters(), 0.5)
            optimizer_policy.step()
Пример #13
0
def replay_loop(cfg, agent):
    max_iter, display = cfg.require("verify iter", "verify display")

    validate_cfg = ParamDict({
        "policy state dict": cfg["policy state dict"],
        "filter state dict": cfg["filter state dict"],
        "trajectory max step": 64,
        "max iter": max_iter,
        "display": display,
        "fixed environment": False,
        "fixed policy": True,
        "fixed filter": True
    })

    agent.verify(validate_cfg)
Пример #14
0
    def __init__(self, config_dict: ParamDict, env_info: dict):
        # require check failed mainly because you have not passed config to environment first,
        # which will fill-in action dim and state dim fields
        assert "state dim" in env_info and "action dim" in env_info,\
            f"Error: Key 'state dim' or 'action dim' not in env_info, which only contains {env_info.keys()}"
        super(PolicyOnly, self).__init__()

        activation = config_dict.require("activation")

        self.policy_net = None

        # TODO: do some actual work here
        self._action_dim = env_info["action dim"]
        self._state_dim = env_info["state dim"]
        self.is_fixed = False
        self._activation = activation
Пример #15
0
def ppo_step(config: ParamDict, replay_memory: StepDictList,
             policy: PolicyWithValue):
    lr, l2_reg, clip_epsilon, policy_iter, i_iter, max_iter, mini_batch_sz = \
        config.require("lr", "l2 reg", "clip eps", "optimize policy epochs",
                       "current training iter", "max iter", "optimize batch size")
    lam_entropy = 0.0
    states, actions, advantages, returns = get_tensor(policy, replay_memory,
                                                      policy.device)
    """update critic"""
    update_value_net(policy.value_net, states, returns, l2_reg)
    """update policy"""
    lr_mult = max(1.0 - i_iter / max_iter, 0.)
    clip_epsilon = clip_epsilon * lr_mult
    optimizer = Adam(policy.policy_net.parameters(),
                     lr=lr * lr_mult,
                     weight_decay=l2_reg)

    with torch.no_grad():
        fixed_log_probs = policy.policy_net.get_log_prob(states,
                                                         actions).detach()

    inds = torch.arange(states.size(0))

    for _ in range(policy_iter):
        np.random.shuffle(inds)
        """perform mini-batch PPO update"""
        for i_b in range(int(np.ceil(states.size(0) / mini_batch_sz))):
            ind = inds[i_b * mini_batch_sz:min((i_b + 1) *
                                               mini_batch_sz, inds.size(0))]

            log_probs, entropy = policy.policy_net.get_log_prob_entropy(
                states[ind], actions[ind])
            ratio = torch.exp(log_probs - fixed_log_probs[ind])
            surr1 = ratio * advantages[ind]
            surr2 = torch.clamp(ratio, 1.0 - clip_epsilon,
                                1.0 + clip_epsilon) * advantages[ind]
            policy_surr = -torch.min(surr1, surr2).mean()
            policy_surr -= entropy.mean() * lam_entropy
            optimizer.zero_grad()
            policy_surr.backward()
            torch.nn.utils.clip_grad_norm_(policy.policy_net.parameters(), 0.5)
            optimizer.step()
Пример #16
0
    def verify(self, config: ParamDict):
        policy_state, filter_state, max_step, max_iter, display, fixed_env, fixed_policy, fixed_filter =\
             config.require("policy state dict", "filter state dict", "trajectory max step", "max iter", "display",
                            "fixed environment", "fixed policy", "fixed filter")

        policy_state["fixed policy"] = fixed_policy
        filter_state["fixed filter"] = fixed_filter

        self._environment.init(display=display)
        self._policy.reset(policy_state)
        self._filter.reset(filter_state)

        r_sum = []
        for i in range(max_iter):
            current_step = self._environment.reset(random=not fixed_env)
            # sampling
            r_sum.append(0.)
            i_step = 0
            for _ in range(max_step):
                current_step = self._filter.operate_currentStep(current_step)
                last_step = self._policy.step([current_step])[0]
                last_step, current_step, done = self._environment.step(
                    last_step)
                record_step = self._filter.operate_recordStep(last_step)
                r_sum[-1] += record_step['r']
                i_step += 1
                if done:
                    break
            print(f"Verification iter {i}: reward={r_sum[-1]}; step={i_step}")
        # finalization
        self._environment.finalize()
        r_sum = np.asarray(r_sum, dtype=np.float32)
        r_sum_mean = r_sum.mean()
        r_sum_std = r_sum.std()
        r_sum_max = r_sum.max()
        r_sum_min = r_sum.min()
        print(
            f"Verification done with reward sum: avg={r_sum_mean}, std={r_sum_std}, min={r_sum_min}, max={r_sum_max}"
        )
        return r_sum_mean, r_sum_std, r_sum_min, r_sum_max
Пример #17
0
def loadInitConfig(default_cfg: Config):
    """
    This function will deepcopy the cfg and add some fields used for (continue) training
    """
    cfg = deepcopy(default_cfg)
    saved_name = cfg.require("load name")

    cfg.register_item("current training iter", 0, fields=["save"])
    cfg.register_item("policy state dict", None, fields=["save"])
    cfg.register_item("filter state dict", None, fields=["save"])

    # load demos if it exists
    if "demo path" in cfg:
        demo_path = cfg.require("demo path")
        demo_loader = DemoLoader()
        if os.path.isfile(cfg["demo path"]):
            demo_loader.load_file(demo_path)
            print(f"Info: loading Demo from file {demo_path}")
        elif os.path.isfile(demo_dir(cfg["demo path"])):
            demo_loader.load_file(demo_dir(cfg["demo path"]))
            print(f"Info: loading Demo from file {demo_dir(cfg['demo path'])}")

        cfg.register_item("demo loader", demo_loader)
        if "filter state dict" in demo_loader.info():
            cfg["filter state dict"] = ParamDict(demo_loader.info()["filter state dict"])

    # load saved model if it exists
    if os.path.isfile(saved_name):
        cfg.load(saved_name, "this")
        print(f"Info: try loading saved model from {saved_name}")
    else:
        saved_name = os.path.join(model_dir(cfg), saved_name)
        if os.path.isfile(saved_name):
            cfg.load(saved_name, "this")
            print(f"Info: try loading saved model from {saved_name}")

    return cfg
Пример #18
0
    def verify(self, config: ParamDict, environment: Environment):
        max_iter, max_step, random = \
            config.require("verify iter", "verify max step", "verify random")

        r_sum = [0.]
        for _ in range(max_iter):
            current_step = environment.reset(random=random)
            # sampling
            for _ in range(max_step):
                last_step = self._policy.step([current_step])[0]
                last_step, current_step, done = environment.step(last_step)
                r_sum[-1] += last_step['r']
                if done:
                    break
        # finalization
        r_sum = torch.as_tensor(r_sum, dtype=torch.float32, device=self.device)
        r_sum_mean = r_sum.mean()
        r_sum_std = r_sum.std()
        r_sum_max = r_sum.max()
        r_sum_min = r_sum.min()
        print(
            f"Verification done with reward sum: avg={r_sum_mean}, std={r_sum_std}, min={r_sum_min}, max={r_sum_max}"
        )
        return r_sum_mean, r_sum_std, r_sum_min, r_sum_max
Пример #19
0
def train_loop(cfg, agent, logger):
    curr_iter, max_iter, eval_iter, eval_batch_sz, save_iter, demo_loader =\
        cfg.require("current training iter", "max iter", "eval interval",
                    "eval batch size", "save interval", "demo loader")

    validate_cfg = ParamDict({
        "policy state dict": None,
        "filter state dict": None,
        "trajectory max step": 64,
        "batch size": eval_batch_sz,
        "fixed environment": False,
        "fixed policy": True,
        "fixed filter": True
    })

    # we use the entire demo set without sampling
    demo_trajectory = demo_loader.generate_all()
    if demo_trajectory is None:
        raise FileNotFoundError(
            "Demo file not exists or cannot be loaded, abort !")
    else:
        print("Info: Demo loaded successfully")
        demo_actions = []
        demo_states = []
        for p in demo_trajectory:
            demo_actions.append(
                torch.as_tensor([t['a'] for t in p],
                                dtype=torch.float32,
                                device=agent.policy().device))
            demo_states.append(
                torch.as_tensor([t['s'] for t in p],
                                dtype=torch.float32,
                                device=agent.policy().device))
        demo_states = torch.cat(demo_states, dim=0)
        demo_actions = torch.cat(demo_actions, dim=0)
        demo_trajectory = (demo_states, demo_actions)

    for i_iter in range(curr_iter, max_iter):

        s_time = float(running_time(fmt=False))
        """sample new data batch and perform Behavior Cloning update"""
        loss = bc_step(cfg, agent.policy(), demo_trajectory)

        e_time = float(running_time(fmt=False))

        cfg["current training iter"] = i_iter + 1
        cfg["policy state dict"] = validate_cfg[
            "policy state dict"] = agent.policy().getStateDict()
        cfg["filter state dict"] = validate_cfg[
            "filter state dict"] = agent.filter().getStateDict()

        if i_iter % eval_iter == 0:
            batch_eval, info_eval = agent.rollout(validate_cfg)
            logger.train(False)
            info_eval["duration"] = e_time - s_time
            info_eval["epoch"] = i_iter
            info_eval["loss"] = loss
            logger(info_eval)

        if i_iter != 0 and i_iter % save_iter == 0:
            file_name = os.path.join(model_dir(cfg), f"I_{i_iter}.pkl")
            cfg.save(file_name)
            print(f"Saving current step at {file_name}")

    file_name = os.path.join(model_dir(cfg), f"final.pkl")
    cfg.save(file_name)
    print(
        f"Total running time: {running_time(fmt=True)}, result saved at {file_name}"
    )
Пример #20
0
def train_loop(cfg, agent, logger):
    curr_iter, max_iter, eval_iter, eval_batch_sz, batch_sz, save_iter, demo_loader =\
        cfg.require("current training iter", "max iter", "eval interval",
                    "eval batch size", "batch size", "save interval", "demo loader")

    training_cfg = ParamDict({
        "policy state dict": agent.policy().getStateDict(),
        "filter state dict": agent.filter().getStateDict(),
        "trajectory max step": 1024,
        "batch size": batch_sz,
        "fixed environment": False,
        "fixed policy": False,
        "fixed filter": False
    })
    validate_cfg = ParamDict({
        "policy state dict": None,
        "filter state dict": None,
        "trajectory max step": 1024,
        "batch size": eval_batch_sz,
        "fixed environment": False,
        "fixed policy": True,
        "fixed filter": True
    })

    # we use the entire demo set without sampling
    demo_trajectory = demo_loader.generate_all()
    if demo_trajectory is None:
        print("Warning: No demo loaded, fall back compatible with TRPO method")
    else:
        print("Info: Demo loaded successfully")
        demo_actions = []
        demo_states = []
        for p in demo_trajectory:
            demo_actions.append(
                torch.as_tensor([t['a'] for t in p],
                                dtype=torch.float32,
                                device=agent.policy().device))
            demo_states.append(
                torch.as_tensor([t['s'] for t in p],
                                dtype=torch.float32,
                                device=agent.policy().device))
        demo_states = torch.cat(demo_states, dim=0)
        demo_actions = torch.cat(demo_actions, dim=0)
        demo_trajectory = (demo_states, demo_actions)

    for i_iter in range(curr_iter, max_iter):

        s_time = float(running_time(fmt=False))
        """sample new batch and perform MCPO update"""
        batch_train, info_train = agent.rollout(training_cfg)

        demo_batch = None
        if demo_trajectory is not None:
            filter_dict = agent.filter().getStateDict()
            errsum, mean, n_step = filter_dict["zfilter errsum"], filter_dict[
                "zfilter mean"], filter_dict["zfilter n_step"]
            errsum = torch.as_tensor(errsum,
                                     dtype=torch.float32,
                                     device=agent.policy().device)
            mean = torch.as_tensor(mean,
                                   dtype=torch.float32,
                                   device=agent.policy().device)
            std = torch.sqrt(errsum / (n_step - 1)) if n_step > 1 else mean
            demo_batch = ((demo_trajectory[0] - mean) / (std + 1e-8),
                          demo_trajectory[1])

        mcpo_step(cfg, batch_train, agent.policy(), demo_batch)

        e_time = float(running_time(fmt=False))

        logger.train()
        info_train["duration"] = e_time - s_time
        info_train["epoch"] = i_iter
        logger(info_train)

        cfg["current training iter"] = i_iter + 1
        cfg["policy state dict"] = training_cfg[
            "policy state dict"] = validate_cfg[
                "policy state dict"] = agent.policy().getStateDict()
        cfg["filter state dict"] = training_cfg[
            "filter state dict"] = validate_cfg[
                "filter state dict"] = agent.filter().getStateDict()

        if i_iter % eval_iter == 0:
            batch_eval, info_eval = agent.rollout(validate_cfg)

            logger.train(False)
            info_eval["duration"] = e_time - s_time
            info_eval["epoch"] = i_iter
            logger(info_eval)

        if i_iter != 0 and i_iter % save_iter == 0:
            file_name = os.path.join(model_dir(cfg), f"I_{i_iter}.pkl")
            cfg.save(file_name)
            print(f"Saving current step at {file_name}")

    file_name = os.path.join(model_dir(cfg), f"final.pkl")
    cfg.save(file_name)
    print(
        f"Total running time: {running_time(fmt=True)}, result saved at {file_name}"
    )
Пример #21
0
    def _environment_worker(setups: ParamDict, pipe_cmd, pipe_step, read_lock, step_lock, sync_signal, environment, filter_op):
        gpu, seed = setups.require("gpu", "seed")

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        environment.init(display=False)
        filter_op.to_device(torch.device("cpu"))
        filter_op.init()
        # -1: syncing, 0: waiting for command, 1: waiting for action
        local_state = 0
        step_buffer = []
        cmd = None

        def _get_piped_data(pipe, lock):
            with lock:
                if pipe.poll(0.001):
                    return pipe.recv()
                else:
                    return None

        while sync_signal.value >= 0:
            # check sync counter for sync event
            if sync_signal.value > 0 and local_state >= 0:
                # receive sync signal, reset all workspace settings, decrease sync counter,
                # and set state machine to -1 for not init again
                while _get_piped_data(pipe_cmd, read_lock) is not None:
                    pass
                while _get_piped_data(pipe_step, step_lock) is not None:
                    pass
                step_buffer.clear()
                with sync_signal.get_lock():
                    sync_signal.value -= 1
                local_state = -1

            # if sync ends, tell state machine to recover from syncing state, and reset environment
            elif sync_signal.value == 0 and local_state == -1:
                local_state = 0

            # idle and waiting for new command
            elif sync_signal.value == 0 and local_state == 0:
                cmd = _get_piped_data(pipe_cmd, read_lock)
                if cmd is not None:
                    step_buffer.clear()
                    cmd.require("fixed environment", "trajectory max step")
                    current_step = environment.reset(random=not cmd["fixed environment"])
                    filter_op.reset(cmd["filter state dict"])

                    policy_step = filter_op.operate_currentStep(current_step)
                    with step_lock:
                        pipe_step.send(policy_step)
                    local_state = 1

            # waiting for action
            elif sync_signal.value == 0 and local_state == 1:
                last_step = _get_piped_data(pipe_step, step_lock)
                if last_step is not None:
                    last_step, current_step, done = environment.step(last_step)
                    record_step = filter_op.operate_recordStep(last_step)
                    step_buffer.append(record_step)
                    if len(step_buffer) >= cmd["trajectory max step"] or done:
                        traj = filter_op.operate_stepList(step_buffer, done=done)
                        with read_lock:
                            pipe_cmd.send(traj)
                        local_state = 0
                    else:
                        policy_step = filter_op.operate_currentStep(current_step)
                        with step_lock:
                            pipe_step.send(policy_step)

        # finalization
        environment.finalize()
        filter_op.finalize()
        pipe_cmd.close()
        pipe_step.close()
        print("Environment sub-process exited")
Пример #22
0
def mcpo_step(config: ParamDict,
              batch: SampleBatch,
              policy: PolicyWithValue,
              demo=None):
    global mmd_cache
    max_kl, damping, l2_reg, bc_method, d_init, d_factor, d_max, i_iter = \
        config.require("max kl", "damping", "l2 reg", "bc method", "constraint", "constraint factor", "constraint max",
                       "current training iter")
    states, actions, advantages, returns, demo_states, demo_actions = \
        get_tensor(batch, demo, policy.device)

    # ---- annealing on constraint tolerance ---- #
    d = min(d_init + (d_factor * i_iter)**2, d_max)

    # ---- update critic ---- #
    update_value_net(policy.value_net, states, returns, l2_reg)

    # ---- update policy ---- #
    with torch.no_grad():
        fixed_log_probs = policy.policy_net.get_log_prob(states,
                                                         actions).detach()

    # ---- define the reward loss function for MCPO ---- #
    def RL_loss():
        log_probs = policy.policy_net.get_log_prob(states, actions)
        action_loss = -advantages * torch.exp(log_probs - fixed_log_probs)
        return action_loss.mean()

    # ---- define the reward loss function for MCPO ---- #
    def Dkl():
        kl = policy.policy_net.get_kl(states)
        return kl.mean()

    # ---- define MMD constraint from demonstrations ---- #
    # ---- if use oversample, recommended value is 5 ---- #
    def Dmmd(oversample=0):
        from_demo = torch.cat((demo_states, demo_actions), dim=1)
        # here uses distance between (s_e, a_e) and (s_e, pi_fix(s_e)) instead of (s, a) if not exact

        a, _, var = policy.policy_net(demo_states)
        if oversample > 0:
            sample_a = Normal(torch.zeros_like(a), torch.ones_like(a)).sample(
                (oversample, )) * var + a
            sample_s = demo_states.expand(oversample, -1, -1)
            from_policy = torch.cat(
                (sample_s, sample_a),
                dim=2).view(-1,
                            demo_states.size(-1) + a.size(-1))
        else:
            from_policy = torch.cat((demo_states, a + 0. * var), dim=1)

        return mmd(from_policy, from_demo, mmd_cache)

    def Dl2():
        import torch.nn.functional as F
        # here we use the mean
        mean, logstd = policy.policy_net(demo_states)
        return F.mse_loss(mean + 0. * logstd, demo_actions)

    # ---- define BC from demonstrations ---- #
    if bc_method == "l2":
        Dc = Dl2
    else:
        Dc = Dmmd

    def DO_BC():
        assert demo_states is not None, "I should not arrive here with demos == None"
        dist = Dc()
        if dist > d:
            policy_optimizer = torch.optim.Adam(policy.policy_net.parameters())
            #print(f"Debug: Constraint not meet, refining tile it satisfies {dist} < {d}")
            for _ in range(500):
                policy_optimizer.zero_grad()
                dist.backward()
                policy_optimizer.step()
                dist = Dc()
                if dist < d:
                    break
            #print(f"Debug: BC margin={d - dist}")
        else:
            print(f"Debug: constraint meet, {dist.item()} < {d}")
        x = policy.policy_net.get_flat_params().detach()
        return x

    # ---- define grad funcs ---- #
    def Hvp_f(v, damping=damping):
        kl = Dkl()

        grads = torch.autograd.grad(kl,
                                    policy.policy_net.parameters(),
                                    create_graph=True)
        flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])

        kl_v = (flat_grad_kl * v.detach()).sum()
        grads = torch.autograd.grad(kl_v, policy.policy_net.parameters())
        flat_grad_grad_kl = torch.cat([grad.view(-1)
                                       for grad in grads]).detach()

        return flat_grad_grad_kl + v * damping

    f0 = RL_loss()
    grads = torch.autograd.grad(f0, policy.policy_net.parameters())
    g = torch.cat([grad.view(-1) for grad in grads]).detach()
    f0.detach_()

    if demo_states is None:
        d_value = -1.e7
        c = 1.e7
        b = torch.zeros_like(g)
    else:
        d_value = Dc()
        c = (d - d_value).detach()
        grads = torch.autograd.grad(d_value, policy.policy_net.parameters())
        b = torch.cat([grad.view(-1) for grad in grads]).detach()

    # ---- update policy net with CG-LineSearch algorithm---- #
    d_theta, line_search_check_range, case = mcpo_optim(
        g, b, c, Hvp_f, max_kl, DO_BC)

    if torch.isnan(d_theta).any():
        if torch.isnan(b).any():
            print("b is NaN when Dc={}. Rejecting this step!".format(d_value))
        else:
            print("net parameter is NaN. Rejecting this step!")
        success = False
    elif line_search_check_range is not None:
        expected_df = g @ d_theta
        success, new_params = line_search(policy.policy_net, expected_df, f0,
                                          RL_loss, Dc, d, d_theta,
                                          line_search_check_range)
        policy.policy_net.set_flat_params(new_params)
    else:
        # here d_theta is from BC, so skip line search procedure
        success = True
    return (case if success else -1), d_value
Пример #23
0
    def _policy_worker(setups: ParamDict, pipe_param, pipe_steps, read_lock, sync_signal, policy):
        gpu, seed = setups.require("gpu", "seed")
        device = decide_device(gpu)

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        policy.to_device(device)
        policy.init()
        # -1: syncing, 0: waiting for state
        local_state = 0
        max_batchsz = 8

        def _get_piped_data(pipe, lock):
            with lock:
                if pipe.poll():
                    return pipe.recv()
                else:
                    return None

        while sync_signal.value >= 0:
            # check sync counter for sync event, and waiting for new parameters
            if sync_signal.value > 0:
                # receive sync signal, reset all workspace settings, decrease sync counter,
                # and set state machine to -1 for not init again
                for _pipe, _lock in pipe_steps:
                    while _get_piped_data(_pipe, _lock) is not None:
                        pass
                if local_state >= 0:
                    _policy_state = _get_piped_data(pipe_param, read_lock)
                    if _policy_state is not None:
                        # set new parameters
                        policy.reset(_policy_state)
                        with sync_signal.get_lock():
                            sync_signal.value -= 1
                        local_state = -1
                    else:
                        sleep(0.01)

            # if sync ends, tell state machine to recover from syncing state, and reset environment
            elif sync_signal.value == 0 and local_state == -1:
                local_state = 0

            # waiting for states (states are list of dicts)
            elif sync_signal.value == 0 and local_state == 0:
                idx = []
                data = []
                for i, (_pipe, _lock) in enumerate(pipe_steps):
                    if len(idx) >= max_batchsz:
                        break
                    _steps = _get_piped_data(_pipe, _lock)
                    if _steps is not None:
                        data.append(_steps)
                        idx.append(i)
                if len(idx) > 0:
                    # prepare for data batch
                    with torch.no_grad():
                        data = policy.step(data)
                    # send back actions
                    for i, d in zip(idx, data):
                        with pipe_steps[i][1]:
                            pipe_steps[i][0].send(d)
                else:
                    sleep(0.00001)

        # finalization
        policy.finalize()
        pipe_param.close()
        for _pipe, _lock in pipe_steps:
            _pipe.close()
        print("Policy sub-process exited")
Пример #24
0
 def getStateDict(self) -> ParamDict:
     return ParamDict()
Пример #25
0
 def reset(self, param: ParamDict):
     self.mean, self.errsum, self.n_step, self.is_fixed =\
         param.require("zfilter mean", "zfilter errsum", "zfilter n_step", "fixed filter")