def train_fetch(experiment: sacred.Experiment, agent: Any, eval_env: FetchEnv, progressive_noise: bool, small_goal: bool):
    reporting.register_field("eval_success_rate")
    reporting.register_field("action_norm")
    reporting.finalize_fields()
    if progressive_noise:
        trange = tqdm.trange(2000000)
    elif small_goal:
        trange = tqdm.trange(2000000)
    else:
        trange = tqdm.trange(2000000)
    for iteration in trange:
        if iteration % 10000 == 0:
            action_norms = []
            success_rate = 0
            for i in range(50):
                state = eval_env.reset()
                while not eval_env.needs_reset:
                    action = agent.eval_action(state)
                    action_norms.append(np.linalg.norm(action))
                    state, reward, is_terminal, info = eval_env.step(action)
                    if reward > -1.:
                        success_rate += 1
                        break
            reporting.iter_record("eval_success_rate", success_rate)
            reporting.iter_record("action_norm", np.mean(action_norms).item())

        if iteration % 20000 == 0:
            policy_path = f"/tmp/policy_{iteration}"
            with open(policy_path, 'wb') as f:
                torch.save(agent.freeze_policy(torch.device('cpu')), f)
            experiment.add_artifact(policy_path)

        agent.update()
        reporting.iterate()
        trange.set_description(f"{iteration} -- " + reporting.get_description(["return", "td_loss", "env_steps"]))
    def __init__(self,
                 make_env: Callable[[],
                                    environment.Environment[Union[np.ndarray,
                                                                  int]]],
                 device: torch.device, params: AsyncReplayAgentParams):
        self._params = params
        self._device = device
        self._actor_device = torch.device(self._params.actor_device)
        self._executor = futures.ProcessPoolExecutor(
            max_workers=self._params.num_envs,
            mp_context=torch.multiprocessing.get_context('spawn'),
            initializer=self._initialize_process_env,
            initargs=(make_env, ))
        self._futures = None

        self._make_env = make_env
        self._env = make_env()
        self._state_dim = self._env.state_dim
        self._action_dim = self._env.action_dim

        self._replay_description = self.get_description(self._env)
        self._buffer = ram_buffer.RamBuffer(
            self._params.replay_size, self._replay_description.num_columns,
            device)

        reporting.register_field("return")
        self._env_steps = 0
        reporting.register_field("env_steps")
        self.__update_count = 0
    def __init__(self,
                 make_env: Callable[[],
                                    environment.Environment[Union[np.ndarray,
                                                                  int]]],
                 device: torch.device, q1: torch.nn.Module,
                 q2: torch.nn.Module, policy_net: torch.nn.Module,
                 params: HerTD3Params):
        super().__init__(make_env, device, params)
        spec_env = make_env()
        action_dim = spec_env.action_dim

        self._online_policy = policy_net
        self._target_policy = copy.deepcopy(policy_net)
        self._behavior_policy = gaussian.SphericalGaussianPolicy(
            device,
            action_dim,
            self._target_policy,
            fixed_noise=self._params.exploration_noise,
            eval_use_mean=True)
        self._critic = double_q_critic.DoubleQCritic(
            q1, q2, self._target_policy, params.target_update_step,
            params.discount_factor, params.target_action_noise)

        self._target_update_rate = 2
        self._update_count = -1

        self._critic_optim = torch.optim.RMSprop(
            self._critic.parameters, lr=params.critic_learning_rate)
        self._actor_optim = torch.optim.RMSprop(
            self._online_policy.parameters(),
            lr=params.critic_learning_rate,
            weight_decay=params.actor_weight_decay)

        reporting.register_field("advantage_loss")
        reporting.register_field("td_loss")
Example #4
0
    def __init__(self, net: torch.nn.Module, learning_rate: float, batch_size: int):
        self._train_x = np.random.random(size=(500, 2)).astype(np.float32)
        self._train_y = (self._train_x[:, 0] * 3 + self._train_x[:, 1] +
                         np.random.randn(self._train_x.shape[0]).astype(np.float32) * 0.1)
        self._valid_x = np.random.random(size=(100, 2)).astype(np.float32)
        self._valid_y = (self._train_x[:, 0] * 3 + self._train_x[:, 1] +
                         np.random.randn(self._train_x.shape[0]).astype(np.float32) * 0.1)

        self._net = net
        self._optim = torch.optim.Adam(self._net.parameters(), learning_rate)
        self._batch_size = batch_size

        reporting.register_field("train_loss")
        reporting.register_field("valid_loss")
