Exemple #1
0
def train_mf(mb_steps, policy_weight, trainer, sampler, worker, dynamics,
             policy, reward, args=None):
    logger.info('Training starts at {}'.format(init_path.get_abs_base_dir()))
    network_type = {'policy': policy, 'dynamics': dynamics, 'reward': reward}

    # make the trainer and sampler
    sampler_agent = make_sampler(sampler, worker, network_type, args)
    trainer_tasks, trainer_results, trainer_agent, init_weights = \
        make_trainer(trainer, network_type, args)

    # Initialize the policy with dagger policy weight.
    trainer_tasks.put((parallel_util.SET_POLICY_WEIGHT, policy_weight))
    trainer_tasks.join()
    init_weights['policy'][0] = policy_weight
    sampler_agent.set_weights(init_weights)

    timer_dict = OrderedDict()
    timer_dict['Program Start'] = time.time()
    current_iteration = 0

    while True:
        timer_dict['** Program Total Time **'] = time.time()

        # step 1: collect rollout data
        rollout_data = \
            sampler_agent.rollouts_using_worker_playing(use_true_env=True)

        timer_dict['Generate Rollout'] = time.time()

        # step 2: train the weights for dynamics and policy network
        training_info = {'network_to_train': ['dynamics', 'reward', 'policy']}
        trainer_tasks.put(
            (parallel_util.TRAIN_SIGNAL,
             {'data': rollout_data['data'], 'training_info': training_info})
        )
        trainer_tasks.join()
        training_return = trainer_results.get()
        timer_dict['Train Weights'] = time.time()

        # step 4: update the weights
        sampler_agent.set_weights(training_return['network_weights'])
        timer_dict['Assign Weights'] = time.time()

        # log and print the results
        log_results(training_return, timer_dict, mb_steps)

        if training_return['totalsteps'] > args.max_timesteps:
            break
        else:
            current_iteration += 1

    # end of training
    sampler_agent.end()
    trainer_tasks.put((parallel_util.END_SIGNAL, None))
Exemple #2
0
def train_mb(trainer, sampler, worker, dynamics, policy, reward, args=None):
    logger.info('Training starts at {}'.format(init_path.get_abs_base_dir()))
    network_type = {'policy': policy, 'dynamics': dynamics, 'reward': reward}

    # make the trainer and sampler
    sampler_agent = make_sampler(sampler, worker, network_type, args)
    trainer_tasks, trainer_results, trainer_agent, init_weights = \
        make_trainer(trainer, network_type, args)
    sampler_agent.set_weights(init_weights)

    timer_dict = OrderedDict()
    timer_dict['Program Start'] = time.time()
    totalsteps = 0
    current_iteration = 0
    init_data = {}

    # Start mb training.
    while True:
        timer_dict['** Program Total Time **'] = time.time()

        # step 1: collect rollout data
        if current_iteration == 0 and args.random_timesteps > 0 and \
                (not (args.gt_dynamics and args.gt_reward)):
            # we could first generate random rollout data for exploration
            logger.info(
                'Generating {} random timesteps'.format(args.random_timesteps)
            )
            rollout_data = sampler_agent.rollouts_using_worker_planning(
                args.random_timesteps, use_random_action=True
            )
        else:
            rollout_data = sampler_agent.rollouts_using_worker_planning()

        timer_dict['Generate Rollout'] = time.time()

        # step 2: train the weights for dynamics and policy network
        training_info = {'network_to_train': ['dynamics', 'reward']}
        trainer_tasks.put(
            (parallel_util.TRAIN_SIGNAL,
             {'data': rollout_data['data'], 'training_info': training_info})
        )
        trainer_tasks.join()
        training_return = trainer_results.get()
        timer_dict['Train Weights'] = time.time()

        # step 4: update the weights
        sampler_agent.set_weights(training_return['network_weights'])
        timer_dict['Assign Weights'] = time.time()

        # log and print the results
        log_results(training_return, timer_dict)

        for key in rollout_data.keys():
            if key not in init_data.keys():
                init_data[key] = []
            init_data[key].extend(rollout_data[key])

        # Add noise to initial data to encourge trpo to explore.
        import numpy as np
        for i_rollout in init_data['data']:
            action = i_rollout['actions']
            i_rollout['actions'] += np.random.normal(scale=0.005,
                                                     size=action.shape)
        if totalsteps > args.max_timesteps or \
                training_return['replay_buffer'].get_current_size() > \
                args.mb_timesteps:
            break
        else:
            current_iteration += 1
    totalsteps = training_return['totalsteps']

    # Initilize policy network
    training_info = {'network_to_train': ['reward', 'policy']}
    trainer_tasks.put(
        (parallel_util.MBMF_INITIAL,
         {'data': init_data['data'], 'training_info': training_info})
    )
    trainer_tasks.join()
    training_return = trainer_results.get()
    timer_dict['Train Weights'] = time.time()

    # Start dagger iteration.
    for dagger_i in range(args.dagger_iter):
        print('=================Doing dagger iteration {}=================='.
              format(dagger_i))
        # Collect on policy rollout.
        rollout_data = sampler_agent.rollouts_using_worker_playing(
            num_timesteps=args.dagger_timesteps_per_iter,
            use_true_env=True)
        sampler_agent.dagger_rollouts(rollout_data['data'])
        init_data['data'] += rollout_data['data']
        trainer_tasks.put(
            (parallel_util.MBMF_INITIAL,
             {'data': init_data['data'], 'training_info': training_info})
        )
        trainer_tasks.join()
        training_return = trainer_results.get()

    trainer_tasks.put((parallel_util.GET_POLICY_WEIGHT, None))
    trainer_tasks.join()
    policy_weight = trainer_results.get()

    # end of training
    sampler_agent.end()
    trainer_tasks.put((parallel_util.END_SIGNAL, None))
    return totalsteps, policy_weight
