Beispiel #1
0
    def __init__(
            self,
            observation_space,
            g_action_space,
            l_action_space,
            pretrain_path,
            output_size=512,
            obs_transform=ResizeCenterCropper(size=(256, 256)),
    ):
        super().__init__(
            observation_space,
            g_action_space,
            l_action_space,
            pretrain_path,
            output_size=512,
            obs_transform=ResizeCenterCropper(size=(256, 256)),
        )

        self.net = ObjectNavGraphSLAMNet(
            observation_space=observation_space,
            g_action_space=g_action_space,
            output_size=output_size,
            obs_transform=obs_transform,
            pretrain_path=pretrain_path,
        )
 def __init__(
     self,
     observation_space,
     action_space,
     hidden_size=512,
     num_recurrent_layers=2,
     rnn_type="LSTM",
     resnet_baseplanes=32,
     backbone="resnet50",
     normalize_visual_inputs=False,
     obs_transform=ResizeCenterCropper(size=(256, 256)),
     force_blind_policy=False,
 ):
     super().__init__(
         PointNavResNetNet(
             observation_space=observation_space,
             action_space=action_space,
             hidden_size=hidden_size,
             num_recurrent_layers=num_recurrent_layers,
             rnn_type=rnn_type,
             backbone=backbone,
             resnet_baseplanes=resnet_baseplanes,
             normalize_visual_inputs=normalize_visual_inputs,
             obs_transform=obs_transform,
             force_blind_policy=force_blind_policy,
         ),
         action_space.n,
     )
 def __init__(
         self,
         observation_space,
         action_space,
         goal_sensor_uuid="pointgoal_with_gps_compass",
         hidden_size=512,
         num_recurrent_layers=2,
         rnn_type="LSTM",
         resnet_baseplanes=32,
         backbone="resnet50",
         normalize_visual_inputs=False,
         obs_transform=ResizeCenterCropper(size=(256, 256)),
 ):
     super().__init__(
         PointNavResNetNet(
             observation_space=observation_space,
             action_space=action_space,
             goal_sensor_uuid=goal_sensor_uuid,
             hidden_size=hidden_size,
             num_recurrent_layers=num_recurrent_layers,
             rnn_type=rnn_type,
             backbone=backbone,
             resnet_baseplanes=resnet_baseplanes,
             normalize_visual_inputs=normalize_visual_inputs,
             obs_transform=obs_transform,
         ),
         action_space.n,
     )
Beispiel #4
0
    def __init__(
        self,
        observation_space,
        hidden_size,
        goal_sensor_uuid=None,
        detach=False,
        additional_sensors=[] # low dim sensors corresponding to registered name
    ):
        super().__init__()
        self.goal_sensor_uuid = goal_sensor_uuid
        self.additional_sensors = additional_sensors
        self._n_input_goal = 0
        self._n_input_goal = 0
        # if goal_sensor_uuid is not None and goal_sensor_uuid != "no_sensor":
        #     self.goal_sensor_uuid = goal_sensor_uuid
        #     self._initialize_goal_encoder(observation_space)
        self._hidden_size = hidden_size

        resnet_baseplanes = 32
        backbone="resnet18"
        visual_resnet = ResNetEncoder(
            observation_space,
            baseplanes=resnet_baseplanes,
            ngroups=resnet_baseplanes // 2,
            make_backbone=getattr(resnet, backbone),
            normalize_visual_inputs=False,
            obs_transform=ResizeCenterCropper(size=(256, 256)),
            backbone_only=True,
        )

        self.detach = detach
        self.visual_resnet = visual_resnet
        self.visual_encoder = nn.Sequential(
            Flatten(),
            nn.Linear(
                np.prod(visual_resnet.output_shape), hidden_size
            ),
            nn.Sigmoid()
        )

        self.visual_decoder = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
        )

        self.train()
