Ejemplo n.º 1
0
def offpolicy_trainer(
    policy: BasePolicy,
    train_collector: Collector,
    test_collector: Collector,
    max_epoch: int,
    step_per_epoch: int,
    collect_per_step: int,
    episode_per_test: Union[int, List[int]],
    batch_size: int,
    update_per_step: int = 1,
    train_fn: Optional[Callable[[int], None]] = None,
    test_fn: Optional[Callable[[int], None]] = None,
    stop_fn: Optional[Callable[[float], bool]] = None,
    save_fn: Optional[Callable[[BasePolicy], None]] = None,
    log_fn: Optional[Callable[[dict], None]] = None,
    writer: Optional[SummaryWriter] = None,
    log_interval: int = 1,
    verbose: bool = True,
    # test_in_train: bool = True,
    test_in_train: bool = False,
) -> Dict[str, Union[float, str]]:
    """A wrapper for off-policy trainer procedure.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param train_collector: the collector used for training.
    :type train_collector: :class:`~tianshou.data.Collector`
    :param test_collector: the collector used for testing.
    :type test_collector: :class:`~tianshou.data.Collector`
    :param int max_epoch: the maximum of epochs for training. The training
        process might be finished before reaching the ``max_epoch``.
    :param int step_per_epoch: the number of step for updating policy network
        in one epoch.
    :param int collect_per_step: the number of frames the collector would
        collect before the network update. In other words, collect some frames
        and do some policy network update.
    :param episode_per_test: the number of episodes for one policy evaluation.
    :param int batch_size: the batch size of sample data, which is going to
        feed in the policy network.
    :param int update_per_step: the number of times the policy network would
        be updated after frames be collected. In other words, collect some
        frames and do some policy network update.
    :param function train_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of training in this
        epoch.
    :param function test_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of testing in this
        epoch.
    :param function save_fn: a function for saving policy when the undiscounted
        average mean reward in evaluation phase gets better.
    :param function stop_fn: a function receives the average undiscounted
        returns of the testing result, return a boolean which indicates whether
        reaching the goal.
    :param function log_fn: a function receives env info for logging.
    :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
        SummaryWriter.
    :param int log_interval: the log interval of the writer.
    :param bool verbose: whether to print the information.
    :param bool test_in_train: whether to test in the training phase.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    test_in_train = test_in_train and train_collector.policy == policy
    # change
    training_res = []
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                result = train_collector.collect(n_step=collect_per_step,
                                                 log_fn=log_fn)
                # data = {}
                # if test_in_train and stop_fn and stop_fn(result['rew']):
                #     test_result = test_episode(
                #         policy, test_collector, test_fn,
                #         epoch, episode_per_test)
                #     if stop_fn and stop_fn(test_result['rew']):
                #         if save_fn:
                #             save_fn(policy)
                #         for k in result.keys():
                #             data[k] = f'{result[k]:.2f}'
                #         t.set_postfix(**data)
                #         return gather_info(
                #             start_time, train_collector, test_collector,
                #             test_result['rew'])
                #     else:
                #         policy.train()
                #         if train_fn:
                #             train_fn(epoch)
                for i in range(update_per_step * min(
                        result['n/st'] // collect_per_step, t.total - t.n)):
                    global_step += 1
                    losses = policy.learn(train_collector.sample(batch_size))
                    # for k in result.keys():
                    #     data[k] = f'{result[k]:.2f}'
                    #     if writer and global_step % log_interval == 0:
                    #         writer.add_scalar(
                    #             k, result[k], global_step=global_step)
                    # for k in losses.keys():
                    #     if stat.get(k) is None:
                    #         stat[k] = MovAvg()
                    #     stat[k].add(losses[k])
                    #     data[k] = f'{stat[k].get():.6f}'
                    #     if writer and global_step % log_interval == 0:
                    #         writer.add_scalar(
                    #             k, stat[k].get(), global_step=global_step)
                    t.update(1)
                    # change
                    # t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        # change
        if epoch % 50 == 0:  #  or epoch < 2000:
            env = EnvFourUsers(step_per_epoch)
            # env.seed(0)
            policy.train(False)
            collector = Collector(policy, env)
            ep = 100
            result = collector.collect(n_episode=ep)
            # result = test_episode(
            #     policy, test_collector, test_fn, epoch, episode_per_test)
            if best_epoch == -1 or best_reward < result['rew']:
                best_reward = result['rew']
                best_epoch = epoch
                if save_fn:
                    save_fn(policy)
            # print(result)
            if verbose:
                # change
                print(
                    f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, ',
                    f'best_reward: {best_reward:.6f} in #{best_epoch},\n',
                    f'ty1_succ_rate_1: {result["ty1s_1"][0]/ep:.4f}, ',
                    f'ty1_succ_rate_2: {result["ty1s_2"][0]/ep:.4f},  \n',
                    f'ty1_succ_rate_3: {result["ty1s_3"][0]/ep:.4f}, ',
                    f'ty1_succ_rate_4: {result["ty1s_4"][0]/ep:.4f}, \n',
                    f'Q_len_1: {result["ql_1"][0]/ep:.4f},',
                    f'Q_len_2: {result["ql_2"][0]/ep:.4f}, \n',
                    f'Q_len_3: {result["ql_3"][0]/ep:.4f},',
                    f'Q_len_4: {result["ql_4"][0]/ep:.4f}, \n',
                    f'energy_effi_1: {result["ee_1"][0]/ep:.4f},',
                    f'energy_effi_2: {result["ee_2"][0]/ep:.4f},\n',
                    f'energy_effi_3: {result["ee_3"][0]/ep:.4f},',
                    f'energy_effi_4: {result["ee_4"][0]/ep:.4f}\n',
                    f'avg_rate: {result["avg_r"]/ep:.4f}, '
                    f'avg_power: {result["avg_p"]/ep:.4f} dBm\n')
            # change
            training_res.append([
                (result["ee_1"][0] / ep + result["ee_2"][0] / ep +
                 result["ee_3"][0] / ep + result["ee_4"][0] / ep) / 4,
                (result["ty1s_1"][0] / ep + result["ty1s_2"][0] / ep +
                 result["ty1s_3"][0] / ep + result["ty1s_4"][0] / ep) / 4,
                (result["ql_1"][0] / ep + result["ql_2"][0] / ep +
                 result["ql_3"][0] / ep + result["ql_4"][0] / ep) / 4,
                result["rew"]
            ])
        if stop_fn and stop_fn(best_reward):
            break
    # change
    training_res = np.array(training_res)
    wb = Workbook()
    ws = wb.active
    ws.title = 'training result'
    ws['A1'] = 'testing num'
    ws['B1'] = 'energy efficiency'
    ws['C1'] = 'type 1 success rate'
    ws['D1'] = 'type 2 q length'
    ws['E1'] = 'return'
    for i in range(training_res.shape[0]):
        ws.cell(i + 2, 1).value = i + 1
        ws.cell(i + 2, 2).value = training_res[i, 0]
        ws.cell(i + 2, 3).value = training_res[i, 1]
        ws.cell(i + 2, 4).value = training_res[i, 2]
        ws.cell(i + 2, 5).value = training_res[i, 3]
    wb.save("directly_training_slot" + str(step_per_epoch) + ".xlsx")
    test_collector.collect_time = -1
    return gather_info(start_time, train_collector, test_collector,
                       best_reward)
Ejemplo n.º 2
0
def onpolicy_trainer(
    policy: BasePolicy,
    train_collector: Collector,
    test_collector: Collector,
    max_epoch: int,
    step_per_epoch: int,
    collect_per_step: int,
    repeat_per_collect: int,
    episode_per_test: Union[int, List[int]],
    batch_size: int,
    train_fn: Optional[Callable[[int, int], None]] = None,
    test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
    stop_fn: Optional[Callable[[float], bool]] = None,
    save_fn: Optional[Callable[[BasePolicy], None]] = None,
    writer: Optional[SummaryWriter] = None,
    log_interval: int = 1,
    verbose: bool = True,
    test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
    """A wrapper for on-policy trainer procedure.

    The "step" in trainer means a policy network update.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param train_collector: the collector used for training.
    :type train_collector: :class:`~tianshou.data.Collector`
    :param test_collector: the collector used for testing.
    :type test_collector: :class:`~tianshou.data.Collector`
    :param int max_epoch: the maximum of epochs for training. The training
        process might be finished before reaching the ``max_epoch``.
    :param int step_per_epoch: the number of step for updating policy network
        in one epoch.
    :param int collect_per_step: the number of episodes the collector would
        collect before the network update. In other words, collect some
        episodes and do one policy network update.
    :param int repeat_per_collect: the number of repeat time for policy
        learning, for example, set it to 2 means the policy needs to learn each
        given batch data twice.
    :param episode_per_test: the number of episodes for one policy evaluation.
    :type episode_per_test: int or list of ints
    :param int batch_size: the batch size of sample data, which is going to
        feed in the policy network.
    :param function train_fn: a function receives the current number of epoch
        and step index, and performs some operations at the beginning of
        training in this poch.
    :param function test_fn: a function receives the current number of epoch
        and step index, and performs some operations at the beginning of
        testing in this epoch.
    :param function save_fn: a function for saving policy when the undiscounted
        average mean reward in evaluation phase gets better.
    :param function stop_fn: a function receives the average undiscounted
        returns of the testing result, return a boolean which indicates whether
        reaching the goal.
    :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
        SummaryWriter.
    :param int log_interval: the log interval of the writer.
    :param bool verbose: whether to print the information.
    :param bool test_in_train: whether to test in the training phase.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    env_step, gradient_step = 0, 0
    best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
    stat: Dict[str, MovAvg] = {}
    start_time = time.time()
    train_collector.reset_stat()
    test_collector.reset_stat()
    test_in_train = test_in_train and train_collector.policy == policy
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        with tqdm.tqdm(
            total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
        ) as t:
            while t.n < t.total:
                if train_fn:
                    train_fn(epoch, env_step)
                result = train_collector.collect(n_episode=collect_per_step)
                env_step += int(result["n/st"])
                data = {
                    "env_step": str(env_step),
                    "rew": f"{result['rew']:.2f}",
                    "len": str(int(result["len"])),
                    "n/ep": str(int(result["n/ep"])),
                    "n/st": str(int(result["n/st"])),
                    "v/ep": f"{result['v/ep']:.2f}",
                    "v/st": f"{result['v/st']:.2f}",
                }
                if writer and env_step % log_interval == 0:
                    for k in result.keys():
                        writer.add_scalar(
                            "train/" + k, result[k], global_step=env_step)
                if test_in_train and stop_fn and stop_fn(result["rew"]):
                    test_result = test_episode(
                        policy, test_collector, test_fn,
                        epoch, episode_per_test, writer, env_step)
                    if stop_fn(test_result["rew"]):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f"{result[k]:.2f}"
                        t.set_postfix(**data)
                        return gather_info(
                            start_time, train_collector, test_collector,
                            test_result["rew"], test_result["rew_std"])
                    else:
                        policy.train()
                losses = policy.update(
                    0, train_collector.buffer,
                    batch_size=batch_size, repeat=repeat_per_collect)
                train_collector.reset_buffer()
                step = max([1] + [
                    len(v) for v in losses.values() if isinstance(v, list)])
                gradient_step += step
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f"{stat[k].get():.6f}"
                    if writer and gradient_step % log_interval == 0:
                        writer.add_scalar(
                            k, stat[k].get(), global_step=gradient_step)
                t.update(step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test, writer, env_step)
        if best_epoch == -1 or best_reward < result["rew"]:
            best_reward, best_reward_std = result["rew"], result["rew_std"]
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± "
                  f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± "
                  f"{best_reward_std:.6f} in #{best_epoch}")
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector,
                       best_reward, best_reward_std)
