def train_mf(mb_steps, policy_weight, trainer, sampler, worker, dynamics, policy, reward, args=None): logger.info('Training starts at {}'.format(init_path.get_abs_base_dir())) network_type = {'policy': policy, 'dynamics': dynamics, 'reward': reward} # make the trainer and sampler sampler_agent = make_sampler(sampler, worker, network_type, args) trainer_tasks, trainer_results, trainer_agent, init_weights = \ make_trainer(trainer, network_type, args) # Initialize the policy with dagger policy weight. trainer_tasks.put((parallel_util.SET_POLICY_WEIGHT, policy_weight)) trainer_tasks.join() init_weights['policy'][0] = policy_weight sampler_agent.set_weights(init_weights) timer_dict = OrderedDict() timer_dict['Program Start'] = time.time() current_iteration = 0 while True: timer_dict['** Program Total Time **'] = time.time() # step 1: collect rollout data rollout_data = \ sampler_agent.rollouts_using_worker_playing(use_true_env=True) timer_dict['Generate Rollout'] = time.time() # step 2: train the weights for dynamics and policy network training_info = {'network_to_train': ['dynamics', 'reward', 'policy']} trainer_tasks.put( (parallel_util.TRAIN_SIGNAL, {'data': rollout_data['data'], 'training_info': training_info}) ) trainer_tasks.join() training_return = trainer_results.get() timer_dict['Train Weights'] = time.time() # step 4: update the weights sampler_agent.set_weights(training_return['network_weights']) timer_dict['Assign Weights'] = time.time() # log and print the results log_results(training_return, timer_dict, mb_steps) if training_return['totalsteps'] > args.max_timesteps: break else: current_iteration += 1 # end of training sampler_agent.end() trainer_tasks.put((parallel_util.END_SIGNAL, None))
def train_mb(trainer, sampler, worker, dynamics, policy, reward, args=None): logger.info('Training starts at {}'.format(init_path.get_abs_base_dir())) network_type = {'policy': policy, 'dynamics': dynamics, 'reward': reward} # make the trainer and sampler sampler_agent = make_sampler(sampler, worker, network_type, args) trainer_tasks, trainer_results, trainer_agent, init_weights = \ make_trainer(trainer, network_type, args) sampler_agent.set_weights(init_weights) timer_dict = OrderedDict() timer_dict['Program Start'] = time.time() totalsteps = 0 current_iteration = 0 init_data = {} # Start mb training. while True: timer_dict['** Program Total Time **'] = time.time() # step 1: collect rollout data if current_iteration == 0 and args.random_timesteps > 0 and \ (not (args.gt_dynamics and args.gt_reward)): # we could first generate random rollout data for exploration logger.info( 'Generating {} random timesteps'.format(args.random_timesteps) ) rollout_data = sampler_agent.rollouts_using_worker_planning( args.random_timesteps, use_random_action=True ) else: rollout_data = sampler_agent.rollouts_using_worker_planning() timer_dict['Generate Rollout'] = time.time() # step 2: train the weights for dynamics and policy network training_info = {'network_to_train': ['dynamics', 'reward']} trainer_tasks.put( (parallel_util.TRAIN_SIGNAL, {'data': rollout_data['data'], 'training_info': training_info}) ) trainer_tasks.join() training_return = trainer_results.get() timer_dict['Train Weights'] = time.time() # step 4: update the weights sampler_agent.set_weights(training_return['network_weights']) timer_dict['Assign Weights'] = time.time() # log and print the results log_results(training_return, timer_dict) for key in rollout_data.keys(): if key not in init_data.keys(): init_data[key] = [] init_data[key].extend(rollout_data[key]) # Add noise to initial data to encourge trpo to explore. import numpy as np for i_rollout in init_data['data']: action = i_rollout['actions'] i_rollout['actions'] += np.random.normal(scale=0.005, size=action.shape) if totalsteps > args.max_timesteps or \ training_return['replay_buffer'].get_current_size() > \ args.mb_timesteps: break else: current_iteration += 1 totalsteps = training_return['totalsteps'] # Initilize policy network training_info = {'network_to_train': ['reward', 'policy']} trainer_tasks.put( (parallel_util.MBMF_INITIAL, {'data': init_data['data'], 'training_info': training_info}) ) trainer_tasks.join() training_return = trainer_results.get() timer_dict['Train Weights'] = time.time() # Start dagger iteration. for dagger_i in range(args.dagger_iter): print('=================Doing dagger iteration {}=================='. format(dagger_i)) # Collect on policy rollout. rollout_data = sampler_agent.rollouts_using_worker_playing( num_timesteps=args.dagger_timesteps_per_iter, use_true_env=True) sampler_agent.dagger_rollouts(rollout_data['data']) init_data['data'] += rollout_data['data'] trainer_tasks.put( (parallel_util.MBMF_INITIAL, {'data': init_data['data'], 'training_info': training_info}) ) trainer_tasks.join() training_return = trainer_results.get() trainer_tasks.put((parallel_util.GET_POLICY_WEIGHT, None)) trainer_tasks.join() policy_weight = trainer_results.get() # end of training sampler_agent.end() trainer_tasks.put((parallel_util.END_SIGNAL, None)) return totalsteps, policy_weight
def train(trainer, sampler, worker, dynamics, policy, reward, args=None): logger.info('Training starts at {}'.format(init_path.get_abs_base_dir())) network_type = {'policy': policy, 'dynamics': dynamics, 'reward': reward} # make the trainer and sampler sampler_agent = make_sampler(sampler, worker, network_type, args) real_trainer_tasks, real_trainer_results, _, init_weights = \ make_trainer(trainer, network_type, args, "real_trainer") fake_trainer_tasks, fake_trainer_results, _, _ = \ make_trainer(trainer, network_type, args, "fake_trainer") sampler_agent.set_weights(init_weights) timer_dict = OrderedDict() timer_dict['Program Start'] = time.time() current_iteration = 0 while True: timer_dict['** Program Total Time **'] = time.time() # step 1: collect rollout data if current_iteration == 0 and args.random_timesteps > 0: # we could first generate random rollout data for exploration logger.info('Generating {} random timesteps'.format( args.random_timesteps)) rollout_data = sampler_agent.rollouts_using_worker_playing( args.random_timesteps, use_random_action=True, use_true_env=True, ) else: rollout_data = sampler_agent.rollouts_using_worker_playing( use_true_env=True) timer_dict['Generate Real Rollout'] = time.time() # step 2: train the weights for dynamics or reward network training_info = {'network_to_train': ['dynamics', 'reward']} real_trainer_tasks.put((parallel_util.TRAIN_SIGNAL, { 'data': rollout_data['data'], 'training_info': training_info })) real_trainer_tasks.join() real_training_return = real_trainer_results.get() timer_dict['Train Weights of Dynamics'] = time.time() totalsteps = real_training_return['totalsteps'] # set weights sampler_agent.set_weights( {"dynamics": real_training_return['network_weights']["dynamics"]}) while True: # step 3: collect rollout data in fake env rollout_data = sampler_agent.rollouts_using_worker_playing( num_timesteps=args.policy_batch_size, use_true_env=False) # step 4: train the weights for policy network training_info = {'network_to_train': ['policy']} fake_trainer_tasks.put((parallel_util.TRAIN_SIGNAL, { 'data': rollout_data['data'], 'training_info': training_info })) fake_trainer_tasks.join() fake_training_return = fake_trainer_results.get() timer_dict['Train Weights of Policy'] = time.time() # step 5: update the weights sampler_agent.set_weights( {"policy": fake_training_return['network_weights']["policy"]}) timer_dict['Assign Weights'] = time.time() fake_totalsteps = fake_training_return['totalsteps'] print(fake_totalsteps) print(args.max_fake_timesteps) if fake_totalsteps > args.max_fake_timesteps: break fake_trainer_tasks.put((parallel_util.RESET_SIGNAL, None)) # log and print the results log_results(real_training_return, timer_dict) # TODO(GD): update totalsteps? if totalsteps > args.max_timesteps: break else: current_iteration += 1 # end of training sampler_agent.end() real_trainer_tasks.put((parallel_util.END_SIGNAL, None)) fake_trainer_tasks.put((parallel_util.END_SIGNAL, None))