def _create_dual_maxq_label_tensor(self, method="duality_based"):
        """Approximate the maxq label with dual."""
        w_transpose_list = []
        b_transpose_list = []
        num_layers = 1

        for itr, var in enumerate(self._vars_tf):
            if itr % 2 == 0:
                # even itr, multiplicative weights
                if itr == 0:
                    wx_transpose = self._dummy_network_var_ph["{}_ph".format(
                        var.name)][:self.state_dim, :]
                    w_transpose_list.append(
                        self._dummy_network_var_ph["{}_ph".format(
                            var.name)][self.state_dim:, :])

                else:
                    w_transpose_list.append(
                        self._dummy_network_var_ph["{}_ph".format(var.name)])
                num_layers += 1

            else:
                # odd itr, additive weights
                if itr == 1:
                    b_transpose_list.append(
                        tf.tile(
                            tf.expand_dims(self._dummy_network_var_ph[
                                "{}_ph".format(var.name)],
                                           axis=0), [self.batch_size, 1]) +
                        tf.matmul(self._next_state_tensor, wx_transpose))
                else:
                    b_transpose_list.append(
                        tf.tile(
                            tf.expand_dims(
                                self._dummy_network_var_ph["{}_ph".format(
                                    var.name)],
                                axis=0), [self.batch_size, 1]))

        action_tensor_center = tf.zeros(
            shape=[self.batch_size, self.action_dim])

        l_infty_norm_bound = np.max(self.action_max)
        if method == "duality_based":
            self.dual_maxq_tensor = dual_method.create_dual_approx(
                num_layers, self.batch_size, l_infty_norm_bound,
                w_transpose_list, b_transpose_list, action_tensor_center)
        elif method == "ibp":
            # ibp dual solver
            self.dual_maxq_tensor = dual_ibp_method.create_dual_ibp_approx(
                num_layers, self.batch_size, l_infty_norm_bound,
                w_transpose_list, b_transpose_list, action_tensor_center)
        else:
            # mix method
            dual_maxq_tensor = dual_method.create_dual_approx(
                num_layers, self.batch_size, l_infty_norm_bound,
                w_transpose_list, b_transpose_list, action_tensor_center)
            dual_ibp_maxq_tensor = dual_ibp_method.create_dual_ibp_approx(
                num_layers, self.batch_size, l_infty_norm_bound,
                w_transpose_list, b_transpose_list, action_tensor_center)
            # minimum of the upper-bound
            self.dual_maxq_tensor = tf.minimum(dual_maxq_tensor,
                                               dual_ibp_maxq_tensor)
    def testCreate_Dual_Approx(self):
        num_layers = 3
        batch_size = 2
        action_max = 1.0
        action_tensor_center = tf.tile(
            tf.convert_to_tensor(
                np.array([1.0, 2.0, 3.0,
                          4.0]).reshape([1, 4]).astype(np.float32)), [2, 1])
        W_T_list = [
            tf.convert_to_tensor(
                np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],
                          [10.0, 11.0, 12.0]]).astype(np.float32)),
            tf.convert_to_tensor(
                np.array([[-1.0], [-3.0], [-5.0]]).astype(np.float32))
        ]
        b_T_list = [
            tf.tile(
                tf.convert_to_tensor(
                    np.array([1.0, -0.5,
                              0.1]).reshape([1, 3]).astype(np.float32)),
                [2, 1]),
            tf.tile(
                tf.convert_to_tensor(
                    np.array([2.0]).reshape([1, 1]).astype(np.float32)),
                [2, 1])
        ]
        (neg_J_tilde, l_list, u_list, D_list, Nu_list, gamma_list, psi, l_ip1,
         u_ip1,
         Nu_hat_1) = dual_method.create_dual_approx(num_layers,
                                                    batch_size,
                                                    action_max,
                                                    W_T_list,
                                                    b_T_list,
                                                    action_tensor_center,
                                                    return_full_info=True)

        self.assertIsInstance(neg_J_tilde, tf.Tensor)
        self.assertEqual((2, 1), neg_J_tilde.shape)

        self.assertIsInstance(l_list, list)
        self.assertEqual(num_layers - 1, len(l_list))
        for itr, ele in enumerate(l_list):
            self.assertIsInstance(ele, tf.Tensor)
            if itr == 0:
                self.assertEqual((2, 4), ele.shape)
            elif itr == 1:
                self.assertEqual((2, 3), ele.shape)

        self.assertIsInstance(u_list, list)
        self.assertEqual(num_layers - 1, len(u_list))
        for itr, ele in enumerate(u_list):
            self.assertIsInstance(ele, tf.Tensor)
            if itr == 0:
                self.assertEqual((2, 4), ele.shape)
            elif itr == 1:
                self.assertEqual((2, 3), ele.shape)

        self.assertIsInstance(D_list, list)
        self.assertEqual(num_layers - 1, len(D_list))
        for itr, ele in enumerate(D_list):
            self.assertIsInstance(ele, tf.Tensor)
            if itr == 0:
                self.assertEqual((2, 4), ele.shape)
            elif itr == 1:
                self.assertEqual((2, 3), ele.shape)

        self.assertIsInstance(Nu_list, list)
        self.assertEqual(num_layers - 1, len(Nu_list))
        for itr, ele in enumerate(Nu_list):
            self.assertIsInstance(ele, tf.Tensor)
            if itr == 0:
                self.assertEqual((2, 3, 1), ele.shape)
            elif itr == 1:
                self.assertEqual((2, 3, 1), ele.shape)

        self.assertIsInstance(gamma_list, list)
        self.assertEqual(num_layers - 1, len(gamma_list))
        for ele in gamma_list:
            self.assertIsInstance(ele, tf.Tensor)
            self.assertEqual((2, 1), ele.shape)

        self.assertIsInstance(psi, tf.Tensor)
        self.assertEqual((2, 1), psi.shape)
        self.assertIsInstance(l_ip1, tf.Tensor)
        self.assertEqual((2, 1), l_ip1.shape)
        self.assertIsInstance(u_ip1, tf.Tensor)
        self.assertEqual((2, 1), u_ip1.shape)
        self.assertIsInstance(Nu_hat_1, tf.Tensor)
        self.assertEqual((2, 4, 1), Nu_hat_1.shape)

        neg_J_tilde_np = self.sess.run(neg_J_tilde)
        l_list_np = self.sess.run(l_list)
        u_list_np = self.sess.run(u_list)
        D_list_np = self.sess.run(D_list)
        Nu_list_np = self.sess.run(Nu_list)
        gamma_list_np = self.sess.run(gamma_list)
        psi_np = self.sess.run(psi)
        l_ip1_np = self.sess.run(l_ip1)
        u_ip1_np = self.sess.run(u_ip1)
        Nu_hat_1_np = self.sess.run(Nu_hat_1)

        self.assertArrayNear(np.array([[-508.], [-508.]]).flatten(),
                             neg_J_tilde_np.flatten(),
                             err=1e-4)
        for itr, ele in enumerate(l_list_np):
            if itr == 0:
                print(ele)
                self.assertArrayNear(np.array([[0., 0., 0., 0.],
                                               [0., 0., 0., 0.]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)
            elif itr == 1:
                self.assertArrayNear(np.array([[49., 53.5, 60.1],
                                               [49., 53.5, 60.1]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)
        for itr, ele in enumerate(u_list_np):
            if itr == 0:
                self.assertArrayNear(np.array([[0., 0., 0., 0.],
                                               [0., 0., 0., 0.]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)
            elif itr == 1:
                self.assertArrayNear(np.array([[93., 105.5, 120.1],
                                               [93., 105.5, 120.1]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)
        for itr, ele in enumerate(D_list_np):
            if itr == 0:
                self.assertArrayNear(np.array([[0., 0., 0., 0.],
                                               [0., 0., 0., 0.]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)
            elif itr == 1:
                self.assertArrayNear(np.array([[1., 1., 1.], [1., 1.,
                                                              1.]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)
        for itr, ele in enumerate(Nu_list_np):
            if itr == 0:
                self.assertArrayNear(np.array([[[0.], [0.], [0.]],
                                               [[0.], [0.], [0.]]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)
            elif itr == 1:
                self.assertArrayNear(np.array([[[-1.], [-3.], [-5.]],
                                               [[-1.], [-3.],
                                                [-5.]]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)

        for itr, ele in enumerate(gamma_list_np):
            if itr == 0:
                self.assertArrayNear(np.array([[0.], [0.]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)
            elif itr == 1:
                self.assertArrayNear(np.array([[2.], [2.]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)

        self.assertArrayNear(np.array([[-758.], [-758.]]).flatten(),
                             psi_np.flatten(),
                             err=1e-4)
        self.assertArrayNear(np.array([[0.], [0.]]).flatten(),
                             l_ip1_np.flatten(),
                             err=1e-4)
        self.assertArrayNear(np.array([[-0.], [-0.]]).flatten(),
                             u_ip1_np.flatten(),
                             err=1e-4)
        self.assertArrayNear(np.array([[[-22.], [-49.], [-76.], [-103.]],
                                       [[-22.], [-49.], [-76.],
                                        [-103.]]]).flatten(),
                             Nu_hat_1_np.flatten(),
                             err=1e-4)