Ejemplo n.º 3
0
def onpolicy_trainer(
        policy: BasePolicy,
        train_collector: Collector,
        test_collector: Collector,
        max_epoch: int,
        frame_per_epoch: int,
        collect_per_step: int,
        repeat_per_collect: int,
        episode_per_test: Union[int, List[int]],
        batch_size: int,
        train_fn: Optional[Callable[[int], None]] = None,
        test_fn: Optional[Callable[[int], None]] = None,
        stop_fn: Optional[Callable[[float], bool]] = None,
        save_fn: Optional[Callable[[BasePolicy], None]] = None,
        log_fn: Optional[Callable[[dict], None]] = None,
        writer: Optional[SummaryWriter] = None,
        log_interval: int = 1,
        verbose: bool = True,
        **kwargs
) -> Dict[str, Union[float, str]]:
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    test_in_train = train_collector.policy == policy
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=frame_per_epoch, desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                result = train_collector.collect(n_step=collect_per_step,
                                                 log_fn=log_fn)
                data = {}
                if test_in_train and stop_fn and stop_fn(result['rew']):
                    test_result = test_episode(
                        policy, test_collector, test_fn,
                        epoch, episode_per_test)
                    if stop_fn and stop_fn(test_result['rew']):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f'{result[k]:.2f}'
                        t.set_postfix(**data)
                        return gather_info(
                            start_time, train_collector, test_collector,
                            test_result['rew'])
                    else:
                        policy.train()
                        if train_fn:
                            train_fn(epoch)
                losses = policy.learn(
                    train_collector.sample(0), batch_size, repeat_per_collect)
                train_collector.reset_buffer()
                global_step += collect_per_step
                for k in result.keys():
                    data[k] = f'{result[k]:.2f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(
                            k, result[k], global_step=global_step)
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f'{stat[k].get():.6f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(
                            k, stat[k].get(), global_step=global_step)
                t.update(collect_per_step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(
            policy, test_collector, test_fn, epoch, episode_per_test)
        if best_epoch == -1 or best_reward < result['rew']:
            best_reward = result['rew']
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
                  f'best_reward: {best_reward:.6f} in #{best_epoch}')
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(
        start_time, train_collector, test_collector, best_reward)
