示例#1
0
def train(env_id, num_timesteps, seed, num_cpu):
    from baselines.pposgd import pposgd_simple, cnn_policy
    import baselines.common.tf_util as U
    whoami  = mpi_fork(num_cpu)
    if whoami == "parent": return
    rank = MPI.COMM_WORLD.Get_rank()
    sess = U.single_threaded_session()
    sess.__enter__()
    logger.session().__enter__()
    if rank != 0: logger.set_level(logger.DISABLED)
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)
    env = gym.make(env_id)
    def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
        return cnn_policy.CnnPolicy(name=name, ob_space=ob_space, ac_space=ac_space)
    env = bench.Monitor(env, osp.join(logger.get_dir(), "%i.monitor.json" % rank))
    env.seed(workerseed)
    gym.logger.setLevel(logging.WARN)

    env = wrap_train(env)
    num_timesteps /= 4 # because we're wrapping the envs to do frame skip
    env.seed(workerseed)

    pposgd_simple.learn(env, policy_fn,
        max_timesteps=num_timesteps,
        timesteps_per_batch=256,
        clip_param=0.2, entcoeff=0.01,
        optim_epochs=4, optim_stepsize=1e-3, optim_batchsize=64,
        gamma=0.99, lam=0.95,
        schedule='linear'
    )
    env.close()
示例#2
0
def train(env_id, num_timesteps, seed):
    from baselines.pposgd import mlp_policy, pposgd_simple
    U.make_session(num_cpu=1).__enter__()
    logger.session().__enter__()
    set_global_seeds(seed)
    env = gym.make(env_id)
    def policy_fn(name, ob_space, ac_space):
        return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
            hid_size=64, num_hid_layers=2)
    env = bench.Monitor(env, osp.join(logger.get_dir(), "monitor.json"))
    env.seed(seed)
    gym.logger.setLevel(logging.WARN)
    pposgd_simple.learn(env, policy_fn, 
            max_timesteps=num_timesteps,
            timesteps_per_batch=2048,
            clip_param=0.2, entcoeff=0.0,
            optim_epochs=10, optim_stepsize=3e-4, optim_batchsize=64,
            gamma=0.99, lam=0.95,
        )
    env.close()
示例#3
0
def train(env_id, num_timesteps, seed, num_cpu):
    from baselines.pposgd import pposgd_simple, cnn_policy
    import baselines.common.tf_util as U
    whoami = mpi_fork(num_cpu)
    if whoami == "parent": return
    rank = MPI.COMM_WORLD.Get_rank()
    sess = U.single_threaded_session()
    sess.__enter__()
    logger.session().__enter__()
    if rank != 0: logger.set_level(logger.DISABLED)
    workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)
    env = gym.make(env_id)

    def policy_fn(name, ob_space, ac_space):  #pylint: disable=W0613
        return cnn_policy.CnnPolicy(name=name,
                                    ob_space=ob_space,
                                    ac_space=ac_space)

    env = bench.Monitor(env,
                        osp.join(logger.get_dir(), "%i.monitor.json" % rank))
    env.seed(workerseed)
    gym.logger.setLevel(logging.WARN)

    env = wrap_train(env)
    num_timesteps /= 4  # because we're wrapping the envs to do frame skip
    env.seed(workerseed)

    pposgd_simple.learn(env,
                        policy_fn,
                        max_timesteps=num_timesteps,
                        timesteps_per_batch=256,
                        clip_param=0.2,
                        entcoeff=0.01,
                        optim_epochs=4,
                        optim_stepsize=1e-3,
                        optim_batchsize=64,
                        gamma=0.99,
                        lam=0.95,
                        schedule='linear')
    env.close()
