Ejemplo n.º 1
0
 def __init__(
     self,
     feature_net: nn.Module,
     feature_dim: int,
     action_dim: int,
     hidden_sizes: Sequence[int] = (),
     device: Union[str, torch.device] = "cpu"
 ) -> None:
     super().__init__()
     self.feature_net = feature_net
     self.forward_model = MLP(
         feature_dim + action_dim,
         output_dim=feature_dim,
         hidden_sizes=hidden_sizes,
         device=device
     )
     self.inverse_model = MLP(
         feature_dim * 2,
         output_dim=action_dim,
         hidden_sizes=hidden_sizes,
         device=device
     )
     self.feature_dim = feature_dim
     self.action_dim = action_dim
     self.device = device
Ejemplo n.º 2
0
 def __init__(
     self,
     preprocess_net: nn.Module,
     action_shape: Sequence[int],
     hidden_sizes: Sequence[int] = (),
     max_action: float = 1.0,
     device: Union[str, int, torch.device] = "cpu",
     unbounded: bool = False,
     conditioned_sigma: bool = False,
     preprocess_net_output_dim: Optional[int] = None,
 ) -> None:
     super().__init__()
     self.preprocess = preprocess_net
     self.device = device
     self.output_dim = int(np.prod(action_shape))
     input_dim = getattr(preprocess_net, "output_dim",
                         preprocess_net_output_dim)
     self.mu = MLP(input_dim,
                   self.output_dim,
                   hidden_sizes,
                   device=self.device)
     self._c_sigma = conditioned_sigma
     if conditioned_sigma:
         self.sigma = MLP(input_dim,
                          self.output_dim,
                          hidden_sizes,
                          device=self.device)
     else:
         self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
     self._max = max_action
     self._unbounded = unbounded
Ejemplo n.º 3
0
def test_net():
    # here test the networks that does not appear in the other script
    bsz = 64
    # MLP
    data = torch.rand([bsz, 3])
    mlp = MLP(3, 6, hidden_sizes=[128])
    assert list(mlp(data).shape) == [bsz, 6]
    # output == 0 and len(hidden_sizes) == 0 means identity model
    mlp = MLP(6, 0)
    assert data.shape == mlp(data).shape
    # common net
    state_shape = (10, 2)
    action_shape = (5, )
    data = torch.rand([bsz, *state_shape])
    expect_output_shape = [bsz, *action_shape]
    net = Net(
        state_shape,
        action_shape,
        hidden_sizes=[128, 128],
        norm_layer=torch.nn.LayerNorm,
        activation=None
    )
    assert list(net(data)[0].shape) == expect_output_shape
    assert str(net).count("LayerNorm") == 2
    assert str(net).count("ReLU") == 0
    Q_param = V_param = {"hidden_sizes": [128, 128]}
    net = Net(
        state_shape,
        action_shape,
        hidden_sizes=[128, 128],
        dueling_param=(Q_param, V_param)
    )
    assert list(net(data)[0].shape) == expect_output_shape
    # concat
    net = Net(state_shape, action_shape, hidden_sizes=[128], concat=True)
    data = torch.rand([bsz, np.prod(state_shape) + np.prod(action_shape)])
    expect_output_shape = [bsz, 128]
    assert list(net(data)[0].shape) == expect_output_shape
    net = Net(
        state_shape,
        action_shape,
        hidden_sizes=[128],
        concat=True,
        dueling_param=(Q_param, V_param)
    )
    assert list(net(data)[0].shape) == expect_output_shape
    # recurrent actor/critic
    data = torch.rand([bsz, *state_shape]).flatten(1)
    expect_output_shape = [bsz, *action_shape]
    net = RecurrentActorProb(3, state_shape, action_shape)
    mu, sigma = net(data)[0]
    assert mu.shape == sigma.shape
    assert list(mu.shape) == [bsz, 5]
    net = RecurrentCritic(3, state_shape, action_shape)
    data = torch.rand([bsz, 8, np.prod(state_shape)])
    act = torch.rand(expect_output_shape)
    assert list(net(data, act).shape) == [bsz, 1]