Ejemplo n.º 4
0
def offpolicy_trainer(policy: BasePolicy,
                      train_collector: Collector,
                      test_collector: Collector,
                      max_epoch: int,
                      step_per_epoch: int,
                      collect_per_step: int,
                      episode_per_test: Union[int, List[int]],
                      batch_size: int,
                      update_per_step: int = 1,
                      train_fn: Optional[Callable[[int], None]] = None,
                      test_fn: Optional[Callable[[int], None]] = None,
                      stop_fn: Optional[Callable[[float], bool]] = None,
                      save_fn: Optional[Callable[[BasePolicy], None]] = None,
                      log_fn: Optional[Callable[[dict], None]] = None,
                      writer: Optional[SummaryWriter] = None,
                      log_interval: int = 1,
                      verbose: bool = True,
                      **kwargs) -> Dict[str, Union[float, str]]:
    """A wrapper for off-policy trainer procedure.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param train_collector: the collector used for training.
    :type train_collector: :class:`~tianshou.data.Collector`
    :param test_collector: the collector used for testing.
    :type test_collector: :class:`~tianshou.data.Collector`
    :param int max_epoch: the maximum of epochs for training. The training
        process might be finished before reaching the ``max_epoch``.
    :param int step_per_epoch: the number of step for updating policy network
        in one epoch.
    :param int collect_per_step: the number of frames the collector would
        collect before the network update. In other words, collect some frames
        and do some policy network update.
    :param episode_per_test: the number of episodes for one policy evaluation.
    :param int batch_size: the batch size of sample data, which is going to
        feed in the policy network.
    :param int update_per_step: the number of times the policy network would
        be updated after frames be collected. In other words, collect some
        frames and do some policy network update.
    :param function train_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of training in this
        epoch.
    :param function test_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of testing in this
        epoch.
    :param function save_fn: a function for saving policy when the undiscounted
        average mean reward in evaluation phase gets better.
    :param function stop_fn: a function receives the average undiscounted
        returns of the testing result, return a boolean which indicates whether
        reaching the goal.
    :param function log_fn: a function receives env info for logging.
    :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
        SummaryWriter.
    :param int log_interval: the log interval of the writer.
    :param bool verbose: whether to print the information.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    test_in_train = train_collector.policy == policy
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                result = train_collector.collect(n_step=collect_per_step,
                                                 log_fn=log_fn)
                data = {}
                if test_in_train and stop_fn and stop_fn(result['rew']):
                    test_result = test_episode(policy, test_collector, test_fn,
                                               epoch, episode_per_test)
                    if stop_fn and stop_fn(test_result['rew']):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f'{result[k]:.2f}'
                        t.set_postfix(**data)
                        return gather_info(start_time, train_collector,
                                           test_collector, test_result['rew'])
                    else:
                        policy.train()
                        if train_fn:
                            train_fn(epoch)
                for i in range(update_per_step * min(
                        result['n/st'] // collect_per_step, t.total - t.n)):
                    global_step += 1
                    losses = policy.learn(train_collector.sample(batch_size))
                    for k in result.keys():
                        data[k] = f'{result[k]:.2f}'
                        if writer and global_step % log_interval == 0:
                            writer.add_scalar(k,
                                              result[k],
                                              global_step=global_step)
                    for k in losses.keys():
                        if stat.get(k) is None:
                            stat[k] = MovAvg()
                        stat[k].add(losses[k])
                        data[k] = f'{stat[k].get():.6f}'
                        if writer and global_step % log_interval == 0:
                            writer.add_scalar(k,
                                              stat[k].get(),
                                              global_step=global_step)
                    t.update(1)
                    t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test)
        if best_epoch == -1 or best_reward < result['rew']:
            best_reward = result['rew']
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
                  f'best_reward: {best_reward:.6f} in #{best_epoch}')
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector,
                       best_reward)
Ejemplo n.º 5
0
def onpolicy_trainer(policy,
                     train_collector,
                     test_collector,
                     max_epoch,
                     step_per_epoch,
                     collect_per_step,
                     repeat_per_collect,
                     episode_per_test,
                     batch_size,
                     train_fn=None,
                     test_fn=None,
                     stop_fn=None,
                     save_fn=None,
                     log_fn=None,
                     writer=None,
                     log_interval=1,
                     verbose=True,
                     task='',
                     **kwargs):
    """A wrapper for on-policy trainer procedure.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param train_collector: the collector used for training.
    :type train_collector: :class:`~tianshou.data.Collector`
    :param test_collector: the collector used for testing.
    :type test_collector: :class:`~tianshou.data.Collector`
    :param int max_epoch: the maximum of epochs for training. The training
        process might be finished before reaching the ``max_epoch``.
    :param int step_per_epoch: the number of step for updating policy network
        in one epoch.
    :param int collect_per_step: the number of frames the collector would
        collect before the network update. In other words, collect some frames
        and do one policy network update.
    :param int repeat_per_collect: the number of repeat time for policy
        learning, for example, set it to 2 means the policy needs to learn each
        given batch data twice.
    :param episode_per_test: the number of episodes for one policy evaluation.
    :type episode_per_test: int or list of ints
    :param int batch_size: the batch size of sample data, which is going to
        feed in the policy network.
    :param function train_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of training in this
        epoch.
    :param function test_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of testing in this
        epoch.
    :param function save_fn: a function for saving policy when the undiscounted
        average mean reward in evaluation phase gets better.
    :param function stop_fn: a function receives the average undiscounted
        returns of the testing result, return a boolean which indicates whether
        reaching the goal.
    :param function log_fn: a function receives env info for logging.
    :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
        SummaryWriter.
    :param int log_interval: the log interval of the writer.
    :param bool verbose: whether to print the information.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                result = train_collector.collect(n_episode=collect_per_step,
                                                 log_fn=log_fn)
                data = {}
                if stop_fn and stop_fn(result['rew']):
                    test_result = test_episode(policy, test_collector, test_fn,
                                               epoch, episode_per_test)
                    if stop_fn and stop_fn(test_result['rew']):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f'{result[k]:.2f}'
                        t.set_postfix(**data)
                        return gather_info(start_time, train_collector,
                                           test_collector, test_result['rew'])
                    else:
                        policy.train()
                        if train_fn:
                            train_fn(epoch)
                losses = policy.learn(train_collector.sample(0), batch_size,
                                      repeat_per_collect)
                train_collector.reset_buffer()
                step = 1
                for k in losses.keys():
                    if isinstance(losses[k], list):
                        step = max(step, len(losses[k]))
                global_step += step
                for k in result.keys():
                    data[k] = f'{result[k]:.2f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(k + '_' + task if task else k,
                                          result[k],
                                          global_step=global_step)
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f'{stat[k].get():.6f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(k + '_' + task if task else k,
                                          stat[k].get(),
                                          global_step=global_step)
                t.update(step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test)
        if best_epoch == -1 or best_reward < result['rew']:
            best_reward = result['rew']
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
                  f'best_reward: {best_reward:.6f} in #{best_epoch}')
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector,
                       best_reward)
