def get_td3_trainer(env, parameters, use_gpu): state_dim = get_num_output_features(env.normalization) action_dim = get_num_output_features(env.normalization_action) q1_network = FullyConnectedParametricDQN( state_dim, action_dim, parameters.q_network.layers, parameters.q_network.activations, ) q2_network = None if parameters.training.use_2_q_functions: q2_network = FullyConnectedParametricDQN( state_dim, action_dim, parameters.q_network.layers, parameters.q_network.activations, ) actor_network = FullyConnectedActor( state_dim, action_dim, parameters.actor_network.layers, parameters.actor_network.activations, ) min_action_range_tensor_training = torch.full((1, action_dim), -1) max_action_range_tensor_training = torch.full((1, action_dim), 1) min_action_range_tensor_serving = torch.FloatTensor( env.action_space.low).unsqueeze(dim=0) max_action_range_tensor_serving = torch.FloatTensor( env.action_space.high).unsqueeze(dim=0) if use_gpu: q1_network.cuda() if q2_network: q2_network.cuda() actor_network.cuda() min_action_range_tensor_training = min_action_range_tensor_training.cuda( ) max_action_range_tensor_training = max_action_range_tensor_training.cuda( ) min_action_range_tensor_serving = min_action_range_tensor_serving.cuda( ) max_action_range_tensor_serving = max_action_range_tensor_serving.cuda( ) trainer_args = [q1_network, actor_network, parameters] trainer_kwargs = { "q2_network": q2_network, "min_action_range_tensor_training": min_action_range_tensor_training, "max_action_range_tensor_training": max_action_range_tensor_training, "min_action_range_tensor_serving": min_action_range_tensor_serving, "max_action_range_tensor_serving": max_action_range_tensor_serving, } return TD3Trainer(*trainer_args, use_gpu=use_gpu, **trainer_kwargs)
def get_td3_trainer(self, env, parameters, use_gpu): state_dim = get_num_output_features(env.normalization) action_dim = get_num_output_features(env.normalization_action) q1_network = FullyConnectedParametricDQN( state_dim, action_dim, parameters.q_network.layers, parameters.q_network.activations, ) q2_network = None if parameters.training.use_2_q_functions: q2_network = FullyConnectedParametricDQN( state_dim, action_dim, parameters.q_network.layers, parameters.q_network.activations, ) actor_network = FullyConnectedActor( state_dim, action_dim, parameters.actor_network.layers, parameters.actor_network.activations, ) if use_gpu: q1_network.cuda() if q2_network: q2_network.cuda() actor_network.cuda() return TD3Trainer( q1_network, actor_network, parameters, q2_network=q2_network, use_gpu=use_gpu, )