Beispiel #5
0
    def __init__(
            self,
            observation_space,
            g_action_space,
            l_action_space,
            pretrain_path,
            output_size=512,
            obs_transform=ResizeCenterCropper(size=(256, 256)),
    ):
        super().__init__()

        self.net = ObjectNavSLAMNet(
            observation_space=observation_space,
            g_action_space=g_action_space,
            output_size=output_size,
            obs_transform=obs_transform,
            pretrain_path=pretrain_path,
        )

        self.num_local_actions = g_action_space.shape[0]
        self.num_global_actions = g_action_space.shape[0]
        # print("num_global_actions: %d" % self.num_global_actions) #2

        self.global_action_distribution = DiagGaussian(self.net.output_size,
                                                       self.num_global_actions)
        self.critic = CriticHead(self.net.output_size)
Beispiel #6
0
    def __init__(
            self,
            observation_space,
            baseplanes=32,
            ngroups=32,
            spatial_size=128,
            make_backbone=None,
            use_if_available=["rgb", "depth"],
            normalize_visual_inputs=True,
            obs_transform=ResizeCenterCropper(size=(256, 256)),
    ):
        super().__init__()

        self.obs_transform = obs_transform
        if self.obs_transform is not None:
            observation_space = self.obs_transform.transform_observation_space(
                observation_space)

        self._inputs = list(
            filter(lambda x: x in observation_space.spaces, use_if_available))
        self._input_sizes = [0] * len(self._inputs)
        for i, mode in enumerate(self._inputs):
            self._input_sizes[i] = observation_space.spaces[mode].shape[2]

        self.running_mean_and_var = nn.Sequential()

        if not self.is_blind:
            spatial_size = observation_space.spaces[
                self._inputs[0]].shape[0] // 2

            input_channels = sum(
                self._input_sizes)  # self._n_input_depth + self._n_input_rgb
            self.backbone = make_backbone(input_channels, baseplanes, ngroups)

            final_spatial = int(spatial_size *
                                self.backbone.final_spatial_compress)
            after_compression_flat_size = 2048
            num_compression_channels = int(
                round(after_compression_flat_size / (final_spatial**2)))
            self.compression = nn.Sequential(
                nn.Conv2d(
                    self.backbone.final_channels,
                    num_compression_channels,
                    kernel_size=3,
                    padding=1,
                    bias=False,
                ),
                nn.GroupNorm(1, num_compression_channels),
                nn.ReLU(True),
            )

            self.output_shape = (
                num_compression_channels,
                final_spatial,
                final_spatial,
            )
    def __init__(
            self,
            observation_space,
            action_space,
            goal_sensor_uuid,
            hidden_size,
            num_recurrent_layers,
            rnn_type,
            backbone,
            resnet_baseplanes,
            normalize_visual_inputs,
            obs_transform=ResizeCenterCropper(size=(256, 256)),
    ):
        super().__init__()
        self.goal_sensor_uuid = goal_sensor_uuid

        self.prev_action_embedding = nn.Embedding(action_space.n + 1, 32)
        self._n_prev_action = 32

        self._n_input_goal = (
            observation_space.spaces[self.goal_sensor_uuid].shape[0] + 1)
        self.tgt_embeding = nn.Linear(self._n_input_goal, 32)
        self._n_input_goal = 32

        self._hidden_size = hidden_size

        rnn_input_size = self._n_input_goal + self._n_prev_action
        self.visual_encoder = ResNetEncoder(
            observation_space,
            baseplanes=resnet_baseplanes,
            ngroups=resnet_baseplanes // 2,
            make_backbone=getattr(resnet, backbone),
            normalize_visual_inputs=normalize_visual_inputs,
            obs_transform=obs_transform,
        )

        if not self.visual_encoder.is_blind:
            self.visual_fc = nn.Sequential(
                Flatten(),
                nn.Linear(np.prod(self.visual_encoder.output_shape),
                          hidden_size),
                nn.ReLU(True),
            )

        self.state_encoder = RNNStateEncoder(
            (0 if self.is_blind else self._hidden_size) + rnn_input_size,
            self._hidden_size,
            rnn_type=rnn_type,
            num_layers=num_recurrent_layers,
        )

        self.train()
Beispiel #8
0
 def __init__(
         self,
         input_channels,
         out_channels,
         obs_transform=ResizeCenterCropper(size=(256, 256)),
 ):
     super(MapEncoder, self).__init__()
     self.maxpool = nn.MaxPool2d(2)
     self.conv1 = nn.Conv2d(input_channels, 32, 3, stride=1, padding=1)
     self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1)
     self.conv3 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
     self.conv4 = nn.Conv2d(128, 64, 3, stride=1, padding=1)
     self.conv5 = nn.Conv2d(64, 32, 3, stride=1, padding=1)
     self.fc = nn.Linear(7200, out_channels)