Ejemplo n.º 6
0
def offpolicy_trainer(policy,
                      train_collector,
                      test_collector,
                      max_epoch,
                      step_per_epoch,
                      collect_per_step,
                      episode_per_test,
                      batch_size,
                      train_fn=None,
                      stop_fn=None,
                      save_fn=None,
                      test_in_train=False,
                      writer=None,
                      log_interval=10,
                      verbose=True,
                      **kwargs):
    """A wrapper for off-policy trainer procedure.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param train_collector: the collector used for training.
    :type train_collector: :class:`~tianshou.data.Collector`
    :param test_collector: the collector used for testing.
    :type test_collector: :class:`~tianshou.data.Collector`
    :param int max_epoch: the maximum of epochs for training. The training
        process might be finished before reaching the ``max_epoch``.
    :param int step_per_epoch: the number of step for updating policy network
        in one epoch.
    :param int collect_per_step: the number of frames the collector would
        collect before the network update. In other words, collect some frames
        and do one policy network update.
    :param episode_per_test: the number of episodes for one policy evaluation.
    :param int batch_size: the batch size of sample data, which is going to
        feed in the policy network.
    :param function train_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of training in this
        epoch.
    :param function test_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of testing in this
        epoch.
    :param function save_fn: a function for saving policy when the undiscounted
        average mean reward in evaluation phase gets better.
    :param function log_fn: a function receives env info for logging.
    :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
        SummaryWriter.
    :param int log_interval: the log interval of the writer.
    :param bool verbose: whether to print the information.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = collections.defaultdict(lambda: collections.deque([], maxlen=5))
    start_time = time.time()
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', **tqdm_config) as t:
            while t.n < t.total:
                result = train_collector.collect(n_step=collect_per_step)
                data = {}
                if test_in_train and stop_fn and stop_fn(result['ep/reward']):
                    test_result = test_episode(policy, test_collector, episode_per_test)
                    if stop_fn and stop_fn(test_result['ep/reward']):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f'{result[k]:.2f}'
                        t.set_postfix(**data)
                        return gather_info(start_time, train_collector, test_collector, test_result['ep/reward'])
                    else:
                        policy.train()
                        if train_fn:
                            train_fn(epoch)
                for i in range(min(result['n/st'] // collect_per_step, t.total - t.n)):
                    global_step += 1
                    batch = policy.state_encode(train_collector.sample(batch_size))
                    losses = policy.learn(batch)
                    for k in result.keys():
                        if not k[0] in ['v', 'n']:
                            data[k] = f'{result[k]:.1f}'
                    for k in losses.keys():
                        stat[k].append(losses[k])
                        if not k[0] in ['g']:
                            data[k] = f'{np.nanmean(stat[k]):.1f}'
                    if writer and global_step % log_interval == 0:
                        for k in result.keys():
                            writer.add_scalar(k, result[k], global_step=global_step)
                        for k in losses.keys():
                            writer.add_scalar(k, np.nanmean(stat[k]), global_step=global_step)
                    t.update(1)
                    t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(policy, test_collector, episode_per_test)
        writer.add_scalar('test/reward', result['ep/reward'], global_step)
        if best_epoch == -1 or best_reward < result['ep/reward']:
            best_reward = result['ep/reward']
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f'Epoch #{epoch}: test_reward: {result["ep/reward"]:.6f}, '
                  f'best_reward: {best_reward:.6f} in #{best_epoch}')
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector, best_reward)
Ejemplo n.º 7
0
def offpolicy_trainer(policy,
                      train_collector,
                      test_collector,
                      max_epoch,
                      step_per_epoch,
                      collect_per_step,
                      episode_per_test,
                      batch_size,
                      train_fn=None,
                      test_fn=None,
                      stop_fn=None,
                      writer=None,
                      log_interval=1,
                      verbose=True,
                      task=''):
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                result = train_collector.collect(n_step=collect_per_step)
                data = {}
                if stop_fn and stop_fn(result['rew']):
                    test_result = test_episode(policy, test_collector, test_fn,
                                               epoch, episode_per_test)
                    if stop_fn and stop_fn(test_result['rew']):
                        for k in result.keys():
                            data[k] = f'{result[k]:.2f}'
                        t.set_postfix(**data)
                        return gather_info(start_time, train_collector,
                                           test_collector, test_result['rew'])
                    else:
                        policy.train()
                        if train_fn:
                            train_fn(epoch)
                for i in range(
                        min(result['n/st'] // collect_per_step,
                            t.total - t.n)):
                    global_step += 1
                    losses = policy.learn(train_collector.sample(batch_size))
                    for k in result.keys():
                        data[k] = f'{result[k]:.2f}'
                        if writer and global_step % log_interval == 0:
                            writer.add_scalar(k + '_' + task if task else k,
                                              result[k],
                                              global_step=global_step)
                    for k in losses.keys():
                        if stat.get(k) is None:
                            stat[k] = MovAvg()
                        stat[k].add(losses[k])
                        data[k] = f'{stat[k].get():.6f}'
                        if writer and global_step % log_interval == 0:
                            writer.add_scalar(k + '_' + task if task else k,
                                              stat[k].get(),
                                              global_step=global_step)
                    t.update(1)
                    t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test)
        if best_epoch == -1 or best_reward < result['rew']:
            best_reward = result['rew']
            best_epoch = epoch
        if verbose:
            print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
                  f'best_reward: {best_reward:.6f} in #{best_epoch}')
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector,
                       best_reward)
Ejemplo n.º 8
0
def imitation_trainer(policy, learner, expert_collector, test_collector,
                      max_epoch, step_per_epoch, collect_per_step,
                      repeat_per_collect, episode_per_test, batch_size,
                      train_fn=None, test_fn=None, stop_fn=None,
                      writer=None, task='', peer=0, peer_decay_steps=0):
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(
                total=step_per_epoch, desc=f'Epoch #{epoch}',
                **tqdm_config) as t:
            while t.n < t.total:
                expert_collector.collect(n_episode=collect_per_step)
                result = test_collector.collect(n_episode=episode_per_test)

                data = {}
                if stop_fn and stop_fn(result['rew']):
                    for k in result.keys():
                        data[k] = f'{result[k]:.2f}'
                    t.set_postfix(**data)
                    return gather_info(
                        start_time, expert_collector, test_collector,
                        result['rew'])
                else:
                    policy.train()
                    if train_fn:
                        train_fn(epoch)

                decay = 1. if not peer_decay_steps else \
                    max(0., 1 - global_step / peer_decay_steps)
                losses = learner(policy, expert_collector.sample(0),
                                 batch_size, repeat_per_collect, peer * decay)
                expert_collector.reset_buffer()
                step = 1
                for k in losses.keys():
                    if isinstance(losses[k], list):
                        step = max(step, len(losses[k]))
                global_step += step
                for k in result.keys():
                    data[k] = f'{result[k]:.2f}'
                    if writer:
                        writer.add_scalar(
                            k + '_' + task if task else k,
                            result[k], global_step=global_step)
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f'{stat[k].get():.6f}'
                    if writer and global_step:
                        writer.add_scalar(
                            k + '_' + task if task else k,
                            stat[k].get(), global_step=global_step)
                t.update(step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(
            policy, test_collector, test_fn, epoch, episode_per_test)
        if best_epoch == -1 or best_reward < result['rew']:
            best_reward = result['rew']
            best_epoch = epoch
        print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
              f'best_reward: {best_reward:.6f} in #{best_epoch}')
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(
        start_time, expert_collector, test_collector, best_reward)
Ejemplo n.º 9
0
def onpolicy_trainer(
        policy: BasePolicy,
        train_collector: Collector,
        test_collector: Collector,
        max_epoch: int,
        frame_per_epoch: int,
        collect_per_step: int,
        repeat_per_collect: int,
        episode_per_test: Union[int, Sequence[int]],
        batch_size: int,
        train_fn: Optional[Callable[[int, int], None]] = None,
        test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
        stop_fn: Optional[Callable[[float], bool]] = None,
        save_fn: Optional[Callable[[BasePolicy], None]] = None,
        writer: Optional[SummaryWriter] = None,
        log_interval: int = 1,
        verbose: bool = True,
        test_in_train: bool = True,
        **kwargs) -> Dict[str, Union[float, str]]:
    """Slightly modified Tianshou `onpolicy_trainer` original method to enable
    to define the maximum number of training steps instead of number of
    episodes, for consistency with other learning frameworks.
    """
    global_step = 0
    best_epoch, best_reward = -1, -1.0
    stat: Dict[str, MovAvg] = {}
    start_time = time.time()
    train_collector.reset_stat()
    test_collector.reset_stat()
    test_in_train = test_in_train and train_collector.policy == policy
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        with tqdm.tqdm(
            total=frame_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
        ) as t:
            while t.n < t.total:
                if train_fn:
                    train_fn(epoch, global_step)
                result = train_collector.collect(n_step=collect_per_step)
                data = {}
                if test_in_train and stop_fn and stop_fn(result["rew"]):
                    test_result = test_episode(
                        policy, test_collector, test_fn,
                        epoch, episode_per_test, writer, global_step)
                    if stop_fn(test_result["rew"]):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f"{result[k]:.2f}"
                        t.set_postfix(**data)
                        return gather_info(
                            start_time, train_collector, test_collector,
                            test_result["rew"])
                    else:
                        policy.train()
                losses = policy.update(
                    0, train_collector.buffer,
                    batch_size=batch_size, repeat=repeat_per_collect)
                train_collector.reset_buffer()
                step = 1
                for v in losses.values():
                    if isinstance(v, (list, tuple)):
                        step = max(step, len(v))
                global_step += step * collect_per_step
                for k in result.keys():
                    data[k] = f"{result[k]:.2f}"
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(
                            "train/" + k, result[k], global_step=global_step)
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f"{stat[k].get():.6f}"
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(
                            k, stat[k].get(), global_step=global_step)
                t.update(collect_per_step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test, writer, global_step)
        if best_epoch == -1 or best_reward < result["rew"]:
            best_reward = result["rew"]
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, "
                  f"best_reward: {best_reward:.6f} in #{best_epoch}")
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(
        start_time, train_collector, test_collector, best_reward)
Ejemplo n.º 10
0
def offpolicy_trainer_with_views(A,
                                 B,
                                 max_epoch,
                                 step_per_epoch,
                                 collect_per_step,
                                 episode_per_test,
                                 batch_size,
                                 copier=False,
                                 peer=0.,
                                 verbose=True,
                                 test_fn=None,
                                 task=''):
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()

    for epoch in range(1, 1 + max_epoch):
        # train
        A.train()
        B.train()
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                for view, other_view in zip([A, B], [B, A]):
                    result = view.train_collector.collect(
                        n_step=collect_per_step)
                    data = {}
                    if view.stop_fn(result['rew']):
                        test_result = test_episode(view.policy,
                                                   view.test_collector,
                                                   test_fn, epoch,
                                                   episode_per_test)
                        if view.stop_fn(test_result['rew']):
                            for k in result.keys():
                                data[k] = f'{result[k]:.2f}'
                            t.set_postfix(**data)
                            return gather_info(start_time,
                                               view.train_collector,
                                               view.test_collector,
                                               test_result['rew'])
                        else:
                            view.policy.train()
                    for i in range(
                            min(result['n/st'] // collect_per_step,
                                t.total - t.n)):
                        global_step += 1
                        batch = view.train_collector.sample(batch_size)
                        losses = view.policy.learn(batch)

                        # Learn from demonstration
                        if copier:
                            demo = other_view.policy(batch)
                            view.learn_from_demos(batch, demo, peer=peer)

                        for k in result.keys():
                            data[k] = f'{result[k]:.2f}'
                            view.writer.add_scalar(k + '_' +
                                                   task if task else k,
                                                   result[k],
                                                   global_step=global_step)
                        for k in losses.keys():
                            if stat.get(k) is None:
                                stat[k] = MovAvg()
                            stat[k].add(losses[k])
                            data[k] = f'{stat[k].get():.4f}'
                            view.writer.add_scalar(k + '_' +
                                                   task if task else k,
                                                   stat[k].get(),
                                                   global_step=global_step)
                        t.update(1)
                        t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        brk = False
        for view in A, B:
            result = test_episode(view.policy, view.test_collector, test_fn,
                                  epoch, episode_per_test)
            if best_epoch == -1 or best_reward < result['rew']:
                best_reward = result['rew']
                best_epoch = epoch
            if verbose:
                print(f'Epoch #{epoch}: test_reward: {result["rew"]:.4f}, '
                      f'best_reward: {best_reward:.4f} in #{best_epoch}')
            if view.stop_fn(best_reward):
                brk = True
        if brk:
            break
    return (
        gather_info(start_time, A.train_collector, A.test_collector,
                    best_reward),
        gather_info(start_time, B.train_collector, B.test_collector,
                    best_reward),
    )
Ejemplo n.º 11
0
def offpolicy_trainer(
    policy: BasePolicy,
    train_collector: Collector,
    test_collector: Collector,
    max_epoch: int,
    step_per_epoch: int,
    step_per_collect: int,
    episode_per_test: int,
    batch_size: int,
    update_per_step: Union[int, float] = 1,
    train_fn: Optional[Callable[[int, int], None]] = None,
    test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
    stop_fn: Optional[Callable[[float], bool]] = None,
    save_fn: Optional[Callable[[BasePolicy], None]] = None,
    save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
    resume_from_log: bool = False,
    reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
    logger: BaseLogger = LazyLogger(),
    verbose: bool = True,
    test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
    """A wrapper for off-policy trainer procedure.

    The "step" in trainer means an environment step (a.k.a. transition).

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
    :param Collector train_collector: the collector used for training.
    :param Collector test_collector: the collector used for testing.
    :param int max_epoch: the maximum number of epochs for training. The training
        process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
    :param int step_per_epoch: the number of transitions collected per epoch.
    :param int step_per_collect: the number of transitions the collector would collect
        before the network update, i.e., trainer will collect "step_per_collect"
        transitions and do some policy network update repeatly in each epoch.
    :param episode_per_test: the number of episodes for one policy evaluation.
    :param int batch_size: the batch size of sample data, which is going to feed in the
        policy network.
    :param int/float update_per_step: the number of times the policy network would be
        updated per transition after (step_per_collect) transitions are collected,
        e.g., if update_per_step set to 0.3, and step_per_collect is 256, policy will
        be updated round(256 * 0.3 = 76.8) = 77 times after 256 transitions are
        collected by the collector. Default to 1.
    :param function train_fn: a hook called at the beginning of training in each epoch.
        It can be used to perform custom additional operations, with the signature ``f(
        num_epoch: int, step_idx: int) -> None``.
    :param function test_fn: a hook called at the beginning of testing in each epoch.
        It can be used to perform custom additional operations, with the signature ``f(
        num_epoch: int, step_idx: int) -> None``.
    :param function save_fn: a hook called when the undiscounted average mean reward in
        evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
        None``.
    :param function save_checkpoint_fn: a function to save training process, with the
        signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
        save whatever you want.
    :param bool resume_from_log: resume env_step/gradient_step and other metadata from
        existing tensorboard log. Default to False.
    :param function stop_fn: a function with signature ``f(mean_rewards: float) ->
        bool``, receives the average undiscounted returns of the testing result,
        returns a boolean which indicates whether reaching the goal.
    :param function reward_metric: a function with signature ``f(rewards: np.ndarray
        with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
        used in multi-agent RL. We need to return a single scalar for each episode's
        result to monitor training in the multi-agent RL setting. This function
        specifies what is the desired metric, e.g., the reward of agent 1 or the
        average reward over all agents.
    :param BaseLogger logger: A logger that logs statistics during
        training/testing/updating. Default to a logger that doesn't log anything.
    :param bool verbose: whether to print the information. Default to True.
    :param bool test_in_train: whether to test in the training phase. Default to True.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    if save_fn:
        warnings.warn("Please consider using save_checkpoint_fn instead of save_fn.")

    start_epoch, env_step, gradient_step = 0, 0, 0
    if resume_from_log:
        start_epoch, env_step, gradient_step = logger.restore_data()
    last_rew, last_len = 0.0, 0
    stat: Dict[str, MovAvg] = defaultdict(MovAvg)
    start_time = time.time()
    train_collector.reset_stat()
    test_collector.reset_stat()
    test_in_train = test_in_train and train_collector.policy == policy
    test_result = test_episode(policy, test_collector, test_fn, start_epoch,
                               episode_per_test, logger, env_step, reward_metric)
    best_epoch = start_epoch
    best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]

    for epoch in range(1 + start_epoch, 1 + max_epoch):
        # train
        policy.train()
        with tqdm.tqdm(
            total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
        ) as t:
            while t.n < t.total:
                if train_fn:
                    train_fn(epoch, env_step)
                result = train_collector.collect(n_step=step_per_collect)
                if result["n/ep"] > 0 and reward_metric:
                    result["rews"] = reward_metric(result["rews"])
                env_step += int(result["n/st"])
                t.update(result["n/st"])
                logger.log_train_data(result, env_step)
                last_rew = result['rew'] if 'rew' in result else last_rew
                last_len = result['len'] if 'len' in result else last_len
                data = {
                    "env_step": str(env_step),
                    "rew": f"{last_rew:.2f}",
                    "len": str(int(last_len)),
                    "n/ep": str(int(result["n/ep"])),
                    "n/st": str(int(result["n/st"])),
                }
                if result["n/ep"] > 0:
                    if test_in_train and stop_fn and stop_fn(result["rew"]):
                        test_result = test_episode(
                            policy, test_collector, test_fn,
                            epoch, episode_per_test, logger, env_step)
                        if stop_fn(test_result["rew"]):
                            if save_fn:
                                save_fn(policy)
                            logger.save_data(
                                epoch, env_step, gradient_step, save_checkpoint_fn)
                            t.set_postfix(**data)
                            return gather_info(
                                start_time, train_collector, test_collector,
                                test_result["rew"], test_result["rew_std"])
                        else:
                            policy.train()
                for i in range(round(update_per_step * result["n/st"])):
                    gradient_step += 1
                    losses = policy.update(batch_size, train_collector.buffer)
                    for k in losses.keys():
                        stat[k].add(losses[k])
                        losses[k] = stat[k].get()
                        data[k] = f"{losses[k]:.3f}"
                    logger.log_update_data(losses, gradient_step)
                    t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        test_result = test_episode(policy, test_collector, test_fn, epoch,
                                   episode_per_test, logger, env_step, reward_metric)
        rew, rew_std = test_result["rew"], test_result["rew_std"]
        if best_epoch < 0 or best_reward < rew:
            best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
            if save_fn:
                save_fn(policy)
        logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
        if verbose:
            print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
                  f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector,
                       best_reward, best_reward_std)
