def __init__(self, z_dim, initailize_weights=True):
        """
        Decodes the optical flow and optical flow mask.
        """
        super().__init__()

        self.optical_flow_conv = conv2d(2 * z_dim, 64, kernel_size=1, stride=1)

        self.img_deconv6 = deconv(64, 64)
        self.img_deconv5 = deconv(64, 32)
        self.img_deconv4 = deconv(162, 32)
        self.img_deconv3 = deconv(98, 32)
        self.img_deconv2 = deconv(98, 32)

        self.predict_optical_flow6 = predict_flow(64)
        self.predict_optical_flow5 = predict_flow(162)
        self.predict_optical_flow4 = predict_flow(98)
        self.predict_optical_flow3 = predict_flow(98)
        self.predict_optical_flow2 = predict_flow(66)

        self.upsampled_optical_flow6_to_5 = nn.ConvTranspose2d(2,
                                                               2,
                                                               4,
                                                               2,
                                                               1,
                                                               bias=False)
        self.upsampled_optical_flow5_to_4 = nn.ConvTranspose2d(2,
                                                               2,
                                                               4,
                                                               2,
                                                               1,
                                                               bias=False)
        self.upsampled_optical_flow4_to_3 = nn.ConvTranspose2d(2,
                                                               2,
                                                               4,
                                                               2,
                                                               1,
                                                               bias=False)
        self.upsampled_optical_flow3_to_2 = nn.ConvTranspose2d(2,
                                                               2,
                                                               4,
                                                               2,
                                                               1,
                                                               bias=False)

        self.predict_optical_flow2_mask = nn.Conv2d(66,
                                                    1,
                                                    kernel_size=3,
                                                    stride=1,
                                                    padding=1,
                                                    bias=False)

        if initailize_weights:
            init_weights(self.modules())
    def __init__(self, z_dim, action_dim, initailize_weights=True):
        """
        Decodes the EE Delta
        """
        super().__init__()

        self.ee_delta_decoder = nn.Sequential(
            nn.Linear(z_dim, 128),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Linear(128, 64),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Linear(64, 32),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Linear(32, action_dim),
        )

        if initailize_weights:
            init_weights(self.modules())
    def __init__(self, z_dim, initailize_weights=True):
        """
        Image encoder taken from selfsupervised code
        """
        super().__init__()
        self.z_dim = z_dim

        self.proprio_encoder = nn.Sequential(
            nn.Linear(8, 32),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Linear(32, 64),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Linear(64, 128),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Linear(128, 2 * self.z_dim),
            nn.LeakyReLU(0.1, inplace=True),
        )

        if initailize_weights:
            init_weights(self.modules())
    def __init__(self, z_dim, initailize_weights=True):
        """
        Image encoder taken from Making Sense of Vision and Touch
        Modified to fit the 224*224 input image
        """
        super().__init__()
        self.z_dim = z_dim

        self.img_conv1 = conv2d(3, 16, kernel_size=7, stride=2)
        self.img_conv2 = conv2d(16, 32, kernel_size=5, stride=2)
        self.img_conv3 = conv2d(32, 64, kernel_size=5, stride=2)
        self.img_conv4 = conv2d(64, 64, stride=2)
        self.img_conv5 = conv2d(64, 128, stride=2)
        self.img_conv6 = conv2d(128, self.z_dim, stride=2)
        self.img_fc1 = nn.Linear(16 * self.z_dim, 8 * self.z_dim)
        self.img_fc2 = nn.Linear(8 * self.z_dim, 2 * self.z_dim)
        self.flatten = Flatten()

        if initailize_weights:
            init_weights(self.modules())
    def __init__(self, z_dim, initailize_weights=True):
        """
        Simplified Depth Encoder taken from Making Sense of Vision and Touch
        """
        super().__init__()
        self.z_dim = z_dim

        self.depth_conv1 = conv2d(1, 32, kernel_size=3, stride=2)
        self.depth_conv2 = conv2d(32, 64, kernel_size=3, stride=2)
        self.depth_conv3 = conv2d(64, 64, kernel_size=4, stride=2)
        self.depth_conv4 = conv2d(64, 64, stride=2)
        self.depth_conv5 = conv2d(64, 128, stride=2)
        self.depth_conv6 = conv2d(128, self.z_dim, stride=2)

        self.depth_fc1 = nn.Linear(16 * self.z_dim, 8 * self.z_dim)
        self.depth_fc2 = nn.Linear(8 * self.z_dim, 2 * self.z_dim)
        self.flatten = Flatten()

        if initailize_weights:
            init_weights(self.modules())
    def __init__(self, z_dim, initailize_weights=True):
        """
        Force encoder taken from selfsupervised code
        Modified to fit the (6, 1) force sensor input
        """
        super().__init__()
        self.z_dim = z_dim

        self.frc_encoder = nn.Sequential(
            CausalConv1D(6, 16, kernel_size=2, stride=2),
            nn.LeakyReLU(0.1, inplace=True),
            CausalConv1D(16, 32, kernel_size=2, stride=2),
            nn.LeakyReLU(0.1, inplace=True),
            CausalConv1D(32, 64, kernel_size=2, stride=2),
            nn.LeakyReLU(0.1, inplace=True),
            CausalConv1D(64, 128, kernel_size=2, stride=2),
            nn.LeakyReLU(0.1, inplace=True),
            CausalConv1D(128, 2 * self.z_dim, kernel_size=2, stride=2),
            nn.LeakyReLU(0.1, inplace=True),
        )

        if initailize_weights:
            init_weights(self.modules())
    def __init__(self,
                 z_dim,
                 out_dim=3,
                 mode="color",
                 initailize_weights=True):
        """
        Decodes the state actor to predict the color image or depth image
        """
        super().__init__()
        self.mode = mode

        self.optical_flow_conv = conv2d(2 * z_dim, 64, kernel_size=1, stride=1)

        self.img_deconv6 = deconv(64, 64)
        self.img_deconv5 = deconv(64, 32)
        self.img_deconv4 = deconv(162, 32)
        self.img_deconv3 = deconv(98, 32)
        self.img_deconv2 = deconv(98, 32)

        self.predict_optical_flow6 = predict_flow(64)
        self.predict_optical_flow5 = predict_flow(162)
        self.predict_optical_flow4 = predict_flow(98)
        self.predict_optical_flow3 = predict_flow(98)

        self.predict_optical_flow1 = predict_flow(34)
        self.predict_optical_flowf = predict_flow(out_dim)

        if self.mode == "color":
            self.img_deconv1 = deconv(66, 8)
            self.predict_optical_flow2 = predict_flow(66)
            self.img_deconvf = deconv(26, out_dim)
        elif self.mode == "depth":
            self.img_deconv1 = deconv(98, 8)
            self.predict_optical_flow2 = predict_flow(98)
            self.img_deconvf = deconv(42, out_dim)
        else:
            raise NotImplementedError

        self.upsampled_optical_flow6_to_5 = nn.ConvTranspose2d(2,
                                                               2,
                                                               4,
                                                               2,
                                                               1,
                                                               bias=False)
        self.upsampled_optical_flow5_to_4 = nn.ConvTranspose2d(2,
                                                               2,
                                                               4,
                                                               2,
                                                               1,
                                                               bias=False)
        self.upsampled_optical_flow4_to_3 = nn.ConvTranspose2d(2,
                                                               2,
                                                               4,
                                                               2,
                                                               1,
                                                               bias=False)
        self.upsampled_optical_flow3_to_2 = nn.ConvTranspose2d(2,
                                                               2,
                                                               4,
                                                               2,
                                                               1,
                                                               bias=False)
        self.upsampled_optical_flow2_to_1 = nn.ConvTranspose2d(2,
                                                               2,
                                                               4,
                                                               2,
                                                               1,
                                                               bias=False)

        if self.mode == "color":
            self.predict_optical_flow_mask = nn.Conv2d(3,
                                                       1,
                                                       kernel_size=3,
                                                       stride=1,
                                                       padding=1,
                                                       bias=False)
        elif self.mode == "depth":
            self.predict_optical_flow_mask = nn.Conv2d(out_dim,
                                                       1,
                                                       kernel_size=3,
                                                       stride=1,
                                                       padding=1,
                                                       bias=False)
        else:
            raise NotImplementedError

        if initailize_weights:
            init_weights(self.modules())