예제 #1
0
def behavior_spec_from_proto(brain_param_proto: BrainParametersProto,
                             agent_info: AgentInfoProto) -> BehaviorSpec:
    """
    Converts brain parameter and agent info proto to BehaviorSpec object.
    :param brain_param_proto: protobuf object.
    :param agent_info: protobuf object.
    :return: BehaviorSpec object.
    """
    observation_specs = []
    for obs in agent_info.observations:
        observation_specs.append(
            ObservationSpec(
                tuple(obs.shape),
                tuple(
                    DimensionProperty(dim)
                    for dim in obs.dimension_properties),
                ObservationType(obs.observation_type),
            ))
    # proto from communicator < v1.3 does not set action spec, use deprecated fields instead
    if (brain_param_proto.action_spec.num_continuous_actions == 0
            and brain_param_proto.action_spec.num_discrete_actions == 0):
        if brain_param_proto.vector_action_space_type_deprecated == 1:
            action_spec = ActionSpec(
                brain_param_proto.vector_action_size_deprecated[0], ())
        else:
            action_spec = ActionSpec(
                0, tuple(brain_param_proto.vector_action_size_deprecated))
    else:
        action_spec_proto = brain_param_proto.action_spec
        action_spec = ActionSpec(
            action_spec_proto.num_continuous_actions,
            tuple(branch
                  for branch in action_spec_proto.discrete_branch_sizes),
        )
    return BehaviorSpec(observation_specs, action_spec)
예제 #2
0
    def __init__(self, env_file, time_scale=10., no_graphics=False):
        self.env = self._create_env(env_file, time_scale, no_graphics)
        # Reset to get behavior names
        self.env.reset()

        self.behavior_name = list(self.env.behavior_specs)[0]
        self.observation_space = ObservationSpec(640, None, None)
        self.action_space = ActionSpec(continuous_size=90, discrete_branches=())
예제 #3
0
def create_observation_specs_with_shapes(
        shapes: List[Tuple[int, ...]]) -> List[ObservationSpec]:
    obs_specs: List[ObservationSpec] = []
    for shape in shapes:
        dim_prop = (DimensionProperty.UNSPECIFIED, ) * len(shape)
        spec = ObservationSpec(shape, dim_prop, ObservationType.DEFAULT)
        obs_specs.append(spec)
    return obs_specs
예제 #4
0
def create_observation_specs_with_shapes(
        shapes: List[Tuple[int, ...]]) -> List[ObservationSpec]:
    obs_specs: List[ObservationSpec] = []
    for i, shape in enumerate(shapes):
        dim_prop = (DimensionProperty.UNSPECIFIED, ) * len(shape)
        if len(shape) == 2:
            dim_prop = (DimensionProperty.VARIABLE_SIZE,
                        DimensionProperty.NONE)
        spec = ObservationSpec(
            name=f"observation {i} with shape {shape}",
            shape=shape,
            dimension_property=dim_prop,
            observation_type=ObservationType.DEFAULT,
        )
        obs_specs.append(spec)
    return obs_specs
예제 #5
0
 def _make_observation_specs(self) -> List[ObservationSpec]:
     obs_shape: List[Any] = []
     for _ in range(self.num_vector):
         obs_shape.append((self.vec_obs_size, ))
     for _ in range(self.num_visual):
         obs_shape.append(self.vis_obs_size)
     for _ in range(self.num_var_len):
         obs_shape.append(self.var_len_obs_size)
     obs_spec = create_observation_specs_with_shapes(obs_shape)
     if self.goal_indices is not None:
         for i in range(len(obs_spec)):
             if i in self.goal_indices:
                 obs_spec[i] = ObservationSpec(
                     shape=obs_spec[i].shape,
                     dimension_property=obs_spec[i].dimension_property,
                     observation_type=ObservationType.GOAL,
                     name=obs_spec[i].name,
                 )
     return obs_spec
예제 #6
0
    dynamic_axes.update({'continuous_actions': {0: "batch"}})
    dynamic_axes.update({'action': {0: "batch"}})

    torch.onnx.export(network,
                      dummy_input,
                      EXPORT_FILE,
                      opset_version=9,
                      input_names=input_names,
                      output_names=output_names,
                      dynamic_axes=dynamic_axes)


if __name__ == '__main__':
    obs_spec = [
        ObservationSpec(shape=(16, ),
                        dimension_property=(DimensionProperty.UNSPECIFIED, ),
                        observation_type=ObservationType.DEFAULT)
    ]
    act_spec = ActionSpec(continuous_size=4, discrete_branches=())
    net_settings = NetworkSettings(normalize=False,
                                   hidden_units=256,
                                   num_layers=2,
                                   vis_encode_type=EncoderType.SIMPLE,
                                   memory=NetworkSettings.MemorySettings(
                                       sequence_length=64, memory_size=256))

    network = SerializableSimpleActor(obs_spec, net_settings, act_spec)
    state_dict = torch.load(MODEL_FILE, map_location=torch.device('cpu'))
    filtered_sd = {
        i: j
        for i, j in state_dict['Policy'].items() if 'critic' not in i