Beispiel #9
0
    def __init__(
            self,
            observation_space,
            g_action_space,
            output_size,
            pretrain_path,
            obs_transform=ResizeCenterCropper(size=(240, 240)),
    ):
        super().__init__()

        self._output_size = output_size

        if ObjectGoalSensor.cls_uuid in observation_space.spaces:
            self._n_object_categories = (int(
                observation_space.spaces[ObjectGoalSensor.cls_uuid].high[0]) +
                                         2)
            self.obj_categories_embedding = nn.Embedding(
                self._n_object_categories, 256)
            hidden_size = 256

        # current pose embedding
        curr_pose_dim = observation_space.spaces["curr_pose"].shape[0]
        # print("curr_pose_dim: ", curr_pose_dim)
        self.curr_pose_embedding = nn.Linear(curr_pose_dim, 256)
        hidden_size += 256

        map_dim = observation_space.spaces["map_sum"].shape[2]
        self.map_encoder = MapEncoder(
            map_dim,
            2048,
            obs_transform=obs_transform,
        )

        hidden_size += 2048

        self.linear1 = nn.Linear(hidden_size, 1024)
        self.linear2 = nn.Linear(1024, self._output_size)
    def __init__(
            self,
            observation_space,
            baseplanes=32,
            ngroups=32,
            spatial_size=128,
            make_backbone=None,
            normalize_visual_inputs=False,
            obs_transform=ResizeCenterCropper(size=(256, 256)),
    ):
        super().__init__()

        self.obs_transform = obs_transform
        if self.obs_transform is not None:
            observation_space = self.obs_transform.transform_observation_space(
                observation_space)

        if "rgb" in observation_space.spaces:
            self._n_input_rgb = observation_space.spaces["rgb"].shape[2]
            spatial_size = observation_space.spaces["rgb"].shape[0] // 2
        else:
            self._n_input_rgb = 0

        if "depth" in observation_space.spaces:
            self._n_input_depth = observation_space.spaces["depth"].shape[2]
            spatial_size = observation_space.spaces["depth"].shape[0] // 2
        else:
            self._n_input_depth = 0

        if normalize_visual_inputs:
            self.running_mean_and_var = RunningMeanAndVar(self._n_input_depth +
                                                          self._n_input_rgb)
        else:
            self.running_mean_and_var = nn.Sequential()

        if not self.is_blind:
            input_channels = self._n_input_depth + self._n_input_rgb
            self.backbone = make_backbone(input_channels, baseplanes, ngroups)

            final_spatial = int(spatial_size *
                                self.backbone.final_spatial_compress)
            after_compression_flat_size = 2048
            num_compression_channels = int(
                round(after_compression_flat_size / (final_spatial**2)))
            self.compression = nn.Sequential(
                nn.Conv2d(
                    self.backbone.final_channels,
                    num_compression_channels,
                    kernel_size=3,
                    padding=1,
                    bias=False,
                ),
                nn.GroupNorm(1, num_compression_channels),
                nn.ReLU(True),
            )

            self.output_shape = (
                num_compression_channels,
                final_spatial,
                final_spatial,
            )
    def __init__(
        self,
        observation_space,
        action_space,
        hidden_size,
        num_recurrent_layers,
        rnn_type,
        backbone,
        resnet_baseplanes,
        normalize_visual_inputs,
        obs_transform=ResizeCenterCropper(size=(256, 256)),
        force_blind_policy=False,
    ):
        super().__init__()

        self.prev_action_embedding = nn.Embedding(action_space.n + 1, 32)
        self._n_prev_action = 32
        rnn_input_size = self._n_prev_action

        if (IntegratedPointGoalGPSAndCompassSensor.cls_uuid
                in observation_space.spaces):
            n_input_goal = (observation_space.spaces[
                IntegratedPointGoalGPSAndCompassSensor.cls_uuid].shape[0] + 1)
            self.tgt_embeding = nn.Linear(n_input_goal, 32)
            rnn_input_size += 32

        if ObjectGoalSensor.cls_uuid in observation_space.spaces:
            self._n_object_categories = (int(
                observation_space.spaces[ObjectGoalSensor.cls_uuid].high[0]) +
                                         1)
            self.obj_categories_embedding = nn.Embedding(
                self._n_object_categories, 32)
            rnn_input_size += 32

        if EpisodicGPSSensor.cls_uuid in observation_space.spaces:
            input_gps_dim = observation_space.spaces[
                EpisodicGPSSensor.cls_uuid].shape[0]
            self.gps_embedding = nn.Linear(input_gps_dim, 32)
            rnn_input_size += 32

        if PointGoalSensor.cls_uuid in observation_space.spaces:
            input_pointgoal_dim = observation_space.spaces[
                PointGoalSensor.cls_uuid].shape[0]
            self.pointgoal_embedding = nn.Linear(input_pointgoal_dim, 32)
            rnn_input_size += 32

        if HeadingSensor.cls_uuid in observation_space.spaces:
            input_heading_dim = (
                observation_space.spaces[HeadingSensor.cls_uuid].shape[0] + 1)
            assert input_heading_dim == 2, "Expected heading with 2D rotation."
            self.heading_embedding = nn.Linear(input_heading_dim, 32)
            rnn_input_size += 32

        if ProximitySensor.cls_uuid in observation_space.spaces:
            input_proximity_dim = observation_space.spaces[
                ProximitySensor.cls_uuid].shape[0]
            self.proximity_embedding = nn.Linear(input_proximity_dim, 32)
            rnn_input_size += 32

        if EpisodicCompassSensor.cls_uuid in observation_space.spaces:
            assert (observation_space.spaces[EpisodicCompassSensor.cls_uuid].
                    shape[0] == 1), "Expected compass with 2D rotation."
            input_compass_dim = 2  # cos and sin of the angle
            self.compass_embedding = nn.Linear(input_compass_dim, 32)
            rnn_input_size += 32

        if ImageGoalSensor.cls_uuid in observation_space.spaces:
            goal_observation_space = spaces.Dict(
                {"rgb": observation_space.spaces[ImageGoalSensor.cls_uuid]})
            self.goal_visual_encoder = ResNetEncoder(
                goal_observation_space,
                baseplanes=resnet_baseplanes,
                ngroups=resnet_baseplanes // 2,
                make_backbone=getattr(resnet, backbone),
                normalize_visual_inputs=normalize_visual_inputs,
                obs_transform=obs_transform,
            )

            self.goal_visual_fc = nn.Sequential(
                Flatten(),
                nn.Linear(np.prod(self.goal_visual_encoder.output_shape),
                          hidden_size),
                nn.ReLU(True),
            )

            rnn_input_size += hidden_size

        self._hidden_size = hidden_size

        self.visual_encoder = ResNetEncoder(
            observation_space if not force_blind_policy else spaces.Dict({}),
            baseplanes=resnet_baseplanes,
            ngroups=resnet_baseplanes // 2,
            make_backbone=getattr(resnet, backbone),
            normalize_visual_inputs=normalize_visual_inputs,
            obs_transform=obs_transform,
        )

        if not self.visual_encoder.is_blind:
            self.visual_fc = nn.Sequential(
                Flatten(),
                nn.Linear(np.prod(self.visual_encoder.output_shape),
                          hidden_size),
                nn.ReLU(True),
            )

        self.state_encoder = RNNStateEncoder(
            (0 if self.is_blind else self._hidden_size) + rnn_input_size,
            self._hidden_size,
            rnn_type=rnn_type,
            num_layers=num_recurrent_layers,
        )

        self.train()
