def experiment(args, variant): #eval_env = gym.make('FetchReach-v1') #expl_env = gym.make('FetchReach-v1') core_env = env.DeepBuilderEnv(args.session_name, args.act_dim, args.box_dim, args.max_num_boxes, args.height_field_dim) eval_env = stuff.NormalizedActions(core_env) expl_env = stuff.NormalizedActions(core_env) obs_dim = expl_env.observation_space.low.size action_dim = eval_env.action_space.low.size resumed = args.resume == 1 if resumed: variant, params = doc.load_rklit_file(args.session_name) variant['algorithm_kwargs']['min_num_steps_before_training'] = 0 M = variant['layer_size'] qf1 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) if not resumed else params['trainer/qf1'] qf2 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) if not resumed else params['trainer/qf2'] target_qf1 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) if not resumed else params['trainer/target_qf1'] target_qf2 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) if not resumed else params['trainer/target_qf2'] policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=[M, M], ) if not resumed else params['trainer/policy'] eval_policy = MakeDeterministic( policy) if not resumed else params['evaluation/policy'] eval_path_collector = MdpPathCollector( eval_env, eval_policy, ) expl_path_collector = MdpPathCollector( expl_env, policy, ) replay_buffer_expl = EnvReplayBuffer( variant['replay_buffer_size'], expl_env, ) replay_buffer_eval = EnvReplayBuffer( int(variant['replay_buffer_size'] * (float(args.num_plays_eval) / float(args.num_plays_expl))), eval_env, ) if resumed: replay_buffer_expl._actions = params['replay_buffer_expl/actions'] replay_buffer_expl._env_infos = params['replay_buffer_expl/env_infos'] replay_buffer_expl._next_obs = params['replay_buffer_expl/next_obs'] replay_buffer_expl._observations = params[ 'replay_buffer_expl/observations'] replay_buffer_expl._rewards = params['replay_buffer_expl/rewards'] replay_buffer_expl._size = params['replay_buffer_expl/size'] replay_buffer_expl._terminals = params['replay_buffer_expl/terminals'] replay_buffer_expl._top = params['replay_buffer_expl/top'] replay_buffer_eval._actions = params['replay_buffer_eval/actions'] replay_buffer_eval._env_infos = params['replay_buffer_eval/env_infos'] replay_buffer_eval._next_obs = params['replay_buffer_eval/next_obs'] replay_buffer_eval._observations = params[ 'replay_buffer_eval/observations'] replay_buffer_eval._rewards = params['replay_buffer_eval/rewards'] replay_buffer_eval._size = params['replay_buffer_eval/size'] replay_buffer_eval._terminals = params['replay_buffer_eval/terminals'] replay_buffer_eval._top = params['replay_buffer_eval/top'] elif args.replay_add_sess_name != '': _, other_params = doc.load_rklit_file(args.replay_add_sess_name) num_samples = int(args.replay_add_num_samples) replay_buffer_expl._size = 0 replay_buffer_expl._top = 0 print("Loading " + str(num_samples) + " batch samples from session " + args.replay_add_sess_name) zeroes = [] offset = 0 for i in range(num_samples): act = other_params['replay_buffer_expl/actions'][i] obs = other_params['replay_buffer_expl/observations'][i] if act.min() == 0.0 and act.max() == 0.0 and obs.min( ) == 0.0 and obs.max() == 0.0: zeroes.append(i) continue replay_buffer_expl._actions[offset] = copy.deepcopy(act.tolist()) replay_buffer_expl._next_obs[offset] = copy.deepcopy( other_params['replay_buffer_expl/next_obs'][i].tolist()) replay_buffer_expl._observations[offset] = copy.deepcopy( obs.tolist()) replay_buffer_expl._rewards[offset] = copy.deepcopy( other_params['replay_buffer_expl/rewards'][i].tolist()) replay_buffer_expl._terminals[offset] = copy.deepcopy( other_params['replay_buffer_expl/terminals'][i].tolist()) replay_buffer_expl._size += 1 replay_buffer_expl._top += 1 offset += 1 print( "Detected and ignored " + str(len(zeroes)) + " zero samples in replay buffer. Total num samples loaded into replay buffer: " + str(replay_buffer_expl._size)) other_params = {} trainer = SACTrainer( env=eval_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **variant['trainer_kwargs'], starting_train_steps=0 if not resumed else (params['replay_buffer_expl/top'] * variant['algorithm_kwargs']['num_trains_per_train_loop']), ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer_eval=replay_buffer_eval, replay_buffer_expl=replay_buffer_expl, **variant['algorithm_kwargs']) algorithm.to(ptu.device) algorithm.train()
def experiment(args, variant): core_env = env.DeepBuilderEnv(args.session_name, args.act_dim, args.box_dim, args.max_num_boxes, args.height_field_dim) eval_env = stuff.NormalizedActions(core_env) expl_env = stuff.NormalizedActions(core_env) obs_dim = expl_env.observation_space.low.size action_dim = eval_env.action_space.low.size M = variant['layer_size'] qf1 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) qf2 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) target_qf1 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) target_qf2 = FlattenMlp( input_size=obs_dim + action_dim, output_size=1, hidden_sizes=[M, M], ) policy = TanhGaussianPolicy( obs_dim=obs_dim, action_dim=action_dim, hidden_sizes=[M, M], ) eval_policy = MakeDeterministic(policy) eval_path_collector = MdpPathCollector( eval_env, eval_policy, ) expl_path_collector = MdpPathCollector( expl_env, policy, ) replay_buffer = EnvReplayBuffer( 222726+21, expl_env, ) replay_buffer_eval = EnvReplayBuffer(21, eval_env) if args.replay_add_sess_name_1 != '': _, other_params = doc.load_rklit_file(args.replay_add_sess_name_1) num_samples = int(args.replay_add_num_samples_1) replay_buffer._size = 0 replay_buffer._top = 0 offset = 0 print("Loading "+str(num_samples)+" batch samples from session " + args.replay_add_sess_name_1) for i in range(num_samples): act = other_params['replay_buffer_expl/actions'][i] obs = other_params['replay_buffer_expl/observations'][i] if not (act.min()== 0.0 and act.max() == 0.0 and obs.min() == 0.0 and obs.max() == 0.0): replay_buffer._actions[i] = act replay_buffer._next_obs[i] = other_params['replay_buffer_expl/next_obs'][i] replay_buffer._observations[i] = obs replay_buffer._rewards[i] = other_params['replay_buffer_expl/rewards'][i] replay_buffer._terminals[i] = other_params['replay_buffer_expl/terminals'][i] replay_buffer._size += 1 replay_buffer._top += 1 offset+=1 if args.replay_add_sess_name_2 != '': _, other_params = doc.load_rklit_file(args.replay_add_sess_name_2) num_samples = int(args.replay_add_num_samples_2) print("Loading "+str(num_samples)+" batch samples from session " + args.replay_add_sess_name_2) for i in range(21021, num_samples): act = other_params['replay_buffer_expl/actions'][i] obs = other_params['replay_buffer_expl/observations'][i] if not (act.min()== 0.0 and act.max() == 0.0 and obs.min() == 0.0 and obs.max() == 0.0): replay_buffer._actions[offset] = act replay_buffer._next_obs[offset] = other_params['replay_buffer_expl/next_obs'][i] replay_buffer._observations[offset] = obs replay_buffer._rewards[offset] = other_params['replay_buffer_expl/rewards'][i] replay_buffer._terminals[offset] = other_params['replay_buffer_expl/terminals'][i] replay_buffer._size += 1 replay_buffer._top += 1 offset+=1 ''' if args.replay_add_sess_name_3 != args.replay_add_sess_name_2: #_, other_params = doc.load_rklit_file(args.replay_add_sess_name_3) num_samples = int(args.replay_add_num_samples_3) print("Loading "+str(num_samples)+" batch samples from session " + args.replay_add_sess_name_3) for i in range(num_samples): act = other_params['replay_buffer_eval/actions'][i] obs = other_params['replay_buffer_eval/observations'][i] if not (act.min()== 0.0 and act.max() == 0.0 and obs.min() == 0.0 and obs.max() == 0.0): replay_buffer._actions[offset] = act replay_buffer._next_obs[offset] = other_params['replay_buffer_eval/next_obs'][i] replay_buffer._observations[offset] = obs replay_buffer._rewards[offset] = other_params['replay_buffer_eval/rewards'][i] replay_buffer._terminals[offset] = other_params['replay_buffer_eval/terminals'][i] replay_buffer._size += 1 replay_buffer._top += 1 offset+=1 ''' del other_params print("Detected and removed "+str(replay_buffer._max_replay_buffer_size - replay_buffer._size)+" zero samples. Final size of replay buffer: " + str(replay_buffer._size)) trainer = SACTrainer( env=eval_env, policy=policy, qf1=qf1, qf2=qf2, target_qf1=target_qf1, target_qf2=target_qf2, **variant['trainer_kwargs'] ) algorithm = TorchBatchRLAlgorithm( trainer=trainer, exploration_env=expl_env, evaluation_env=eval_env, exploration_data_collector=expl_path_collector, evaluation_data_collector=eval_path_collector, replay_buffer_expl=replay_buffer, replay_buffer_eval=replay_buffer_eval, **variant['algorithm_kwargs'] ) algorithm.to(ptu.device) algorithm.train()