Ejemplo n.º 4
0
 def __init__(
     self,
     preprocess_net: nn.Module,
     hidden_sizes: Sequence[int] = (),
     device: Union[str, int, torch.device] = "cpu",
     preprocess_net_output_dim: Optional[int] = None,
 ) -> None:
     super().__init__()
     self.device = device
     self.preprocess = preprocess_net
     self.output_dim = 1
     input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
     self.last = MLP(input_dim, 1, hidden_sizes, device=self.device)
Ejemplo n.º 5
0
 def __init__(
     self,
     preprocess_net: nn.Module,
     hidden_sizes: Sequence[int] = (),
     last_size: int = 1,
     preprocess_net_output_dim: Optional[int] = None,
 ) -> None:
     super().__init__()
     self.preprocess = preprocess_net
     self.output_dim = last_size
     input_dim = getattr(preprocess_net, "output_dim",
                         preprocess_net_output_dim)
     self.last = MLP(input_dim, last_size, hidden_sizes)
Ejemplo n.º 6
0
 def __init__(
     self,
     preprocess_net: nn.Module,
     action_shape: Sequence[int],
     hidden_sizes: Sequence[int] = (),
     softmax_output: bool = True,
     preprocess_net_output_dim: Optional[int] = None,
 ) -> None:
     super().__init__()
     self.preprocess = preprocess_net
     self.output_dim = np.prod(action_shape)
     input_dim = getattr(preprocess_net, "output_dim",
                         preprocess_net_output_dim)
     self.last = MLP(input_dim, self.output_dim, hidden_sizes)
     self.softmax_output = softmax_output
Ejemplo n.º 7
0
 def __init__(
     self,
     preprocess_net: nn.Module,
     action_shape: Sequence[int],
     hidden_sizes: Sequence[int] = (),
     max_action: float = 1.0,
     device: Union[str, int, torch.device] = "cpu",
     preprocess_net_output_dim: Optional[int] = None,
 ) -> None:
     super().__init__()
     self.device = device
     self.preprocess = preprocess_net
     self.output_dim = int(np.prod(action_shape))
     input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
     self.last = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
     self._max = max_action
Ejemplo n.º 8
0
    def __init__(self,
                 state_shape: Union[int, Sequence[int]],
                 action_shape: Union[int, Sequence[int]] = 0,
                 hidden_sizes: Sequence[int] = (256, 256),
                 activation: Optional[ModuleType] = nn.ReLU,
                 device: Union[str, int, torch.device] = "cpu") -> None:
        super().__init__()
        self.device = device
        input_dim = int(np.prod(state_shape))
        action_dim = int(np.prod(action_shape))

        self.net = MLP(input_dim,
                       action_dim,
                       hidden_sizes,
                       norm_layer=None,
                       activation=activation,
                       device=device)
        self.output_dim = self.net.output_dim
Ejemplo n.º 9
0
 def __init__(
     self,
     preprocess_net: nn.Module,  # bu bizim dilated net
     action_shape: Sequence[int],
     hidden_sizes: Sequence[int] = (),
     softmax_output: bool = True,
     preprocess_net_output_dim: Optional[int] = None,
     device: Union[str, int, torch.device] = "cpu",
 ) -> None:
     super().__init__()
     self.device = device
     self.preprocess = preprocess_net
     self.output_dim = int(np.prod(action_shape))
     input_dim = getattr(preprocess_net, "output_dim",
                         preprocess_net_output_dim)
     self.last = MLP(input_dim,
                     self.output_dim,
                     hidden_sizes,
                     device=self.device)
     self.softmax_output = softmax_output
Ejemplo n.º 10
0
 def __init__(
     self,
     preprocess_net: nn.Module,
     hidden_sizes: Sequence[int] = (),
     device: Union[str, int, torch.device] = "cpu",
     preprocess_net_output_dim: Optional[int] = None,
     linear_layer: Type[nn.Linear] = nn.Linear,
     flatten_input: bool = True,
 ) -> None:
     super().__init__()
     self.device = device
     self.preprocess = preprocess_net
     self.output_dim = 1
     input_dim = getattr(preprocess_net, "output_dim",
                         preprocess_net_output_dim)
     self.last = MLP(
         input_dim,  # type: ignore
         1,
         hidden_sizes,
         device=self.device,
         linear_layer=linear_layer,
         flatten_input=flatten_input,
     )
