Example #1
0
    def __init__(self, env, args):
        self.env = env

        self.agents = Agents(args)
        self.rolloutWorker = RolloutWorker(env, self.agents, args)
        self.buffer = ReplayBuffer(args)
        self.args = args
        self.max_counter = []
        self.episode_rewards = []
        self.win_rates = []

        # 用来保存plt和pkl
        self.save_path = self.args.result_dir + '/' + args.alg
        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)
Example #2
0
def main():
    """
    Main function: Training ddpg agent (optional with HER) as defined in DEFAULT_PARAMS and saving stats.
    """
    set_seeds(DEFAULT_PARAMS['seed'])

    env = parallelEnv(DEFAULT_PARAMS['env_name'],
                      n=DEFAULT_PARAMS['num_workers'],
                      seed=DEFAULT_PARAMS['seed'])

    DEFAULT_PARAMS['dims'], DEFAULT_PARAMS['reward_fun'] = dims_and_reward_fun(
        DEFAULT_PARAMS['env_name'])

    DEFAULT_PARAMS['sample_her_transitions'] = make_sample_her_transitions(
        replay_strategy=DEFAULT_PARAMS['replay_strategy'],
        replay_k=4,
        reward_fun=DEFAULT_PARAMS['reward_fun'])

    agent = ddpgAgent(DEFAULT_PARAMS)

    rollout_worker = RolloutWorker(env, agent, DEFAULT_PARAMS)
    evaluation_worker = RolloutWorker(env,
                                      agent,
                                      DEFAULT_PARAMS,
                                      evaluate=True)

    scores = train(agent, rollout_worker, evaluation_worker)

    # save networks and stats ------------------------------------------------------------------------------------------
    agent.save_checkpoint(DEFAULT_PARAMS['results_path'],
                          DEFAULT_PARAMS['env_name'])
    np.savetxt(DEFAULT_PARAMS['results_path'] + '/scores_' +
               DEFAULT_PARAMS['env_name'] + '_' + str(DEFAULT_PARAMS['seed']) +
               '.csv',
               scores,
               delimiter=',')
    fig = plt.figure()
    fig.add_subplot(111)
    plt.plot(np.arange(len(scores)), scores)
    plt.savefig(DEFAULT_PARAMS['results_path'] + '/scores_' +
                DEFAULT_PARAMS['env_name'] + '_' +
                str(DEFAULT_PARAMS['seed']) + '.png')
    plt.show()
def main(policy_file, seed, n_test_rollouts, render):
    set_global_seeds(seed)

    # Load policy.
    with open(policy_file, 'rb') as f:
        policy = pickle.load(f)
    env_name = policy.info['env_name']

    # Prepare params.
    params = config.DEFAULT_PARAMS
    if env_name in config.DEFAULT_ENV_PARAMS:
        params.update(config.DEFAULT_ENV_PARAMS[env_name]
                      )  # merge env-specific parameters in
    params['env_name'] = env_name
    params = config.prepare_params(params)
    config.log_params(params, logger=logger)

    dims = config.configure_dims(params)

    eval_params = {
        'exploit': True,
        'use_target_net': params['test_with_polyak'],
        'compute_Q': True,
        'rollout_batch_size': 1,
        'render': bool(render),
    }

    for name in ['T', 'gamma', 'noise_eps', 'random_eps']:
        eval_params[name] = params[name]

    evaluator = RolloutWorker(params['make_env'], policy, dims, logger,
                              **eval_params)
    evaluator.seed(seed)

    # Run evaluation.
    evaluator.clear_history()
    for _ in range(n_test_rollouts):
        evaluator.generate_rollouts()

    # record logs
    for key, val in evaluator.logs('test'):
        logger.record_tabular(key, np.mean(val))
    logger.dump_tabular()