Example #5
0
    def __init__(self, make_env: Callable[[], environment.Environment[Union[np.ndarray, int]]], device: torch.device,
                 params: ReplayAgentParams):
        self._params = params
        self._env = make_env()
        self._env.reset()
        self._state_dim = self._env.state_dim
        self._action_dim = self._env.action_dim
        self._device = device

        # state, action, reward, next_state, timeout, terminal, action_logprob
        buffer_dim = 2 * self._state_dim + self._action_dim + 4
        if params.buffer_type == 'ram':
            self._buffer = ram_buffer.RamBuffer(self._params.replay_size, buffer_dim, device)
        elif params.buffer_type == 'vram':
            self._buffer = vram_buffer.VramBuffer(self._params.replay_size, buffer_dim, device)
        else:
            assert False

        reporting.register_field("return")
 def __init__(self, q1: torch.nn.Module, q2: torch.nn.Module,
              density_q_model, goal_dim: int,
              target_policy: torch.nn.Module, target_update_step: float,
              discount_factor: float, action_noise_stddev: float,
              replace_goal_fraction: float, shuffle_goals: bool,
              params: UVDParams):
     super().__init__(double_q_critic.DoubleQModel(q1, q2), 2,
                      target_update_step)
     self._online_q1 = q1
     self._density_q_model = density_q_model
     self._target_policy = target_policy
     self._discount_factor = discount_factor
     self._action_noise_stddev = action_noise_stddev
     self._goal_dim = goal_dim
     self._replace_goal_fraction = replace_goal_fraction
     self._largest_target = 0.
     self._shuffle_goals = shuffle_goals
     self._params = params
     reporting.register_field("target_fraction")
     reporting.register_field("target_q")
    def __init__(self,
                 make_env: Callable[[],
                                    environment.Environment[Union[np.ndarray,
                                                                  int]]],
                 device: torch.device, params: OnlineAgentParams):
        self._params = params
        self._envs = [make_env() for _ in range(params.num_envs)]
        state_dim = self._envs[0].state_dim
        action_dim = self._envs[0].action_dim
        self._device = device

        self._states = torch.zeros((params.batch_size, state_dim)).pin_memory()
        self._actions = torch.zeros(
            (params.batch_size, action_dim)).pin_memory()
        self._rewards = torch.zeros((params.batch_size, )).pin_memory()
        self._bootstrap_weights = torch.zeros(
            (params.batch_size, )).pin_memory()
        self._bootstrap_states = torch.zeros(
            (params.batch_size, state_dim)).pin_memory()
        self._bootstrap_actions = torch.zeros(
            (params.batch_size, action_dim)).pin_memory()

        reporting.register_field("return")
    def __init__(self, make_env: Callable[[], environment.Environment],
                 policy_: policy.Policy, value_function: torch.nn.Module,
                 params: A2CParams):
        super().__init__(make_env,
                         list(value_function.parameters())[0].device, params)
        self._policy = policy_
        self._params = params

        self._critic = value_td.ValueTD(
            model=value_function,
            target_update_rate=params.value_target_update_rate)
        self._actor_optimizer = torch.optim.RMSprop(set(
            self._policy.parameters),
                                                    eps=0.1,
                                                    lr=params.learning_rate)
        self._value_optimizer = torch.optim.RMSprop(
            value_function.parameters(), eps=0.1, lr=params.learning_rate)
        reporting.register_field("entropy_loss")
        reporting.register_field("advantage_loss")
        reporting.register_field("td_loss")

        self._num_updates = 0
    def __init__(self,
                 make_env: Callable[[],
                                    environment.Environment[Union[np.ndarray,
                                                                  int]]],
                 device: torch.device, goal_r: torch.nn.Module,
                 q1: torch.nn.Module, q2: torch.nn.Module,
                 policy_net: torch.nn.Module, params: UVDParams):
        super().__init__(make_env, device, params)
        spec_env = make_env()
        action_dim = spec_env.action_dim

        self._goal_r = goal_r
        self._online_policy = policy_net
        self._target_policy = copy.deepcopy(policy_net)
        self._behavior_policy = gaussian.SphericalGaussianPolicy(
            device,
            action_dim,
            self._target_policy,
            fixed_noise=self._params.exploration_noise,
            eval_use_mean=True)
        self._critic = UVDCritic(
            q1, q2, goal_r, spec_env.goal_dim, self._target_policy,
            params.target_update_step, params.discount_factor,
            params.target_action_noise, params.replace_goal_fraction,
            params.shuffle_goals, params)

        self._target_update_rate = 2
        self._update_count = -1
        self._goal_dim = spec_env.goal_dim

        self._critic_optim = torch.optim.RMSprop(
            self._critic.parameters, lr=params.critic_learning_rate)
        self._r_optim = torch.optim.RMSprop(self._goal_r.parameters(),
                                            lr=params.density_learning_rate,
                                            weight_decay=1e-5)
        self._actor_optim = torch.optim.RMSprop(
            self._online_policy.parameters(), lr=params.policy_learning_rate)

        reporting.register_field("td_loss")
        reporting.register_field("r_loss")
        reporting.register_field("mean_r")
        reporting.register_field("mean_r_with_goal")
        reporting.register_field("max_r_with_goal")
        reporting.register_field("min_r_with_goal")
        reporting.register_field("actor_loss")
        reporting.register_field("valid_r_loss")
        reporting.register_field("valid_target_r")
        reporting.register_field("valid_target_v")