Ejemplo n.º 11
0
def test_bcq(args=get_args()):
    if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
        if args.load_buffer_name.endswith(".hdf5"):
            buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
        else:
            buffer = pickle.load(open(args.load_buffer_name, "rb"))
    else:
        buffer = gather_data()
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    args.max_action = env.action_space.high[0]  # float
    if args.reward_threshold is None:
        # too low?
        default_reward_threshold = {"Pendulum-v0": -1100, "Pendulum-v1": -1100}
        args.reward_threshold = default_reward_threshold.get(
            args.task, env.spec.reward_threshold
        )

    args.state_dim = args.state_shape[0]
    args.action_dim = args.action_shape[0]
    # test_envs = gym.make(args.task)
    test_envs = DummyVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.test_num)]
    )
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    test_envs.seed(args.seed)

    # model
    # perturbation network
    net_a = MLP(
        input_dim=args.state_dim + args.action_dim,
        output_dim=args.action_dim,
        hidden_sizes=args.hidden_sizes,
        device=args.device,
    )
    actor = Perturbation(
        net_a, max_action=args.max_action, device=args.device, phi=args.phi
    ).to(args.device)
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)

    net_c1 = Net(
        args.state_shape,
        args.action_shape,
        hidden_sizes=args.hidden_sizes,
        concat=True,
        device=args.device,
    )
    net_c2 = Net(
        args.state_shape,
        args.action_shape,
        hidden_sizes=args.hidden_sizes,
        concat=True,
        device=args.device,
    )
    critic1 = Critic(net_c1, device=args.device).to(args.device)
    critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
    critic2 = Critic(net_c2, device=args.device).to(args.device)
    critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

    # vae
    # output_dim = 0, so the last Module in the encoder is ReLU
    vae_encoder = MLP(
        input_dim=args.state_dim + args.action_dim,
        hidden_sizes=args.vae_hidden_sizes,
        device=args.device,
    )
    if not args.latent_dim:
        args.latent_dim = args.action_dim * 2
    vae_decoder = MLP(
        input_dim=args.state_dim + args.latent_dim,
        output_dim=args.action_dim,
        hidden_sizes=args.vae_hidden_sizes,
        device=args.device,
    )
    vae = VAE(
        vae_encoder,
        vae_decoder,
        hidden_dim=args.vae_hidden_sizes[-1],
        latent_dim=args.latent_dim,
        max_action=args.max_action,
        device=args.device,
    ).to(args.device)
    vae_optim = torch.optim.Adam(vae.parameters())

    policy = BCQPolicy(
        actor,
        actor_optim,
        critic1,
        critic1_optim,
        critic2,
        critic2_optim,
        vae,
        vae_optim,
        device=args.device,
        gamma=args.gamma,
        tau=args.tau,
        lmbda=args.lmbda,
    )

    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
        print("Loaded agent from: ", args.resume_path)

    # collector
    # buffer has been gathered
    # train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
    test_collector = Collector(policy, test_envs)
    # log
    t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
    log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq'
    log_path = os.path.join(args.logdir, args.task, 'bcq', log_file)
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))
    logger = TensorboardLogger(writer)

    def save_best_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    def stop_fn(mean_rewards):
        return mean_rewards >= args.reward_threshold

    def watch():
        policy.load_state_dict(
            torch.load(
                os.path.join(log_path, 'policy.pth'), map_location=torch.device('cpu')
            )
        )
        policy.eval()
        collector = Collector(policy, env)
        collector.collect(n_episode=1, render=1 / 35)

    # trainer
    result = offline_trainer(
        policy,
        buffer,
        test_collector,
        args.epoch,
        args.step_per_epoch,
        args.test_num,
        args.batch_size,
        save_best_fn=save_best_fn,
        stop_fn=stop_fn,
        logger=logger,
    )
    assert stop_fn(result['best_reward'])

    # Let's watch its performance!
    if __name__ == '__main__':
        pprint.pprint(result)
        env = gym.make(args.task)
        policy.eval()
        collector = Collector(policy, env)
        result = collector.collect(n_episode=1, render=args.render)
        rews, lens = result["rews"], result["lens"]
        print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
