Example #1
0
    def __init__(self,
                 state_rep: state_representation.StateRepresentation,
                 embedding_size: int,
                 zero_out: bool = True):
        super(DynamicEnvironmentEmbedder, self).__init__()

        self._card_count_embedder: word_embedder.WordEmbedder = word_embedder.WordEmbedder(embedding_size,
                                                                                           state_rep.get_card_counts(),
                                                                                           add_unk=False,
                                                                                           zero_out=zero_out)
        self._card_color_embedder: word_embedder.WordEmbedder = word_embedder.WordEmbedder(embedding_size,
                                                                                           state_rep.get_card_colors(),
                                                                                           add_unk=False,
                                                                                           zero_out=zero_out)
        self._card_shape_embedder: word_embedder.WordEmbedder = word_embedder.WordEmbedder(embedding_size,
                                                                                           state_rep.get_card_shapes(),
                                                                                           add_unk=False,
                                                                                           zero_out=zero_out)
        self._card_selection_embedder: word_embedder.WordEmbedder = word_embedder.WordEmbedder(
            embedding_size, state_rep.get_card_selection(), add_unk=False, zero_out=zero_out)
        self._leader_rotation_embedder: word_embedder.WordEmbedder = word_embedder.WordEmbedder(
            embedding_size, state_rep.get_leader_rotation(), add_unk=False, zero_out=zero_out)
        self._follower_rotation_embedder: word_embedder.WordEmbedder = word_embedder.WordEmbedder(
            embedding_size, state_rep.get_follower_rotation(), add_unk=False, zero_out=zero_out)

        self._embedders: List[word_embedder.WordEmbedder] = [self._card_count_embedder,
                                                             self._card_color_embedder,
                                                             self._card_shape_embedder,
                                                             self._card_selection_embedder,
                                                             self._leader_rotation_embedder,
                                                             self._follower_rotation_embedder]
Example #2
0
    def __init__(self, args: text_encoder_args.TextEncoderArgs,
                 vocabulary: List[str],
                 dropout: float = 0.):
        super(TextEncoder, self).__init__()

        self._args = args

        self._embedder: word_embedder.WordEmbedder = word_embedder.WordEmbedder(self._args.get_word_embedding_size(),
                                                                                vocabulary)

        self._each_dir_hidden_size: int = int(self._args.get_hidden_size() / 2)

        self._dropout = nn.Dropout(dropout)

        self._rnn: nn.LSTM = nn.LSTM(self._args.get_word_embedding_size(),
                                     self._each_dir_hidden_size,
                                     self._args.get_number_of_layers(),
                                     batch_first=True,
                                     bidirectional=True,
                                     dropout=dropout)
        initialization.initialize_rnn(self._rnn)
Example #3
0
    def __init__(self,
                 state_rep: state_representation.StateRepresentation,
                 embedding_size: int,
                 zero_out: bool = False):
        super(StaticEnvironmentEmbedder, self).__init__()

        self._state_rep: state_representation.StateRepresentation = state_rep

        if zero_out:
            raise ValueError(
                'Zeroing out zero-value properties is no longer supported due to optimizations.'
            )

        all_embeddings = state_rep.get_terrains() + state_rep.get_hut_colors() + state_rep.get_hut_rotations() + \
            state_rep.get_windmill_rotations() + state_rep.get_tower_rotations() + \
            state_rep.get_tent_rotations() + state_rep.get_tree_types() + state_rep.get_plant_types() + \
            state_rep.get_prop_types()
        self._embedder: word_embedder.WordEmbedder = word_embedder.WordEmbedder(
            embedding_size,
            all_embeddings,
            add_unk=False,
            must_be_unique=False)