Example #4
0
def main():
    choose_gpu()
    args = parse_args()
    seed = set_seed(args.seed)
    env = make_vec_env(args.env,
                       'robotics',
                       args.num_workers,
                       seed=seed,
                       reward_scale=1.0,
                       flatten_dict_observations=False)
    env.get_images()
    seed = set_seed(args.seed)
    get_dims(env)
    PARAMS['sample_her_transitions'] = make_sample_her_transitions(
        PARAMS['distance_threshold'], PARAMS['replay_strategy'])
    PARAMS['log_dir'] = 'runs/env=%s_seed=%s' % (args.env, seed)
    shutil.rmtree(PARAMS['log_dir'], ignore_errors=True)
    print('logging to:', PARAMS['log_dir'])
    writer = SummaryWriter(PARAMS['log_dir'])

    policy = DDPG(PARAMS)
    rollout_worker = RolloutWorker(env, policy, PARAMS)
    evaluator = RolloutWorker(env, policy, PARAMS, evaluate=True)
    train(policy, rollout_worker, evaluator, writer)
Example #5
0
def learn(env,
          total_timesteps,
          seed=None,
          replay_strategy='future',
          policy_save_interval=5,
          clip_return=True,
          override_params=None,
          load_path=None,
          save_path=None,
          **kwargs):

    # env = gym.make(env_name)

    override_params = override_params or {}
    if MPI is not None:
        rank = MPI.COMM_WORLD.Get_rank()
        num_cpu = MPI.COMM_WORLD.Get_size()

    # Seed everything.
    rank_seed = seed + 1000000 * rank if seed is not None else None
    set_global_seeds(rank_seed)

    # prepare params
    logger.info("preparing parameters for NN models")
    params = config.DEFAULT_AGENT_PARAMS
    env_name = env.spec.id
    params['env_name'] = env_name
    params['replay_strategy'] = replay_strategy
    if env_name in config.DEFAULT_ENV_PARAMS:
        params.update(config.DEFAULT_ENV_PARAMS[env_name])

    params.update(**override_params)
    params['rollout_per_worker'] = env.num_envs
    params['rollout_batch_size'] = params['rollout_per_worker']
    params['num_timesteps'] = total_timesteps
    logger.save_params(params=params, filename='ddpg_params.json')

    # initialize session
    # tf_config = tf.ConfigProto(inter_op_parallelism_threads=1, intra_op_parallelism_threads=1)
    # tf_config.gpu_options.allow_growth = True # may need if using GPU
    # sess = tf.Session(config=tf_config)
    # sess.__enter__()

    # get policy given params
    policy = config.config_params_get_policy(params=params,
                                             clip_return=clip_return)
    # get planner
    planner = config.config_params_get_planner(params=params)
    if load_path is not None:
        U.load_variables(load_path +
                         '_pi')  # pi and planner are seperately stored.
    if load_path is not None:
        U.load_variables(load_path + '_pln')

    rollout_params = {
        'exploit': False,
        'act_rdm_dec': params['act_rdm_dec'],
        'use_target_net': False,
        'use_demo_states': True,
        'compute_Q': False,
        'T': params['T'],
        'reward_fun': params['reward_fun'],
        'goal_delta': params['goal_delta'],
        'subgoal_strategy': params['subgoal_strategy'],
        'subgoal_num': params['seq_len'] + 1,
        'subgoal_norm': env_name.startswith('Hand')
    }
    eval_params = {
        'exploit': True,
        'act_rdm_dec': params['act_rdm_dec'],
        'use_target_net': params['test_with_polyak'],
        'use_demo_states': False,
        'compute_Q': True,
        'T': params['T'],
        'reward_fun': params['reward_fun'],
        'subgoal_strategy': params['subgoal_strategy'],
        'goal_delta': params['goal_delta'],
        'subgoal_num': params['seq_len'] + 1,
        'subgoal_norm': env_name.startswith('Hand')
    }
    for name in [
            'T', 'rollout_per_worker', 'gamma', 'noise_eps', 'random_eps'
    ]:
        rollout_params[name] = params[name]
        eval_params[name] = params[name]

    eval_env = env

    rollout_worker = RolloutWorker(env,
                                   policy,
                                   params['dims'],
                                   logger,
                                   planner=planner,
                                   monitor=True,
                                   **rollout_params)
    evaluator = RolloutWorker(eval_env,
                              policy,
                              params['dims'],
                              logger,
                              planner=planner,
                              **eval_params)

    n_cycles = params['n_cycles']
    n_epochs = total_timesteps // n_cycles // rollout_worker.T // rollout_worker.rollout_per_worker

    return train(save_path=save_path,
                 policy=policy,
                 planner=planner,
                 rollout_worker=rollout_worker,
                 evaluator=evaluator,
                 n_epochs=n_epochs,
                 n_test_rollouts=params['n_test_rollouts'],
                 n_cycles=params['n_cycles'],
                 n_batches=params['n_batches'],
                 policy_save_interval=policy_save_interval)