Ejemplo n.º 12
0
def offline_trainer(
    policy: BasePolicy,
    buffer: ReplayBuffer,
    test_collector: Collector,
    max_epoch: int,
    step_per_epoch: int,
    episode_per_test: Union[int, List[int]],
    batch_size: int,
    test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
    stop_fn: Optional[Callable[[float], bool]] = None,
    save_fn: Optional[Callable[[BasePolicy], None]] = None,
    writer: Optional[SummaryWriter] = None,
    log_interval: int = 1,
    verbose: bool = True,
) -> Dict[str, Union[float, str]]:
    """A wrapper for offline trainer procedure.

    The "step" in trainer means a policy network update.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param test_collector: the collector used for testing.
    :type test_collector: :class:`~tianshou.data.Collector`
    :param int max_epoch: the maximum number of epochs for training. The
        training process might be finished before reaching the ``max_epoch``.
    :param int step_per_epoch: the number of policy network updates, so-called
        gradient steps, per epoch.
    :param episode_per_test: the number of episodes for one policy evaluation.
    :param int batch_size: the batch size of sample data, which is going to
        feed in the policy network.
    :param function test_fn: a hook called at the beginning of testing in each
        epoch. It can be used to perform custom additional operations, with the
        signature ``f(num_epoch: int, step_idx: int) -> None``.
    :param function save_fn: a hook called when the undiscounted average mean
        reward in evaluation phase gets better, with the signature ``f(policy:
        BasePolicy) -> None``.
    :param function stop_fn: a function with signature ``f(mean_rewards: float)
        -> bool``, receives the average undiscounted returns of the testing
        result, returns a boolean which indicates whether reaching the goal.
    :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
        SummaryWriter; if None is given, it will not write logs to TensorBoard.
    :param int log_interval: the log interval of the writer.
    :param bool verbose: whether to print the information.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    gradient_step = 0
    best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
    stat: Dict[str, MovAvg] = defaultdict(MovAvg)
    start_time = time.time()
    test_collector.reset_stat()

    for epoch in range(1, 1 + max_epoch):
        policy.train()
        with tqdm.trange(step_per_epoch, desc=f"Epoch #{epoch}",
                         **tqdm_config) as t:
            for i in t:
                gradient_step += 1
                losses = policy.update(batch_size, buffer)
                data = {"gradient_step": str(gradient_step)}
                for k in losses.keys():
                    stat[k].add(losses[k])
                    data[k] = f"{stat[k].get():.6f}"
                    if writer and gradient_step % log_interval == 0:
                        writer.add_scalar("train/" + k,
                                          stat[k].get(),
                                          global_step=gradient_step)
                t.set_postfix(**data)
        # test
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test, writer, gradient_step)
        if best_epoch == -1 or best_reward < result["rew"]:
            best_reward, best_reward_std = result["rew"], result["rew_std"]
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± "
                  f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± "
                  f"{best_reward_std:.6f} in #{best_epoch}")
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, None, test_collector, best_reward,
                       best_reward_std)
Ejemplo n.º 13
0
def offpolicy_trainer(
    policy: BasePolicy,
    train_collector: Collector,
    test_collector: Collector,
    max_epoch: int,
    step_per_epoch: int,
    collect_per_step: int,
    episode_per_test: Union[int, List[int]],
    batch_size: int,
    update_per_step: int = 1,
    train_fn: Optional[Callable[[int, int], None]] = None,
    test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
    stop_fn: Optional[Callable[[float], bool]] = None,
    save_fn: Optional[Callable[[BasePolicy], None]] = None,
    writer: Optional[SummaryWriter] = None,
    log_interval: int = 1,
    verbose: bool = True,
    test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
    """A wrapper for off-policy trainer procedure.

    The "step" in trainer means a policy network update.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param train_collector: the collector used for training.
    :type train_collector: :class:`~tianshou.data.Collector`
    :param test_collector: the collector used for testing.
    :type test_collector: :class:`~tianshou.data.Collector`
    :param int max_epoch: the maximum number of epochs for training. The
        training process might be finished before reaching the ``max_epoch``.
    :param int step_per_epoch: the number of policy network updates, so-called
        gradient steps, per epoch.
    :param int collect_per_step: the number of frames the collector would
        collect before the network update. In other words, collect some frames
        and do some policy network update.
    :param episode_per_test: the number of episodes for one policy evaluation.
    :param int batch_size: the batch size of sample data, which is going to
        feed in the policy network.
    :param int update_per_step: the number of times the policy network would
        be updated after frames are collected, for example, set it to 256 means
        it updates policy 256 times once after ``collect_per_step`` frames are
        collected.
    :param function train_fn: a hook called at the beginning of training in
        each epoch. It can be used to perform custom additional operations,
        with the signature ``f(num_epoch: int, step_idx: int) -> None``.
    :param function test_fn: a hook called at the beginning of testing in each
        epoch. It can be used to perform custom additional operations, with the
        signature ``f(num_epoch: int, step_idx: int) -> None``.
    :param function save_fn: a hook called when the undiscounted average mean
        reward in evaluation phase gets better, with the signature ``f(policy:
        BasePolicy) -> None``.
    :param function stop_fn: a function with signature ``f(mean_rewards: float)
        -> bool``, receives the average undiscounted returns of the testing
        result, returns a boolean which indicates whether reaching the goal.
    :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
        SummaryWriter; if None is given, it will not write logs to TensorBoard.
    :param int log_interval: the log interval of the writer.
    :param bool verbose: whether to print the information.
    :param bool test_in_train: whether to test in the training phase.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    env_step, gradient_step = 0, 0
    best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
    stat: Dict[str, MovAvg] = defaultdict(MovAvg)
    start_time = time.time()
    train_collector.reset_stat()
    test_collector.reset_stat()
    test_in_train = test_in_train and train_collector.policy == policy
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f"Epoch #{epoch}",
                       **tqdm_config) as t:
            while t.n < t.total:
                if train_fn:
                    train_fn(epoch, env_step)
                result = train_collector.collect(n_step=collect_per_step)
                env_step += int(result["n/st"])
                data = {
                    "env_step": str(env_step),
                    "rew": f"{result['rew']:.2f}",
                    "len": str(int(result["len"])),
                    "n/ep": str(int(result["n/ep"])),
                    "n/st": str(int(result["n/st"])),
                    "v/ep": f"{result['v/ep']:.2f}",
                    "v/st": f"{result['v/st']:.2f}",
                }
                if writer and env_step % log_interval == 0:
                    for k in result.keys():
                        writer.add_scalar("train/" + k,
                                          result[k],
                                          global_step=env_step)
                if test_in_train and stop_fn and stop_fn(result["rew"]):
                    test_result = test_episode(policy, test_collector, test_fn,
                                               epoch, episode_per_test, writer,
                                               env_step)
                    if stop_fn(test_result["rew"]):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f"{result[k]:.2f}"
                        t.set_postfix(**data)
                        return gather_info(start_time, train_collector,
                                           test_collector, test_result["rew"],
                                           test_result["rew_std"])
                    else:
                        policy.train()
                for i in range(update_per_step * min(
                        result["n/st"] // collect_per_step, t.total - t.n)):
                    gradient_step += 1
                    losses = policy.update(batch_size, train_collector.buffer)
                    for k in losses.keys():
                        stat[k].add(losses[k])
                        data[k] = f"{stat[k].get():.6f}"
                        if writer and gradient_step % log_interval == 0:
                            writer.add_scalar(k,
                                              stat[k].get(),
                                              global_step=gradient_step)
                    t.update(1)
                    t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test, writer, env_step)
        if best_epoch == -1 or best_reward < result["rew"]:
            best_reward, best_reward_std = result["rew"], result["rew_std"]
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± "
                  f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± "
                  f"{best_reward_std:.6f} in #{best_epoch}")
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector,
                       best_reward, best_reward_std)