Beispiel #12
0
    def __init__(
            self,
            observation_space,
            output_size,
            obs_transform: nn.Module = ResizeCenterCropper(size=(256, 256)),
    ):
        super().__init__()

        self.obs_transform = obs_transform
        if self.obs_transform is not None:
            observation_space = obs_transform.transform_observation_space(
                observation_space)

        if "rgb" in observation_space.spaces:
            self._n_input_rgb = observation_space.spaces["rgb"].shape[2]
        else:
            self._n_input_rgb = 0

        if "depth" in observation_space.spaces:
            self._n_input_depth = observation_space.spaces["depth"].shape[2]
        else:
            self._n_input_depth = 0

        # kernel size for different CNN layers
        self._cnn_layers_kernel_size = [(8, 8), (4, 4), (3, 3)]

        # strides for different CNN layers
        self._cnn_layers_stride = [(4, 4), (2, 2), (1, 1)]

        if self._n_input_rgb > 0:
            cnn_dims = np.array(observation_space.spaces["rgb"].shape[:2],
                                dtype=np.float32)
        elif self._n_input_depth > 0:
            cnn_dims = np.array(observation_space.spaces["depth"].shape[:2],
                                dtype=np.float32)

        if self.is_blind:
            self.cnn = nn.Sequential()
        else:
            for kernel_size, stride in zip(self._cnn_layers_kernel_size,
                                           self._cnn_layers_stride):
                cnn_dims = self._conv_output_dim(
                    dimension=cnn_dims,
                    padding=np.array([0, 0], dtype=np.float32),
                    dilation=np.array([1, 1], dtype=np.float32),
                    kernel_size=np.array(kernel_size, dtype=np.float32),
                    stride=np.array(stride, dtype=np.float32),
                )

            self.cnn = nn.Sequential(
                nn.Conv2d(
                    in_channels=self._n_input_rgb + self._n_input_depth,
                    out_channels=32,
                    kernel_size=self._cnn_layers_kernel_size[0],
                    stride=self._cnn_layers_stride[0],
                ),
                nn.ReLU(True),
                nn.Conv2d(
                    in_channels=32,
                    out_channels=64,
                    kernel_size=self._cnn_layers_kernel_size[1],
                    stride=self._cnn_layers_stride[1],
                ),
                nn.ReLU(True),
                nn.Conv2d(
                    in_channels=64,
                    out_channels=32,
                    kernel_size=self._cnn_layers_kernel_size[2],
                    stride=self._cnn_layers_stride[2],
                ),
                #  nn.ReLU(True),
                Flatten(),
                nn.Linear(32 * cnn_dims[0] * cnn_dims[1], output_size),
                nn.ReLU(True),
            )

        self.layer_init()