def train(density_learning_rate: float, _config: sacred.config.ConfigDict):
    target_dir = "/home/anon/generated_data/algorithms"
    reporting.register_global_reporter(experiment, target_dir)
    device = torch.device('cuda:0')
    demo_states, _ = load_demos()
    demo_min = np.min(demo_states, axis=0)
    demo_max = np.max(demo_states, axis=0)
    random_min, random_max = random_rollout_bounds(10)
    min_states = np.minimum(demo_min, random_min)
    max_states = np.maximum(demo_max, random_max)
    make_normalized_env = functools.partial(environment_adapters.NormalizedEnv,
                                            make_env, min_states, max_states)
    eval_env = make_normalized_env()
    np.savetxt(target_dir + "/normalization", [min_states, max_states])
    experiment.add_artifact(target_dir + "/normalization")

    demo_states = eval_env.normalize_state(demo_states)
    demo_states = torch.from_numpy(demo_states).to(device)
    demo_actions = None
    print(demo_states.shape)

    state_dim = demo_states.shape[1]
    action_dim = eval_env.action_dim

    density_model = DensityModel(device, state_dim, action_dim)
    state_density_model = StateDensityModel(device, state_dim, action_dim)
    policy = PolicyNetwork(state_dim, action_dim).to(device)
    params_parser = util.ConfigParser(vdi.VDIParams)
    params = params_parser.parse(_config)

    q1 = QNetwork(state_dim, action_dim).to(device)
    q2 = QNetwork(state_dim, action_dim).to(device)

    agent = vdi.VDI(make_normalized_env, device, density_model,
                    state_density_model, policy, q1, q2, params, demo_states,
                    demo_actions)

    reporting.register_field("eval_return")
    reporting.finalize_fields()
    trange = tqdm.trange(1000000)
    for iteration in trange:
        agent.update()
        reporting.iterate()
        if iteration % 20000 == 0:
            eval_reward = 0
            for i in range(2):
                state = eval_env.reset()
                cumulative_reward = 0
                while not eval_env.needs_reset:
                    action = agent.eval_action(state)
                    state, reward, is_terminal, _ = eval_env.step(action)
                    cumulative_reward += reward
                eval_reward += cumulative_reward / 2
            reporting.iter_record("eval_return", eval_reward)

        if iteration % 10000 == 0:
            policy_path = f"{target_dir}/policy_{iteration}"
            with open(policy_path, 'wb') as f:
                torch.save(agent.freeze_policy(torch.device('cpu')), f)
            experiment.add_artifact(policy_path)
            density_model_path = f"{target_dir}/dm_{iteration}"
            with open(density_model_path, 'wb') as f:
                torch.save(density_model, f)
            experiment.add_artifact(density_model_path)

        trange.set_description(f"{iteration} -- " + reporting.get_description([
            "return", "eval_return", "density_loss", "actor_loss", "td_loss",
            "env_steps"
        ]))