Ejemplo n.º 12
0
def test_ppo(args=get_args()):
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    if args.reward_threshold is None:
        default_reward_threshold = {"CartPole-v0": 195}
        args.reward_threshold = default_reward_threshold.get(
            args.task, env.spec.reward_threshold)
    # train_envs = gym.make(args.task)
    # you can also use tianshou.env.SubprocVectorEnv
    train_envs = DummyVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.training_num)])
    # test_envs = gym.make(args.task)
    test_envs = DummyVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # model
    net = Net(args.state_shape,
              hidden_sizes=args.hidden_sizes,
              device=args.device)
    actor = Actor(net, args.action_shape, device=args.device).to(args.device)
    critic = Critic(net, device=args.device).to(args.device)
    actor_critic = ActorCritic(actor, critic)
    # orthogonal initialization
    for m in actor_critic.modules():
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.orthogonal_(m.weight)
            torch.nn.init.zeros_(m.bias)
    optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
    dist = torch.distributions.Categorical
    policy = PPOPolicy(actor,
                       critic,
                       optim,
                       dist,
                       discount_factor=args.gamma,
                       max_grad_norm=args.max_grad_norm,
                       eps_clip=args.eps_clip,
                       vf_coef=args.vf_coef,
                       ent_coef=args.ent_coef,
                       gae_lambda=args.gae_lambda,
                       reward_normalization=args.rew_norm,
                       dual_clip=args.dual_clip,
                       value_clip=args.value_clip,
                       action_space=env.action_space,
                       deterministic_eval=True,
                       advantage_normalization=args.norm_adv,
                       recompute_advantage=args.recompute_adv)
    feature_dim = args.hidden_sizes[-1]
    feature_net = MLP(np.prod(args.state_shape),
                      output_dim=feature_dim,
                      hidden_sizes=args.hidden_sizes[:-1],
                      device=args.device)
    action_dim = np.prod(args.action_shape)
    icm_net = IntrinsicCuriosityModule(feature_net,
                                       feature_dim,
                                       action_dim,
                                       hidden_sizes=args.hidden_sizes[-1:],
                                       device=args.device).to(args.device)
    icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
    policy = ICMPolicy(policy, icm_net, icm_optim, args.lr_scale,
                       args.reward_scale, args.forward_loss_weight)
    # collector
    train_collector = Collector(
        policy, train_envs,
        VectorReplayBuffer(args.buffer_size, len(train_envs)))
    test_collector = Collector(policy, test_envs)
    # log
    log_path = os.path.join(args.logdir, args.task, 'ppo_icm')
    writer = SummaryWriter(log_path)
    logger = TensorboardLogger(writer)

    def save_best_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    def stop_fn(mean_rewards):
        return mean_rewards >= args.reward_threshold

    # trainer
    result = onpolicy_trainer(policy,
                              train_collector,
                              test_collector,
                              args.epoch,
                              args.step_per_epoch,
                              args.repeat_per_collect,
                              args.test_num,
                              args.batch_size,
                              step_per_collect=args.step_per_collect,
                              stop_fn=stop_fn,
                              save_best_fn=save_best_fn,
                              logger=logger)
    assert stop_fn(result['best_reward'])

    if __name__ == '__main__':
        pprint.pprint(result)
        # Let's watch its performance!
        env = gym.make(args.task)
        policy.eval()
        collector = Collector(policy, env)
        result = collector.collect(n_episode=1, render=args.render)
        rews, lens = result["rews"], result["lens"]
        print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