Example #4
0
    def __init__(self,
                 state_rep: state_representation.StateRepresentation,
                 embedding_size: int,
                 zero_out: bool = False):
        super(DynamicEnvironmentEmbedder, self).__init__()

        self._state_rep: state_representation.StateRepresentation = state_rep

        if zero_out:
            raise ValueError(
                'Zeroing out zero-value properties is no longer supported due to optimizations.'
            )

        # A single embedding lookup is used for all properties.
        all_embeddings = state_rep.get_card_counts() + state_rep.get_card_colors() + state_rep.get_card_shapes() + \
            state_rep.get_card_selection() + state_rep.get_leader_rotation() + \
            state_rep.get_follower_rotation()
        self._embedder: word_embedder.WordEmbedder = word_embedder.WordEmbedder(
            embedding_size,
            all_embeddings,
            add_unk=False,
            must_be_unique=False)
    def __init__(self,
                 state_rep: state_representation.StateRepresentation,
                 embedding_size: int,
                 zero_out: bool = True):
        super(StaticEnvironmentEmbedder, self).__init__()

        self._terrain_embedder: word_embedder.WordEmbedder = word_embedder.WordEmbedder(
            embedding_size,
            state_rep.get_terrains(),
            add_unk=False,
            zero_out=False)

        self._hut_colors_embedder: word_embedder.WordEmbedder = word_embedder.WordEmbedder(
            embedding_size,
            state_rep.get_hut_colors(),
            add_unk=False,
            zero_out=zero_out)

        self._hut_rotations_embedder: word_embedder.WordEmbedder = word_embedder.WordEmbedder(
            embedding_size,
            state_rep.get_hut_rotations(),
            add_unk=False,
            zero_out=zero_out)

        self._windmill_rotations_embedder: word_embedder.WordEmbedder = word_embedder.WordEmbedder(
            embedding_size,
            state_rep.get_windmill_rotations(),
            add_unk=False,
            zero_out=zero_out)

        self._tower_rotations_embedder: word_embedder.WordEmbedder = word_embedder.WordEmbedder(
            embedding_size,
            state_rep.get_tower_rotations(),
            add_unk=False,
            zero_out=zero_out)

        self._tent_rotations_embedder: word_embedder.WordEmbedder = word_embedder.WordEmbedder(
            embedding_size,
            state_rep.get_tent_rotations(),
            add_unk=False,
            zero_out=zero_out)

        self._tree_types_embedder: word_embedder.WordEmbedder = word_embedder.WordEmbedder(
            embedding_size,
            state_rep.get_tree_types(),
            add_unk=False,
            zero_out=zero_out)

        self._plant_types_embedder: word_embedder.WordEmbedder = word_embedder.WordEmbedder(
            embedding_size,
            state_rep.get_plant_types(),
            add_unk=False,
            zero_out=zero_out)

        self._prop_types_embedder: word_embedder.WordEmbedder = word_embedder.WordEmbedder(
            embedding_size,
            state_rep.get_prop_types(),
            add_unk=False,
            zero_out=zero_out)

        self._embedders: List[word_embedder.WordEmbedder] = [
            self._prop_types_embedder, self._hut_colors_embedder,
            self._hut_rotations_embedder, self._tree_types_embedder,
            self._plant_types_embedder, self._windmill_rotations_embedder,
            self._tower_rotations_embedder, self._tent_rotations_embedder,
            self._terrain_embedder
        ]
Example #6
0
    def __init__(self,
                 args: model_args.ModelArgs,
                 input_vocabulary: List[str],
                 auxiliaries: List[auxiliary.Auxiliary],
                 load_pretrained: bool = True,
                 end_to_end: bool = False):
        super(ActionGeneratorModel, self).__init__()
        self._args: model_args.ModelArgs = args
        self._end_to_end = end_to_end

        # If end-to-end, also add in the hex predictor model.
        self._plan_predictor: Optional[
            plan_predictor_model.PlanPredictorModel] = None
        if self._end_to_end or load_pretrained:
            self._plan_predictor: plan_predictor_model.PlanPredictorModel = plan_predictor_model.PlanPredictorModel(
                args, input_vocabulary, auxiliaries)

        self._output_layer = None
        self._rnn = None
        self._action_embedder = None

        if self._args.get_decoder_args().use_recurrence():
            self._action_embedder: word_embedder.WordEmbedder = \
                word_embedder.WordEmbedder(
                    self._args.get_decoder_args().get_action_embedding_size(),
                    [str(action) for action in agent_actions.AGENT_ACTIONS],
                    add_unk=False)

            self._rnn: nn.Module = nn.LSTM(
                self._args.get_decoder_args().get_state_internal_size() +
                self._args.get_decoder_args().get_action_embedding_size(),
                self._args.get_decoder_args().get_hidden_size(),
                self._args.get_decoder_args().get_num_layers(),
                batch_first=True)

            # Add a different output layer
            self._output_layer: nn.Module = nn.Linear(
                self._args.get_decoder_args().get_hidden_size() +
                self._args.get_decoder_args().get_state_internal_size(),
                len(agent_actions.AGENT_ACTIONS))
            torch.nn.init.orthogonal_(
                self._output_layer.weight,
                torch.nn.init.calculate_gain("leaky_relu"))
            self._output_layer.bias.data.fill_(0)

        distribution_num_channels: int = 0
        if self._args.get_decoder_args().use_trajectory_distribution():
            distribution_num_channels += 1
        if self._args.get_decoder_args().use_goal_probabilities():
            distribution_num_channels += 1
        if self._args.get_decoder_args().use_obstacle_probabilities():
            distribution_num_channels += 1
        if self._args.get_decoder_args().use_avoid_probabilities():
            distribution_num_channels += 1

        self._map_distribution_embedder: map_distribution_embedder.MapDistributionEmbedder = \
            map_distribution_embedder.MapDistributionEmbedder(
                distribution_num_channels,
                self._args.get_decoder_args().get_state_internal_size(),
                self._args.get_decoder_args().get_state_internal_size()
                if self._args.get_decoder_args().use_recurrence() else len(
                    agent_actions.AGENT_ACTIONS),
                self._args.get_decoder_args().get_crop_size(),
                self._args.get_decoder_args().convolution_encode_map_distributions(),
                self._args.get_decoder_args().use_recurrence())

        if load_pretrained:
            if self._args.get_decoder_args().pretrained_generator():
                initialization.load_pretrained_parameters(
                    self._args.get_decoder_args(
                    ).pretrained_action_generator_filepath(),
                    module=self)
            if self._args.get_decoder_args().pretrained_plan_predictor():
                initialization.load_pretrained_parameters(
                    self._args.get_decoder_args(
                    ).pretrained_plan_predictor_filepath(),
                    module=self._plan_predictor)