def train(args):
    from baselines.pposgd import mlp_policy, pposgd_simple
    rank = MPI.COMM_WORLD.Get_rank()
    sess = U.single_threaded_session()
    sess.__enter__()
    logger.session(dir=args.exp_path,
                   format_strs=None if rank == 0 and not args.test_only
                   and not args.evaluate else []).__enter__()
    if rank != 0:
        logger.set_level(logger.DISABLED)
    workerseed = args.seed + 10000 * rank
    set_global_seeds(workerseed)

    if args.submit:
        env = SubmitRunEnv(visualize=args.render)
    elif args.submit_round2:
        from turnips.submit_round2_env import SubmitRunEnv2
        submit_env = env = SubmitRunEnv2()
    elif args.simwalker:
        env = SimWalker(visualize=args.render)
    else:
        env = IsolatedMyRunEnv(visualize=args.render,
                               run_logs_dir=args.run_logs_dir,
                               additional_info={'exp_name': args.exp_name},
                               step_timeout=args.step_timeout,
                               n_obstacles=args.n_obstacles,
                               higher_pelvis=args.higher_pelvis)

    env = RunEnvWrapper(env, args.diff)
    if args.simwalker and args.log_simwalker:
        cls = type(
            "h5pyEnvLoggerClone", (gym.Wrapper, ),
            dict(h5pyEnvLogger.__dict__))  # workaround for double wrap problem
        env = cls(env,
                  log_dir=args.run_logs_dir,
                  filename_prefix='simwalker_',
                  additional_info={
                      'exp_name': args.exp_name,
                      'difficulty': args.diff,
                      'seed': args.seed
                  })

    env = env_walker = Walker(env,
                              shaping_mode=args.shaping,
                              transform_inputs=args.transform_inputs,
                              obstacle_hack=not args.noobsthack,
                              max_steps=args.max_env_steps,
                              memory_size=args.memory_size,
                              swap_legs_mode=args.swap_legs_mode,
                              filter_obs=args.filter_obs,
                              add_time=args.add_time,
                              fall_penalty=args.fall_penalty,
                              fall_penalty_value=args.fall_penalty_val,
                              print_action=args.print_action,
                              new8_fix=args.new8_fix,
                              pause=args.pause,
                              noisy_obstacles=args.noisy_obstacles,
                              noisy_obstacles2=args.noisy_obstacles2,
                              noisy_fix=args.noisy_fix)

    if args.log_walker:
        env = h5pyEnvLogger(env,
                            log_dir=args.run_logs_dir,
                            filename_prefix='walker_',
                            additional_info={
                                'exp_name': args.exp_name,
                                'difficulty': args.diff,
                                'seed': args.seed
                            })
    if args.muscles:
        env = MuscleWalker(env)
    if args.repeats > 1:
        env = RepeatActionsWalker(env, args.repeats)

    def policy_fn(name, ob_space, ac_space):
        return mlp_policy.MlpPolicy(
            name=name,
            ob_space=ob_space,
            ac_space=ac_space,
            hid_size=args.hid_size,
            num_hid_layers=args.num_hid_layers,
            bound_by_sigmoid=args.bound_by_sigmoid,
            sigmoid_coef=args.sigmoid_coef,
            activation=args.activation,
            normalize_obs=not args.nonormalize_obs,
            gaussian_fixed_var=not args.nogaussian_fixed_var,
            avg_norm_symmetry=args.avg_norm_symmetry,
            symmetric_interpretation=args.symmetric_interpretation,
            stdclip=args.stdclip,
            actions=args.actions,
            gaussian_bias=args.gaussian_bias,
            gaussian_from_binary=args.gaussian_from_binary,
            parallel_value=args.parallel_value,
            pv_layers=args.pv_layers,
            pv_hid_size=args.pv_hid_size,
            three=args.three)

    if not args.test_only and not args.evaluate:
        env = bench.Monitor(env,
                            path.join(args.exp_path, "%i.monitor.json" % rank))
    env.seed(workerseed)
    gym.logger.setLevel(logging.WARN)

    current_best = float('-inf')
    current_best_completed = float('-inf')
    current_best_perc_completed = float('-inf')
    stats_f = None

    start = time.time()

    def callback(local_, global_):
        nonlocal current_best
        nonlocal current_best_completed
        nonlocal current_best_perc_completed
        nonlocal stats_f
        if rank != 0: return
        if args.test_only or args.evaluate: return

        print('ELAPSED', time.time() - start)
        print(f'{socket.gethostname()}:{args.exp_path}')

        iter_no = local_['iters_so_far']
        if iter_no % args.save_every == 0:
            U.save_state(
                path.join(args.exp_path, 'models', f'{iter_no:04d}', 'model'))

        if local_['iters_so_far'] == 0:
            stats_f = open(path.join(args.exp_path, 'simple_stats.csv'), 'w')
            cols = [
                "Iter", "EpLenMean", "EpRewMean", "EpOrigRewMean",
                "EpThisIter", "EpisodesSoFar", "TimestepsSoFar", "TimeElapsed",
                "AvgCompleted", "PercCompleted"
            ]
            for name in local_['loss_names']:
                cols.append("loss_" + name)
            stats_f.write(",".join(cols) + '\n')
        else:
            current_orig_reward = np.mean(local_['origrew_buffer'])
            if current_best < current_orig_reward:
                print(
                    f'Found better {current_best:.2f} -> {current_orig_reward:.2f}'
                )
                current_best = current_orig_reward
                U.save_state(path.join(args.exp_path, 'best', 'model'))
            U.save_state(path.join(args.exp_path, 'last', 'model'))

            avg_completed = local_["avg_completed"]
            if current_best_completed < avg_completed:
                print(
                    f'Found better completed {current_best_completed:.2f} -> {avg_completed:.2f}'
                )
                current_best_completed = avg_completed
                U.save_state(
                    path.join(args.exp_path, 'best_completed', 'model'))

            perc_completed = local_["perc_completed"]
            if current_best_perc_completed < perc_completed:
                print(
                    f'Found better perc completed {current_best_perc_completed:.2f} -> {perc_completed:.2f}'
                )
                current_best_perc_completed = perc_completed
                U.save_state(
                    path.join(args.exp_path, 'perc_completed', 'model'))

            data = [
                local_['iters_so_far'],
                np.mean(local_['len_buffer']),
                np.mean(local_['rew_buffer']),
                np.mean(local_['origrew_buffer']),
                len(local_['lens']),
                local_['episodes_so_far'],
                local_['timesteps_so_far'],
                time.time() - local_['tstart'],
                avg_completed,
                perc_completed,
            ]
            if 'meanlosses' in local_:
                for lossval in local_['meanlosses']:
                    data.append(lossval)

            stats_f.write(",".join([str(x) for x in data]) + '\n')
            stats_f.flush()

    if args.load_model is not None:
        args.load_model += '/model'
    if args.submit_round2:
        submit_round2(env,
                      submit_env,
                      policy_fn,
                      load_model_path=args.load_model,
                      stochastic=False,
                      actions=args.actions)
        #submit_env.submit()   # submit_round2(...) submits already
        sys.exit()
    if args.evaluate:
        pposgd_simple.evaluate(env,
                               policy_fn,
                               load_model_path=args.load_model,
                               n_episodes=args.n_eval_episodes,
                               stochastic=not args.nostochastic,
                               actions=args.actions,
                               execute_just=args.execute_just)
    else:
        pposgd_simple.learn(
            env,
            policy_fn,
            max_timesteps=args.max_timesteps,
            timesteps_per_batch=args.timesteps_per_batch,
            clip_param=args.clip_param,
            entcoeff=args.entcoeff,
            optim_epochs=args.optim_epochs,
            optim_stepsize=args.optim_stepsize,
            optim_batchsize=args.optim_batchsize,
            gamma=args.gamma,
            lam=args.lam,
            callback=callback,
            load_model_path=args.load_model,
            test_only=args.test_only,
            stochastic=not args.nostochastic,
            symmetric_training=args.symmetric_training,
            obs_names=env_walker.obs_names,
            single_episode=args.single_episode,
            horizon_hack=args.horizon_hack,
            running_avg_len=args.running_avg_len,
            init_three=args.init_three,
            actions=args.actions,
            symmetric_training_trick=args.symmetric_training_trick,
            bootstrap_seeds=args.bootstrap_seeds,
            seeds_fn=args.seeds_fn,
        )
    env.close()
    vis = env.osim_model.model.updVisualizer().updSimbodyVisualizer()
    vis.setBackgroundType(vis.GroundAndSky)
    vis.setShowFrameNumber(True)
    vis.zoomCameraToShowAllGeometry()
    vis.setCameraFieldOfView(1)

if args.train:
    history = pposgd_simple.learn(
        env,
        policy_fn,
        max_timesteps=args.steps,
        timesteps_per_batch=args.batch,
        clip_param=args.clip,
        entcoeff=args.ent,
        optim_epochs=args.epochs,
        optim_stepsize=args.stepsize,
        optim_batchsize=args.optim_batch,
        adam_epsilon=1e-5,
        gamma=args.gamma,
        lam=0.95,
        schedule=args.schedule,
        callback=on_iteration_start,
        verbose=args.verbose,
    )

    env.close()

    if MPI.COMM_WORLD.Get_rank() == 0:
        plot_history(history)
        save_model()