def launch(env,
           logdir,
           n_epochs,
           num_cpu,
           seed,
           replay_strategy,
           policy_save_interval,
           clip_return,
           bc_loss,
           q_filter,
           num_demo,
           override_params={},
           save_policies=True):
    # Fork for multi-CPU MPI implementation.
    if num_cpu > 1:
        try:
            whoami = mpi_fork(num_cpu, ['--bind-to', 'core'])
        except CalledProcessError:
            # fancy version of mpi call failed, try simple version
            whoami = mpi_fork(num_cpu)

        if whoami == 'parent':
            sys.exit(0)
        import baselines.common.tf_util as U
        U.single_threaded_session().__enter__()
    rank = MPI.COMM_WORLD.Get_rank()

    # Configure logging
    if rank == 0:
        if logdir or logger.get_dir() is None:
            logger.configure(dir=logdir)
    else:
        logger.configure()
    logdir = logger.get_dir()
    assert logdir is not None
    os.makedirs(logdir, exist_ok=True)

    # Seed everything.
    rank_seed = seed + 1000000 * rank
    set_global_seeds(rank_seed)
    resource.setrlimit(resource.RLIMIT_NOFILE, (65536, 65536))

    # Prepare params.
    params = config.DEFAULT_PARAMS
    params['env_name'] = env
    params['replay_strategy'] = replay_strategy
    if env in config.DEFAULT_ENV_PARAMS:
        params.update(
            config.DEFAULT_ENV_PARAMS[env])  # merge env-specific parameters in
    params.update(
        **override_params)  # makes it possible to override any parameter
    with open(os.path.join(logger.get_dir(), 'params.json'), 'w') as f:
        json.dump(params, f)
    params = config.prepare_params(params)
    config.log_params(params, logger=logger)

    if num_cpu == 1:
        logger.warn()
        logger.warn('*** Warning ***')
        logger.warn(
            'You are running HER with just a single MPI worker. This will work, but the '
            +
            'experiments that we report in Plappert et al. (2018, https://arxiv.org/abs/1802.09464) '
            +
            'were obtained with --num_cpu 19. This makes a significant difference and if you '
            +
            'are looking to reproduce those results, be aware of this. Please also refer to '
            +
            'https://github.com/openai/baselines/issues/314 for further details.'
        )
        logger.warn('****************')
        logger.warn()

    dims = config.configure_dims(params)
    policy = config.configure_ddpg(dims=dims,
                                   params=params,
                                   clip_return=clip_return,
                                   bc_loss=bc_loss,
                                   q_filter=q_filter,
                                   num_demo=num_demo)

    if params['env_name'] == 'GazeboWAMemptyEnv-v2':

        demoFileName = '/home/rjangir/wamObjectDemoData/data_wam_double_random_100_40_25.npz'
        rollout_params = {
            'exploit': False,
            'use_target_net': False,
            'use_demo_states': True,
            'compute_Q': False,
            'T': params['T'],
            #'render': 1,
        }

        eval_params = {
            'exploit': True,
            'use_target_net': params['test_with_polyak'],
            #'use_demo_states': False,
            'compute_Q': True,
            'T': params['T'],
            'rollout_batch_size': 1,
            #'render': 1,
        }

        for name in [
                'T', 'rollout_batch_size', 'gamma', 'noise_eps', 'random_eps'
        ]:
            rollout_params[name] = params[name]
            eval_params[name] = params[name]

        madeEnv = config.cached_make_env(params['make_env'])
        rollout_worker = RolloutWorker(madeEnv, params['make_env'], policy,
                                       dims, logger, **rollout_params)
        rollout_worker.seed(rank_seed)

        evaluator = RolloutWorker(madeEnv, params['make_env'], policy, dims,
                                  logger, **eval_params)
        evaluator.seed(rank_seed)
    else:

        demoFileName = '/home/rjangir/fetchDemoData/data_fetch_random_100.npz'
        rollout_params = {
            'exploit': False,
            'use_target_net': False,
            'use_demo_states': True,
            'compute_Q': False,
            'T': params['T'],
            #'render': 1,
        }

        eval_params = {
            'exploit': True,
            'use_target_net': params['test_with_polyak'],
            #'use_demo_states': False,
            'compute_Q': True,
            'T': params['T'],
            #'render': 1,
        }

        for name in [
                'T', 'rollout_batch_size', 'gamma', 'noise_eps', 'random_eps'
        ]:
            rollout_params[name] = params[name]
            eval_params[name] = params[name]

        rollout_worker = RolloutWorkerOriginal(params['make_env'], policy,
                                               dims, logger, **rollout_params)
        rollout_worker.seed(rank_seed)

        evaluator = RolloutWorkerOriginal(params['make_env'], policy, dims,
                                          logger, **eval_params)
        evaluator.seed(rank_seed)

    train(logdir=logdir,
          policy=policy,
          rollout_worker=rollout_worker,
          evaluator=evaluator,
          n_epochs=n_epochs,
          n_test_rollouts=params['n_test_rollouts'],
          n_cycles=params['n_cycles'],
          n_batches=params['n_batches'],
          policy_save_interval=policy_save_interval,
          save_policies=save_policies,
          demo_file_name=demoFileName)