Ejemplo n.º 13
0
def test_dqn_icm(args=get_args()):
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    if args.reward_threshold is None:
        default_reward_threshold = {"CartPole-v0": 195}
        args.reward_threshold = default_reward_threshold.get(
            args.task, env.spec.reward_threshold)
    # train_envs = gym.make(args.task)
    # you can also use tianshou.env.SubprocVectorEnv
    train_envs = DummyVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.training_num)])
    # test_envs = gym.make(args.task)
    test_envs = DummyVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # Q_param = V_param = {"hidden_sizes": [128]}
    # model
    net = Net(
        args.state_shape,
        args.action_shape,
        hidden_sizes=args.hidden_sizes,
        device=args.device,
        # dueling=(Q_param, V_param),
    ).to(args.device)
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    policy = DQNPolicy(
        net,
        optim,
        args.gamma,
        args.n_step,
        target_update_freq=args.target_update_freq,
    )
    feature_dim = args.hidden_sizes[-1]
    feature_net = MLP(np.prod(args.state_shape),
                      output_dim=feature_dim,
                      hidden_sizes=args.hidden_sizes[:-1],
                      device=args.device)
    action_dim = np.prod(args.action_shape)
    icm_net = IntrinsicCuriosityModule(feature_net,
                                       feature_dim,
                                       action_dim,
                                       hidden_sizes=args.hidden_sizes[-1:],
                                       device=args.device).to(args.device)
    icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
    policy = ICMPolicy(policy, icm_net, icm_optim, args.lr_scale,
                       args.reward_scale, args.forward_loss_weight)
    # buffer
    if args.prioritized_replay:
        buf = PrioritizedVectorReplayBuffer(
            args.buffer_size,
            buffer_num=len(train_envs),
            alpha=args.alpha,
            beta=args.beta,
        )
    else:
        buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
    # collector
    train_collector = Collector(policy,
                                train_envs,
                                buf,
                                exploration_noise=True)
    test_collector = Collector(policy, test_envs, exploration_noise=True)
    # policy.set_eps(1)
    train_collector.collect(n_step=args.batch_size * args.training_num)
    # log
    log_path = os.path.join(args.logdir, args.task, 'dqn_icm')
    writer = SummaryWriter(log_path)
    logger = TensorboardLogger(writer)

    def save_best_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    def stop_fn(mean_rewards):
        return mean_rewards >= args.reward_threshold

    def train_fn(epoch, env_step):
        # eps annnealing, just a demo
        if env_step <= 10000:
            policy.set_eps(args.eps_train)
        elif env_step <= 50000:
            eps = args.eps_train - (env_step - 10000) / \
                40000 * (0.9 * args.eps_train)
            policy.set_eps(eps)
        else:
            policy.set_eps(0.1 * args.eps_train)

    def test_fn(epoch, env_step):
        policy.set_eps(args.eps_test)

    # trainer
    result = offpolicy_trainer(
        policy,
        train_collector,
        test_collector,
        args.epoch,
        args.step_per_epoch,
        args.step_per_collect,
        args.test_num,
        args.batch_size,
        update_per_step=args.update_per_step,
        train_fn=train_fn,
        test_fn=test_fn,
        stop_fn=stop_fn,
        save_best_fn=save_best_fn,
        logger=logger,
    )
    assert stop_fn(result['best_reward'])

    if __name__ == '__main__':
        pprint.pprint(result)
        # Let's watch its performance!
        env = gym.make(args.task)
        policy.eval()
        policy.set_eps(args.eps_test)
        collector = Collector(policy, env)
        result = collector.collect(n_episode=1, render=args.render)
        rews, lens = result["rews"], result["lens"]
        print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