Ejemplo n.º 14
0
def offline_trainer(
    policy: BasePolicy,
    buffer: ReplayBuffer,
    test_collector: Collector,
    max_epoch: int,
    update_per_epoch: int,
    episode_per_test: int,
    batch_size: int,
    test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
    stop_fn: Optional[Callable[[float], bool]] = None,
    save_fn: Optional[Callable[[BasePolicy], None]] = None,
    reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
    logger: BaseLogger = LazyLogger(),
    verbose: bool = True,
) -> Dict[str, Union[float, str]]:
    """A wrapper for offline trainer procedure.

    The "step" in offline trainer means a gradient step.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
    :param Collector test_collector: the collector used for testing.
    :param int max_epoch: the maximum number of epochs for training. The training
        process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
    :param int update_per_epoch: the number of policy network updates, so-called
        gradient steps, per epoch.
    :param episode_per_test: the number of episodes for one policy evaluation.
    :param int batch_size: the batch size of sample data, which is going to feed in
        the policy network.
    :param function test_fn: a hook called at the beginning of testing in each epoch.
        It can be used to perform custom additional operations, with the signature ``f(
        num_epoch: int, step_idx: int) -> None``.
    :param function save_fn: a hook called when the undiscounted average mean reward in
        evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
        None``.
    :param function stop_fn: a function with signature ``f(mean_rewards: float) ->
        bool``, receives the average undiscounted returns of the testing result,
        returns a boolean which indicates whether reaching the goal.
    :param function reward_metric: a function with signature ``f(rewards: np.ndarray
        with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
        used in multi-agent RL. We need to return a single scalar for each episode's
        result to monitor training in the multi-agent RL setting. This function
        specifies what is the desired metric, e.g., the reward of agent 1 or the
        average reward over all agents.
    :param BaseLogger logger: A logger that logs statistics during updating/testing.
        Default to a logger that doesn't log anything.
    :param bool verbose: whether to print the information. Default to True.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    gradient_step = 0
    stat: Dict[str, MovAvg] = defaultdict(MovAvg)
    start_time = time.time()
    test_collector.reset_stat()
    test_result = test_episode(policy, test_collector, test_fn, 0,
                               episode_per_test, logger, gradient_step,
                               reward_metric)
    best_epoch = 0
    best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
    for epoch in range(1, 1 + max_epoch):
        policy.train()
        with tqdm.trange(update_per_epoch,
                         desc=f"Epoch #{epoch}",
                         **tqdm_config) as t:
            for i in t:
                gradient_step += 1
                losses = policy.update(batch_size, buffer)
                data = {"gradient_step": str(gradient_step)}
                for k in losses.keys():
                    stat[k].add(losses[k])
                    losses[k] = stat[k].get()
                    data[k] = f"{losses[k]:.3f}"
                logger.log_update_data(losses, gradient_step)
                t.set_postfix(**data)
        # test
        test_result = test_episode(policy, test_collector, test_fn, epoch,
                                   episode_per_test, logger, gradient_step,
                                   reward_metric)
        rew, rew_std = test_result["rew"], test_result["rew_std"]
        if best_epoch == -1 or best_reward < rew:
            best_reward, best_reward_std = rew, rew_std
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(
                f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
                f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
            )
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, None, test_collector, best_reward,
                       best_reward_std)
Ejemplo n.º 15
0
def onpolicy_trainer(policy,
                     train_collector,
                     test_collector,
                     max_epoch,
                     step_per_epoch,
                     collect_per_step,
                     repeat_per_collect,
                     episode_per_test,
                     batch_size,
                     train_fn=None,
                     test_fn=None,
                     stop_fn=None,
                     writer=None,
                     log_interval=1,
                     verbose=True,
                     task='',
                     **kwargs):
    """A wrapper for on-policy trainer procedure.

    Parameters
        * **policy** – an instance of the :class:`~tianshou.policy.BasePolicy`\
            class.
        * **train_collector** – the collector used for training.
        * **test_collector** – the collector used for testing.
        * **max_epoch** – the maximum of epochs for training. The training \
            process might be finished before reaching the ``max_epoch``.
        * **step_per_epoch** – the number of step for updating policy network \
            in one epoch.
        * **collect_per_step** – the number of frames the collector would \
            collect before the network update. In other words, collect some \
            frames and do one policy network update.
        * **repeat_per_collect** – the number of repeat time for policy \
            learning, for example, set it to 2 means the policy needs to learn\
            each given batch data twice.
        * **episode_per_test** – the number of episodes for one policy \
            evaluation.
        * **batch_size** – the batch size of sample data, which is going to \
            feed in the policy network.
        * **train_fn** – a function receives the current number of epoch index\
            and performs some operations at the beginning of training in this \
            epoch.
        * **test_fn** – a function receives the current number of epoch index \
            and performs some operations at the beginning of testing in this \
            epoch.
        * **stop_fn** – a function receives the average undiscounted returns \
            of the testing result, return a boolean which indicates whether \
            reaching the goal.
        * **writer** – a SummaryWriter provided from TensorBoard.
        * **log_interval** – an int indicating the log interval of the writer.
        * **verbose** – a boolean indicating whether to print the information.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                result = train_collector.collect(n_episode=collect_per_step)
                data = {}
                if stop_fn and stop_fn(result['rew']):
                    test_result = test_episode(policy, test_collector, test_fn,
                                               epoch, episode_per_test)
                    if stop_fn and stop_fn(test_result['rew']):
                        for k in result.keys():
                            data[k] = f'{result[k]:.2f}'
                        t.set_postfix(**data)
                        return gather_info(start_time, train_collector,
                                           test_collector, test_result['rew'])
                    else:
                        policy.train()
                        if train_fn:
                            train_fn(epoch)
                losses = policy.learn(train_collector.sample(0), batch_size,
                                      repeat_per_collect)
                train_collector.reset_buffer()
                step = 1
                for k in losses.keys():
                    if isinstance(losses[k], list):
                        step = max(step, len(losses[k]))
                global_step += step
                for k in result.keys():
                    data[k] = f'{result[k]:.2f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(k + '_' + task if task else k,
                                          result[k],
                                          global_step=global_step)
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f'{stat[k].get():.6f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(k + '_' + task if task else k,
                                          stat[k].get(),
                                          global_step=global_step)
                t.update(step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test)
        if best_epoch == -1 or best_reward < result['rew']:
            best_reward = result['rew']
            best_epoch = epoch
        if verbose:
            print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
                  f'best_reward: {best_reward:.6f} in #{best_epoch}')
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector,
                       best_reward)