Example #7
0
def launch(args):

    rank = MPI.COMM_WORLD.Get_rank()

    t_total_init = time.time()

    # Make the environment
    if args.algo == 'continuous':
        args.env_name = 'FetchManipulate3ObjectsContinuous-v0'
        args.multi_criteria_her = True
    else:
        args.env_name = 'FetchManipulate3Objects-v0'
    env = gym.make(args.env_name)

    # set random seeds for reproducibility
    env.seed(args.seed + MPI.COMM_WORLD.Get_rank())
    random.seed(args.seed + MPI.COMM_WORLD.Get_rank())
    np.random.seed(args.seed + MPI.COMM_WORLD.Get_rank())
    torch.manual_seed(args.seed + MPI.COMM_WORLD.Get_rank())
    if args.cuda:
        torch.cuda.manual_seed(args.seed + MPI.COMM_WORLD.Get_rank())

    # get saving paths
    if rank == 0:
        logdir, model_path, bucket_path = init_storage(args)
        logger.configure(dir=logdir)
        logger.info(vars(args))

    args.env_params = get_env_params(env)

    if args.algo == 'language':
        language_goal = get_instruction()
        goal_sampler = GoalSampler(args)
    else:
        language_goal = None
        goal_sampler = GoalSampler(args)

    # Initialize RL Agent
    if args.agent == "SAC":
        policy = RLAgent(args, env.compute_reward, goal_sampler)
    else:
        raise NotImplementedError

    # Initialize Rollout Worker
    rollout_worker = RolloutWorker(env, policy, goal_sampler, args)

    # Main interaction loop
    episode_count = 0
    for epoch in range(args.n_epochs):
        t_init = time.time()

        # setup time_tracking
        time_dict = dict(goal_sampler=0,
                         rollout=0,
                         gs_update=0,
                         store=0,
                         norm_update=0,
                         policy_train=0,
                         lp_update=0,
                         eval=0,
                         epoch=0)

        # log current epoch
        if rank == 0: logger.info('\n\nEpoch #{}'.format(epoch))

        # Cycles loop
        for _ in range(args.n_cycles):

            # Sample goals
            t_i = time.time()
            goals, self_eval = goal_sampler.sample_goal(
                n_goals=args.num_rollouts_per_mpi, evaluation=False)
            if args.algo == 'language':
                language_goal_ep = np.random.choice(
                    language_goal, size=args.num_rollouts_per_mpi)
            else:
                language_goal_ep = None
            time_dict['goal_sampler'] += time.time() - t_i

            # Control biased initializations
            if epoch < args.start_biased_init:
                biased_init = False
            else:
                biased_init = args.biased_init

            # Environment interactions
            t_i = time.time()
            episodes = rollout_worker.generate_rollout(
                goals=goals,  # list of goal configurations
                self_eval=
                self_eval,  # whether the agent performs self-evaluations
                true_eval=False,  # these are not offline evaluation episodes
                biased_init=biased_init,
                language_goal=language_goal_ep
            )  # whether initializations should be biased.
            time_dict['rollout'] += time.time() - t_i

            # Goal Sampler updates
            t_i = time.time()
            episodes = goal_sampler.update(episodes, episode_count)
            time_dict['gs_update'] += time.time() - t_i

            # Storing episodes
            t_i = time.time()
            policy.store(episodes)
            time_dict['store'] += time.time() - t_i

            # Updating observation normalization
            t_i = time.time()
            for e in episodes:
                policy._update_normalizer(e)
            time_dict['norm_update'] += time.time() - t_i

            # Policy updates
            t_i = time.time()
            for _ in range(args.n_batches):
                policy.train()
            time_dict['policy_train'] += time.time() - t_i
            episode_count += args.num_rollouts_per_mpi * args.num_workers

        # Updating Learning Progress
        t_i = time.time()
        if goal_sampler.curriculum_learning and rank == 0:
            goal_sampler.update_LP()
        goal_sampler.sync()

        time_dict['lp_update'] += time.time() - t_i
        time_dict['epoch'] += time.time() - t_init
        time_dict['total'] = time.time() - t_total_init

        if args.evaluations:
            if rank == 0: logger.info('\tRunning eval ..')
            # Performing evaluations
            t_i = time.time()
            if args.algo == 'language':
                ids = np.random.choice(np.arange(35), size=len(language_goal))
                eval_goals = goal_sampler.valid_goals[ids]
            else:
                eval_goals = goal_sampler.valid_goals
            episodes = rollout_worker.generate_rollout(
                goals=eval_goals,
                self_eval=True,  # this parameter is overridden by true_eval
                true_eval=True,  # this is offline evaluations
                biased_init=False,
                language_goal=language_goal)

            # Extract the results
            if args.algo == 'continuous':
                results = np.array([e['rewards'][-1] == 3.
                                    for e in episodes]).astype(np.int)
            elif args.algo == 'language':
                results = np.array([
                    e['language_goal']
                    in sentence_from_configuration(config=e['ag'][-1],
                                                   all=True) for e in episodes
                ]).astype(np.int)
            else:
                results = np.array([
                    str(e['g'][0]) == str(e['ag'][-1]) for e in episodes
                ]).astype(np.int)
            rewards = np.array([e['rewards'][-1] for e in episodes])
            all_results = MPI.COMM_WORLD.gather(results, root=0)
            all_rewards = MPI.COMM_WORLD.gather(rewards, root=0)
            time_dict['eval'] += time.time() - t_i

            # Logs
            if rank == 0:
                assert len(all_results) == args.num_workers  # MPI test
                av_res = np.array(all_results).mean(axis=0)
                av_rewards = np.array(all_rewards).mean(axis=0)
                global_sr = np.mean(av_res)
                log_and_save(goal_sampler, epoch, episode_count, av_res,
                             av_rewards, global_sr, time_dict)

                # Saving policy models
                if epoch % args.save_freq == 0:
                    policy.save(model_path, epoch)
                    goal_sampler.save_bucket_contents(bucket_path, epoch)
                if rank == 0:
                    logger.info('\tEpoch #{}: SR: {}'.format(epoch, global_sr))