Ejemplo n.º 14
0
def test_bcq():
    args = get_args()
    env = gym.make(args.task)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    args.max_action = env.action_space.high[0]  # float
    print("device:", args.device)
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    print("Action range:", np.min(env.action_space.low),
          np.max(env.action_space.high))

    args.state_dim = args.state_shape[0]
    args.action_dim = args.action_shape[0]
    print("Max_action", args.max_action)

    # test_envs = gym.make(args.task)
    test_envs = SubprocVectorEnv(
        [lambda: gym.make(args.task) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    test_envs.seed(args.seed)

    # model
    # perturbation network
    net_a = MLP(
        input_dim=args.state_dim + args.action_dim,
        output_dim=args.action_dim,
        hidden_sizes=args.hidden_sizes,
        device=args.device,
    )
    actor = Perturbation(net_a,
                         max_action=args.max_action,
                         device=args.device,
                         phi=args.phi).to(args.device)
    actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)

    net_c1 = Net(
        args.state_shape,
        args.action_shape,
        hidden_sizes=args.hidden_sizes,
        concat=True,
        device=args.device,
    )
    net_c2 = Net(
        args.state_shape,
        args.action_shape,
        hidden_sizes=args.hidden_sizes,
        concat=True,
        device=args.device,
    )
    critic1 = Critic(net_c1, device=args.device).to(args.device)
    critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
    critic2 = Critic(net_c2, device=args.device).to(args.device)
    critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

    # vae
    # output_dim = 0, so the last Module in the encoder is ReLU
    vae_encoder = MLP(
        input_dim=args.state_dim + args.action_dim,
        hidden_sizes=args.vae_hidden_sizes,
        device=args.device,
    )
    if not args.latent_dim:
        args.latent_dim = args.action_dim * 2
    vae_decoder = MLP(
        input_dim=args.state_dim + args.latent_dim,
        output_dim=args.action_dim,
        hidden_sizes=args.vae_hidden_sizes,
        device=args.device,
    )
    vae = VAE(
        vae_encoder,
        vae_decoder,
        hidden_dim=args.vae_hidden_sizes[-1],
        latent_dim=args.latent_dim,
        max_action=args.max_action,
        device=args.device,
    ).to(args.device)
    vae_optim = torch.optim.Adam(vae.parameters())

    policy = BCQPolicy(
        actor,
        actor_optim,
        critic1,
        critic1_optim,
        critic2,
        critic2_optim,
        vae,
        vae_optim,
        device=args.device,
        gamma=args.gamma,
        tau=args.tau,
        lmbda=args.lmbda,
    )

    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(
            torch.load(args.resume_path, map_location=args.device))
        print("Loaded agent from: ", args.resume_path)

    # collector
    test_collector = Collector(policy, test_envs)

    # log
    now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
    args.algo_name = "bcq"
    log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
    log_path = os.path.join(args.logdir, log_name)

    # logger
    if args.logger == "wandb":
        logger = WandbLogger(
            save_interval=1,
            name=log_name.replace(os.path.sep, "__"),
            run_id=args.resume_id,
            config=args,
            project=args.wandb_project,
        )
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))
    if args.logger == "tensorboard":
        logger = TensorboardLogger(writer)
    else:  # wandb
        logger.load(writer)

    def save_best_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

    def watch():
        if args.resume_path is None:
            args.resume_path = os.path.join(log_path, "policy.pth")

        policy.load_state_dict(
            torch.load(args.resume_path, map_location=torch.device("cpu")))
        policy.eval()
        collector = Collector(policy, env)
        collector.collect(n_episode=1, render=1 / 35)

    if not args.watch:
        dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))
        dataset_size = dataset["rewards"].size

        print("dataset_size", dataset_size)
        replay_buffer = ReplayBuffer(dataset_size)

        for i in range(dataset_size):
            replay_buffer.add(
                Batch(
                    obs=dataset["observations"][i],
                    act=dataset["actions"][i],
                    rew=dataset["rewards"][i],
                    done=dataset["terminals"][i],
                    obs_next=dataset["next_observations"][i],
                ))
        print("dataset loaded")
        # trainer
        result = offline_trainer(
            policy,
            replay_buffer,
            test_collector,
            args.epoch,
            args.step_per_epoch,
            args.test_num,
            args.batch_size,
            save_best_fn=save_best_fn,
            logger=logger,
        )
        pprint.pprint(result)
    else:
        watch()

    # Let's watch its performance!
    policy.eval()
    test_envs.seed(args.seed)
    test_collector.reset()
    result = test_collector.collect(n_episode=args.test_num,
                                    render=args.render)
    print(
        f"Final reward: {result['rews'].mean()}, length: {result['lens'].mean()}"
    )