Ejemplo n.º 16
0
def onpolicy_trainer(
    policy: BasePolicy,
    train_collector: Collector,
    test_collector: Collector,
    max_epoch: int,
    step_per_epoch: int,
    repeat_per_collect: int,
    episode_per_test: int,
    batch_size: int,
    step_per_collect: Optional[int] = None,
    episode_per_collect: Optional[int] = None,
    train_fn: Optional[Callable[[int, int], None]] = None,
    test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
    stop_fn: Optional[Callable[[float], bool]] = None,
    save_fn: Optional[Callable[[BasePolicy], None]] = None,
    reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
    logger: BaseLogger = LazyLogger(),
    verbose: bool = True,
    test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
    """A wrapper for on-policy trainer procedure.

    The "step" in trainer means an environment step (a.k.a. transition).

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
    :param Collector train_collector: the collector used for training.
    :param Collector test_collector: the collector used for testing.
    :param int max_epoch: the maximum number of epochs for training. The training
        process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
    :param int step_per_epoch: the number of transitions collected per epoch.
    :param int repeat_per_collect: the number of repeat time for policy learning, for
        example, set it to 2 means the policy needs to learn each given batch data
        twice.
    :param int episode_per_test: the number of episodes for one policy evaluation.
    :param int batch_size: the batch size of sample data, which is going to feed in the
        policy network.
    :param int step_per_collect: the number of transitions the collector would collect
        before the network update, i.e., trainer will collect "step_per_collect"
        transitions and do some policy network update repeatly in each epoch.
    :param int episode_per_collect: the number of episodes the collector would collect
        before the network update, i.e., trainer will collect "episode_per_collect"
        episodes and do some policy network update repeatly in each epoch.
    :param function train_fn: a hook called at the beginning of training in each epoch.
        It can be used to perform custom additional operations, with the signature ``f(
        num_epoch: int, step_idx: int) -> None``.
    :param function test_fn: a hook called at the beginning of testing in each epoch.
        It can be used to perform custom additional operations, with the signature ``f(
        num_epoch: int, step_idx: int) -> None``.
    :param function save_fn: a hook called when the undiscounted average mean reward in
        evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
        None``.
    :param function stop_fn: a function with signature ``f(mean_rewards: float) ->
        bool``, receives the average undiscounted returns of the testing result,
        returns a boolean which indicates whether reaching the goal.
    :param function reward_metric: a function with signature ``f(rewards: np.ndarray
        with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
        used in multi-agent RL. We need to return a single scalar for each episode's
        result to monitor training in the multi-agent RL setting. This function
        specifies what is the desired metric, e.g., the reward of agent 1 or the
        average reward over all agents.
    :param BaseLogger logger: A logger that logs statistics during
        training/testing/updating. Default to a logger that doesn't log anything.
    :param bool verbose: whether to print the information. Default to True.
    :param bool test_in_train: whether to test in the training phase. Default to True.

    :return: See :func:`~tianshou.trainer.gather_info`.

    .. note::

        Only either one of step_per_collect and episode_per_collect can be specified.
    """
    env_step, gradient_step = 0, 0
    last_rew, last_len = 0.0, 0
    stat: Dict[str, MovAvg] = defaultdict(MovAvg)
    start_time = time.time()
    train_collector.reset_stat()
    test_collector.reset_stat()
    test_in_train = test_in_train and train_collector.policy == policy
    test_result = test_episode(policy, test_collector, test_fn, 0,
                               episode_per_test, logger, env_step,
                               reward_metric)
    best_epoch = 0
    best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f"Epoch #{epoch}",
                       **tqdm_config) as t:
            while t.n < t.total:
                if train_fn:
                    train_fn(epoch, env_step)
                result = train_collector.collect(n_step=step_per_collect,
                                                 n_episode=episode_per_collect)
                if reward_metric:
                    result["rews"] = reward_metric(result["rews"])
                env_step += int(result["n/st"])
                t.update(result["n/st"])
                logger.log_train_data(result, env_step)
                last_rew = result['rew'] if 'rew' in result else last_rew
                last_len = result['len'] if 'len' in result else last_len
                data = {
                    "env_step": str(env_step),
                    "rew": f"{last_rew:.2f}",
                    "len": str(int(last_len)),
                    "n/ep": str(int(result["n/ep"])),
                    "n/st": str(int(result["n/st"])),
                }
                if test_in_train and stop_fn and stop_fn(result["rew"]):
                    test_result = test_episode(policy, test_collector, test_fn,
                                               epoch, episode_per_test, logger,
                                               env_step)
                    if stop_fn(test_result["rew"]):
                        if save_fn:
                            save_fn(policy)
                        t.set_postfix(**data)
                        return gather_info(start_time, train_collector,
                                           test_collector, test_result["rew"],
                                           test_result["rew_std"])
                    else:
                        policy.train()
                losses = policy.update(0,
                                       train_collector.buffer,
                                       batch_size=batch_size,
                                       repeat=repeat_per_collect)
                train_collector.reset_buffer()
                step = max(
                    [1] +
                    [len(v) for v in losses.values() if isinstance(v, list)])
                gradient_step += step
                for k in losses.keys():
                    stat[k].add(losses[k])
                    losses[k] = stat[k].get()
                    data[k] = f"{losses[k]:.3f}"
                logger.log_update_data(losses, gradient_step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        test_result = test_episode(policy, test_collector, test_fn, epoch,
                                   episode_per_test, logger, env_step)
        rew, rew_std = test_result["rew"], test_result["rew_std"]
        if best_epoch == -1 or best_reward < rew:
            best_reward, best_reward_std = rew, rew_std
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(
                f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
                f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
            )
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector,
                       best_reward, best_reward_std)