Example #8
0
    all_goals = generate_all_goals_in_goal_space()
    dict_goals = dict(zip([str(g) for g in all_goals], all_goals))

    policy_scores = []
    for vae_id in range(10):
        model_path = path + '/policy_models/model{}.pt'.format(vae_id + 1)

        # create the sac agent to interact with the environment
        if args.agent == "SAC":
            policy = RLAgent(args, env.compute_reward, goal_sampler)
            policy.load(model_path, args)
        else:
            raise NotImplementedError

        # def rollout worker
        rollout_worker = RolloutWorker(env, policy, goal_sampler, args)

        with open(path + 'vae_models/vae_model{}.pkl'.format(vae_id + 1),
                  'rb') as f:
            vae = torch.load(f)

        scores = []
        for i in range(num_eval):
            print(i)
            score = rollout(sentence_generator,
                            vae,
                            sentences,
                            inst_to_one_hot,
                            dict_goals,
                            env,
                            policy,
Example #9
0
def launch(
    env, logdir, n_epochs, num_cpu, seed, replay_strategy, policy_save_interval, clip_return, demo_file,
    override_params={}, save_policies=True
):
    # Fork for multi-CPU MPI implementation.
    if num_cpu > 1:
        try:
            whoami = mpi_fork(num_cpu, ['--bind-to', 'core'])
        except CalledProcessError:
            # fancy version of mpi call failed, try simple version
            whoami = mpi_fork(num_cpu)

        if whoami == 'parent':
            sys.exit(0)
        import baselines.common.tf_util as U
        U.single_threaded_session().__enter__()
    rank = MPI.COMM_WORLD.Get_rank()

    # Configure logging
    if rank == 0:
        if logdir or logger.get_dir() is None:
            logger.configure(dir=logdir)
    else:
        logger.configure()
    logdir = logger.get_dir()
    assert logdir is not None
    os.makedirs(logdir, exist_ok=True)

    # Seed everything.
    rank_seed = seed + 1000000 * rank
    set_global_seeds(rank_seed)
    resource.setrlimit(resource.RLIMIT_NOFILE, (65536, 65536))

    # Prepare params.
    params = config.DEFAULT_PARAMS
    params['env_name'] = env
    params['replay_strategy'] = replay_strategy
    if env in config.DEFAULT_ENV_PARAMS:
        params.update(config.DEFAULT_ENV_PARAMS[env])  # merge env-specific parameters in
    params.update(**override_params)  # makes it possible to override any parameter
    with open(os.path.join(logger.get_dir(), 'params.json'), 'w') as f:
        json.dump(params, f)
    params = config.prepare_params(params)
    config.log_params(params, logger=logger)

    
    dims = config.configure_dims(params)
    policy = config.configure_ddpg(dims=dims, params=params, clip_return=clip_return)

    if params['env_name'] == 'FetchPickAndPlace-v0':
        rollout_params = {
            'exploit': False,
            'use_target_net': False,
            'use_demo_states': True,
            'compute_Q': False,
            'T': params['T'],
            'render': 1,
        }

        eval_params = {
            'exploit': True,
            'use_target_net': params['test_with_polyak'],
            #'use_demo_states': False,
            'compute_Q': True,
            'T': params['T'],
            'rollout_batch_size': 1,
            'render': 1,
        }

        for name in ['T', 'rollout_batch_size', 'gamma', 'noise_eps', 'random_eps']:
            rollout_params[name] = params[name]
            eval_params[name] = params[name]



        madeEnv = config.cached_make_env(params['make_env'])
        rollout_worker = RolloutWorker(madeEnv, params['make_env'], policy, dims, logger, **rollout_params)
        rollout_worker.seed(rank_seed)

        evaluator = RolloutWorker(madeEnv, params['make_env'], policy, dims, logger, **eval_params)
        evaluator.seed(rank_seed)
    else:
        rollout_params = {
            'exploit': False,
            'use_target_net': False,
            'use_demo_states': True,
            'compute_Q': False,
            'T': params['T'],
            'render': 1,
        }

        eval_params = {
            'exploit': True,
            'use_target_net': params['test_with_polyak'],
            #'use_demo_states': False,
            'compute_Q': True,
            'T': params['T'],
            'render': 1,
        }

        for name in ['T', 'rollout_batch_size', 'gamma', 'noise_eps', 'random_eps']:
            rollout_params[name] = params[name]
            eval_params[name] = params[name]


        rollout_worker = RolloutWorkerOriginal(params['make_env'], policy, dims, logger, **rollout_params)
        rollout_worker.seed(rank_seed)

        evaluator = RolloutWorkerOriginal(params['make_env'], policy, dims, logger, **eval_params)
        evaluator.seed(rank_seed)

    train(
        logdir=logdir, policy=policy, rollout_worker=rollout_worker,
        evaluator=evaluator, n_epochs=n_epochs, n_test_rollouts=params['n_test_rollouts'],
        n_cycles=params['n_cycles'], n_batches=params['n_batches'],
        policy_save_interval=policy_save_interval, save_policies=save_policies, demo_file = demo_file)
Example #10
0
class Runner:
    def __init__(self, env, args):
        self.env = env

        self.agents = Agents(args)
        self.rolloutWorker = RolloutWorker(env, self.agents, args)
        self.buffer = ReplayBuffer(args)
        self.args = args
        self.max_counter = []
        self.episode_rewards = []
        self.win_rates = []

        # 用来保存plt和pkl
        self.save_path = self.args.result_dir + '/' + args.alg
        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)

    def run(self, num):
        train_steps = 0
        # print('Run {} start'.format(num))
        for epoch in range(self.args.n_epoch):
            print('Run {}, train epoch {}'.format(num, epoch))
            if epoch % self.args.evaluate_cycle == 0 and epoch != 0:
                # win_rate, episode_reward = self.evaluate(all_gain,all_loss)
                self.evaluate(all_gain,all_loss)
                # print('win_rate is ', win_rate)
                # self.win_rates.append(win_rate)
                # self.episode_rewards.append(episode_reward)
            #     self.plt(num)
            episodes = []

            # 收集self.args.n_episodes个episodes
            for episode_idx in range(self.args.n_episodes):
                print("Generate episode {}".format(episode_idx))
                episode, episode_reward, info, win_number, all_gain, all_loss = self.rolloutWorker.generate_episode(episode_idx,0)
                episodes.append(episode)
                print('win_number:',win_number)
  
            # episode的每一项都是一个(1, episode_len, n_agents, 具体维度)四维数组,下面要把所有episode的的obs拼在一起
            episode_batch = episodes[0]
            episodes.pop(0)
            for episode in episodes:
                for key in episode_batch.keys():
                    episode_batch[key] = np.concatenate((episode_batch[key], episode[key]), axis=0) #数组拼接

            self.buffer.store_episode(episode_batch)
            for train_step in range(self.args.train_steps):
                mini_batch = self.buffer.sample(min(self.buffer.current_size, self.args.batch_size))
                self.agents.train(mini_batch, train_steps)
                train_steps += 1
            
    def evaluate(self,all_gain,all_loss):
        episode_rewards = 0
        win_number = 0
        x=[]
        for i in range(len(all_gain)):
            x.append(i+1)
        for epoch in range(self.args.evaluate_epoch):
            print('evaluate_epoch{}'.format(epoch))
            _, episode_reward, info, win_number,all_gain,all_loss = self.rolloutWorker.generate_episode(epoch, win_number, evaluate=True)
            print('win_number',win_number)
            episode_rewards += episode_reward
        
        plt.figure()
        plt.plot(all_gain, lw = 1.5,label = 'reward')
        plt.plot(all_loss, lw = 1.5,label = 'loss')
        plt.grid(True)
        plt.legend(loc = 0) #图例位置自动
        plt.axis('tight')
        plt.xlabel('index')
        plt.ylabel('packets')
        plt.title('reawrd and loss')
        plt.show()

        # return win_number / (self.args.evaluate_epoch*self.args.episode_limit), episode_rewards / 2