def train(trainer, sampler, worker, dynamics, policy, reward, args=None):
    logger.info('Training starts at {}'.format(init_path.get_abs_base_dir()))
    network_type = {'policy': policy, 'dynamics': dynamics, 'reward': reward}

    # make the trainer and sampler
    sampler_agent = make_sampler(sampler, worker, network_type, args)
    real_trainer_tasks, real_trainer_results, _, init_weights = \
        make_trainer(trainer, network_type, args, "real_trainer")
    fake_trainer_tasks, fake_trainer_results, _, _ = \
            make_trainer(trainer, network_type, args, "fake_trainer")
    sampler_agent.set_weights(init_weights)

    timer_dict = OrderedDict()
    timer_dict['Program Start'] = time.time()
    current_iteration = 0

    while True:
        timer_dict['** Program Total Time **'] = time.time()

        # step 1: collect rollout data
        if current_iteration == 0 and args.random_timesteps > 0:
            # we could first generate random rollout data for exploration
            logger.info('Generating {} random timesteps'.format(
                args.random_timesteps))
            rollout_data = sampler_agent.rollouts_using_worker_playing(
                args.random_timesteps,
                use_random_action=True,
                use_true_env=True,
            )
        else:
            rollout_data = sampler_agent.rollouts_using_worker_playing(
                use_true_env=True)

        timer_dict['Generate Real Rollout'] = time.time()

        # step 2: train the weights for dynamics or reward network
        training_info = {'network_to_train': ['dynamics', 'reward']}
        real_trainer_tasks.put((parallel_util.TRAIN_SIGNAL, {
            'data': rollout_data['data'],
            'training_info': training_info
        }))
        real_trainer_tasks.join()
        real_training_return = real_trainer_results.get()
        timer_dict['Train Weights of Dynamics'] = time.time()
        totalsteps = real_training_return['totalsteps']

        # set weights
        sampler_agent.set_weights(
            {"dynamics": real_training_return['network_weights']["dynamics"]})

        while True:
            # step 3: collect rollout data in fake env
            rollout_data = sampler_agent.rollouts_using_worker_playing(
                num_timesteps=args.policy_batch_size, use_true_env=False)

            # step 4: train the weights for policy network
            training_info = {'network_to_train': ['policy']}
            fake_trainer_tasks.put((parallel_util.TRAIN_SIGNAL, {
                'data': rollout_data['data'],
                'training_info': training_info
            }))
            fake_trainer_tasks.join()
            fake_training_return = fake_trainer_results.get()
            timer_dict['Train Weights of Policy'] = time.time()

            # step 5: update the weights
            sampler_agent.set_weights(
                {"policy": fake_training_return['network_weights']["policy"]})
            timer_dict['Assign Weights'] = time.time()

            fake_totalsteps = fake_training_return['totalsteps']
            print(fake_totalsteps)
            print(args.max_fake_timesteps)
            if fake_totalsteps > args.max_fake_timesteps:
                break

        fake_trainer_tasks.put((parallel_util.RESET_SIGNAL, None))

        # log and print the results
        log_results(real_training_return, timer_dict)

        # TODO(GD): update totalsteps?
        if totalsteps > args.max_timesteps:
            break
        else:
            current_iteration += 1

    # end of training
    sampler_agent.end()
    real_trainer_tasks.put((parallel_util.END_SIGNAL, None))
    fake_trainer_tasks.put((parallel_util.END_SIGNAL, None))