Beispiel #13
0
    def __init__(
            self,
            observation_space,
            g_action_space,
            output_size,
            pretrain_path,
            obs_transform=ResizeCenterCropper(size=(240, 240)),
    ):
        super().__init__()

        self._output_size = output_size

        # goal encoder
        if ObjectGoalSensor.cls_uuid in observation_space.spaces:
            # self.fasttext = torchtext.vocab.FastText()
            # self.obj_categories_encoder = nn.Linear(
            #     300, 512
            # )
            # hidden_size = 512
            self._n_object_categories = (int(
                observation_space.spaces[ObjectGoalSensor.cls_uuid].high[0]) +
                                         2)
            self.obj_categories_embedding = nn.Embedding(
                self._n_object_categories, 256)
            hidden_size = 256
        # print('obj_categories_embedding: ', self.obj_categories_embedding) # 22

        curr_pose_dim = observation_space.spaces["curr_pose"].shape[0]
        # print("curr_pose_dim: ", curr_pose_dim)
        self.curr_pose_embedding = nn.Linear(curr_pose_dim, 256)
        hidden_size += 256

        # map encoder
        map_dim = observation_space.spaces["map_sum"].shape[2]
        self.map_encoder = MapEncoder(
            map_dim,
            2048,
            obs_transform=obs_transform,
        )

        hidden_size += 2048

        # scene priors
        self.semantic_encoder = SemanticMap_Encoder(1, 256)
        self.graphcnn = GraphRCNN(512, 21)
        self.edge = torch.tensor(
            [[
                0, 8, 0, 19, 0, 11, 0, 6, 0, 1, 0, 5, 0, 16, 0, 3, 0, 10, 0,
                15, 0, 7, 0, 9, 0, 20, 0, 12, 5, 20, 6, 20, 8, 20, 4, 20, 16,
                20, 7, 20, 15, 20, 18, 20, 8, 14, 7, 14, 1, 14, 2, 14, 9,
                14, 6, 14, 3, 14, 14, 17, 5, 14, 13, 14, 10, 14, 14, 15, 4,
                14, 7, 8, 4, 7, 2, 7, 7, 10, 7, 9, 1, 7, 6, 7, 7, 16, 7, 15, 7,
                17, 5, 7, 7, 13, 7, 18, 3, 8, 1, 3, 3, 6, 2, 3, 3, 10, 3,
                15, 3, 4, 4, 8, 0, 4, 4, 6, 4, 11, 4, 18, 4, 16, 4, 12, 4, 5,
                2, 8, 2, 9, 2, 17, 1, 2, 2, 10, 5, 8, 8, 11, 8, 15, 8, 19, 6,
                8, 1, 8, 8, 16, 8, 9, 8, 12, 8, 13, 8, 18, 8, 10, 2,
                15, 15, 16, 1, 15, 9, 15, 15, 17, 12, 15, 13, 15, 1, 5, 1,
                16, 1, 11, 1, 10, 1, 9, 16, 19, 18, 19, 3, 19, 6, 11, 6, 18, 5,
                6, 6, 16, 13, 17, 10, 17, 12, 17, 8, 17, 10, 11, 5,
                11, 10, 13, 12, 13, 6, 12, 11, 12, 12, 16, 16, 18, 0, 18
            ],
             [
                 8, 0, 19, 0, 11, 0, 6, 0, 1, 0, 5, 0, 16, 0, 3, 0, 10, 0, 15,
                 0, 7, 0, 9, 0, 20, 0, 12, 0, 20, 5, 20, 6, 20, 8, 20, 4, 20,
                 16, 20, 7, 20, 15, 20, 18, 14, 8, 14, 7, 14, 1, 14, 2, 14, 9,
                 14, 6, 14, 3, 17, 14, 14, 5, 14, 13, 14, 10, 15, 14, 14, 4, 8,
                 7, 7, 4, 7, 2, 10, 7, 9, 7, 7, 1, 7, 6, 16, 7, 15, 7, 17, 7,
                 7, 5, 13, 7, 18, 7, 8, 3, 3, 1, 6, 3, 3, 2, 10, 3, 15, 3, 4,
                 3, 8, 4, 4, 0, 6, 4, 11, 4, 18, 4, 16, 4, 12, 4, 5, 4, 8, 2,
                 9, 2, 17, 2, 2, 1, 10, 2, 8, 5, 11, 8, 15, 8, 19, 8, 8, 6, 8,
                 1, 16, 8, 9, 8, 12, 8, 13, 8, 18, 8, 10, 8, 15, 2, 16, 15, 15,
                 1, 15, 9, 17, 15, 15, 12, 15, 13, 5, 1, 16, 1, 11, 1, 10, 1,
                 9, 1, 19, 16, 19, 18, 19, 3, 11, 6, 18, 6, 6, 5, 16, 6, 17,
                 13, 17, 10, 17, 12, 17, 8, 11, 10, 11, 5, 13, 10, 13, 12, 12,
                 6, 12, 11, 16, 12, 18, 16, 18, 0
             ]],
            dtype=torch.long)

        self.edge_type = torch.tensor([
            2, 2, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1,
            2, 2, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
            2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 2, 2, 2, 2, 0, 0,
            2, 2, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
            2, 2, 2, 2, 0, 0, 2, 2, 2, 2, 1, 1, 2, 2, 0, 0, 0, 0, 2, 2, 2, 2,
            2, 2, 2, 2, 2, 2, 0, 0, 2, 2, 2, 2, 1, 1, 0, 0, 2, 2, 2, 2, 2, 2,
            2, 2, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 2, 0, 0,
            2, 2, 0, 0, 2, 2, 0, 0, 2, 2, 2, 2, 0, 0, 2, 2, 0, 0, 2, 2, 2, 2,
            2, 2, 2, 2, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 1, 1,
            0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 0, 0, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2,
            2, 2, 2, 2
        ],
                                      dtype=torch.long)

        self.graph_fc = nn.Linear(21 * 21, 256)

        hidden_size += 256

        self.linear1 = nn.Linear(hidden_size, 1024)
        self.linear2 = nn.Linear(1024, self._output_size)