Example #11
0
 def __init__(self, model: torch.nn.Module, target_update_rate: int, *args,
              **kwargs):
     super().__init__(model, target_update_rate, *args, **kwargs)
     reporting.register_field("td_target_value")
 def __init__(self, device: torch.device, network: torch.nn.Module):
     self._network = network
     self._device = device
     reporting.register_field('max pi')
Example #13
0
 def __init__(self, q1: torch.nn.Module, q2: torch.nn.Module, density_q_model, state_density_model, target_policy: torch.nn.Module,
              params: VDIParams):
     super().__init__(DoubleQModel(q1, q2), 2, params.target_update_step)
     self._online_q1 = q1
     self._density_q_model = density_q_model
     self._state_density_model = state_density_model
     self._target_policy = target_policy
     self._discount_factor = params.discount_factor
     self._action_noise_stddev = params.action_noise_stddev
     self._largest_target = 0.
     self._params = params
     self._temporal_smoothing = params.temporal_smoothing
     reporting.register_field("target_fraction")
     reporting.register_field("target_q")
     reporting.register_field("max_state_density")
     reporting.register_field("min_state_density")
     reporting.register_field("mean_state_density")
     reporting.register_field("state_density_bound")
     reporting.register_field("terminal_weight")
Example #14
0
    def __init__(self, make_env: Callable[[], environment.Environment[Union[np.ndarray, int]]], device: torch.device,
                 density_model: rnvp.SimpleRealNVP, state_density_model: rnvp.SimpleRealNVP,
                 policy_net: torch.nn.Module, q1: torch.nn.Module, q2: torch.nn.Module, params: VDIParams,
                 demo_states: torch.Tensor, demo_actions: Optional[torch.Tensor]):
        spec_env = make_env()
        state_dim = spec_env.state_dim
        action_dim = spec_env.action_dim
        super().__init__(state_dim, make_env, device, params)
        print(demo_states.shape)

        self._density_update_rate = params.burnin_density_update_rate
        self._online_q1 = q1
        self._density_model = density_model
        self._target_density_model = util.target_network(density_model)
        self._state_density_model = state_density_model
        self._target_state_density_model = util.target_network(state_density_model)
        self._online_policy = policy_net
        self._target_policy = util.target_network(policy_net)
        self._behavior_policy = gaussian.SphericalGaussianPolicy(
            device, action_dim, self._target_policy, fixed_noise=self._params.exploration_noise, eval_use_mean=True)

        self.spatial_smoothing = self._params.spatial_smoothing
        self._target_update_rate = 1
        self._update_count = -1
        self._demo_weights = np.ones((demo_states.shape[0], ))/demo_states.shape[0]

        self._density_optim = torch.optim.RMSprop(
            set(self._density_model.parameters()) | set(self._state_density_model.parameters()),
            lr=params.density_learning_rate, weight_decay=self._params.density_l2)
        self._critic = VDICritic(q1, q2, self._target_density_model, self._target_state_density_model, self._target_policy, params)
        self._actor_optim = torch.optim.RMSprop(self._online_policy.parameters(), lr=params.policy_learning_rate,
                                                weight_decay=self._params.policy_l2)
        self._critic_optim = torch.optim.RMSprop(self._critic.parameters, lr=params.critic_learning_rate,
                                                 weight_decay=params.critic_l2)
        if len(params.lr_decay_iterations) > 0:
            self._actor_scheduler = torch.optim.lr_scheduler.MultiStepLR(self._actor_optim, params.lr_decay_iterations,
                                                                        params.lr_decay_rate)
            self._critic_scheduler = torch.optim.lr_scheduler.MultiStepLR(self._critic_optim, params.lr_decay_iterations,
                                                                        params.lr_decay_rate)

        self._demo_states = demo_states
        self._demo_actions = demo_actions
        self._last_imagined_samples = None

        reporting.register_field("q_norm")
        reporting.register_field("policy_norm")
        reporting.register_field("density_loss")
        reporting.register_field("state_density_loss")
        reporting.register_field("valid_density_loss")
        reporting.register_field("actor_loss")
        reporting.register_field("bc_loss")
        reporting.register_field("td_loss")
        reporting.register_field("actor_lr")
        reporting.register_field("critic_lr")