Ejemplo n.º 17
0
def onpolicy_trainer_with_views(A, B, max_epoch, step_per_epoch, collect_per_step,
                                repeat_per_collect, episode_per_test, batch_size,
                                copier=False, peer=0, verbose=True, test_fn=None,
                                task='', copier_batch_size=0):
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    for epoch in range(1, 1 + max_epoch):
        # train
        A.train()
        B.train()
        with tqdm.tqdm(
                total=step_per_epoch, desc=f'Epoch #{epoch}',
                **tqdm_config) as t:
            while t.n < t.total:
                for view, other_view in zip([A, B], [B, A]):
                    result = view.train_collector.collect(n_episode=collect_per_step)
                    data = {}
                    if view.stop_fn and view.stop_fn(result['rew']):
                        test_result = test_episode(
                            view.policy, view.test_collector, test_fn,
                            epoch, episode_per_test)
                        if view.stop_fn and view.stop_fn(test_result['rew']):
                            for k in result.keys():
                                data[k] = f'{result[k]:.2f}'
                            t.set_postfix(**data)
                            return gather_info(
                                start_time, view.train_collector, view.test_collector,
                                test_result['rew'])
                        else:
                            view.train()

                    batch = view.train_collector.sample(0)
                    for obs in batch.obs:
                        view.buffer.add(obs, {}, {}, {}, {}, {})  # only need obs
                    losses = view.policy.learn(batch, batch_size, repeat_per_collect)

                    # Learn from demonstration
                    if copier:
                        copier_batch, indice = view.buffer.sample(copier_batch_size)
                        # copier_batch = view.train_collector.process_fn(
                        #     copier_batch, view.buffer, indice)
                        demo = other_view.policy(copier_batch)
                        view.learn_from_demos(copier_batch, demo, peer=peer)

                    view.train_collector.reset_buffer()
                    step = 1
                    for k in losses.keys():
                        if isinstance(losses[k], list):
                            step = max(step, len(losses[k]))
                    global_step += step
                    for k in result.keys():
                        data[k] = f'{result[k]:.2f}'
                        view.writer.add_scalar(
                            k + '_' + task if task else k,
                            result[k], global_step=global_step)
                    for k in losses.keys():
                        if stat.get(k) is None:
                            stat[k] = MovAvg()
                        stat[k].add(losses[k])
                        data[k] = f'{stat[k].get():.4f}'
                        if global_step:
                            view.writer.add_scalar(
                                k + '_' + task if task else k,
                                stat[k].get(), global_step=global_step)
                    t.update(step)
                    t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        brk = False
        for view in A, B:
            result = test_episode(
                view.policy, view.test_collector, test_fn, epoch, episode_per_test)
            if best_epoch == -1 or best_reward < result['rew']:
                best_reward = result['rew']
                best_epoch = epoch
            if verbose:
                print(f'Epoch #{epoch}: test_reward: {result["rew"]:.4f}, '
                      f'best_reward: {best_reward:.4f} in #{best_epoch}')
            if view.stop_fn(best_reward):
                brk = True
        if brk:
            break
    return (
        gather_info(start_time, A.train_collector, A.test_collector, best_reward),
        gather_info(start_time, B.train_collector, B.test_collector, best_reward),
    )