def _create_dual_maxq_label_tensor(self, method="duality_based"): """Approximate the maxq label with dual.""" w_transpose_list = [] b_transpose_list = [] num_layers = 1 for itr, var in enumerate(self._vars_tf): if itr % 2 == 0: # even itr, multiplicative weights if itr == 0: wx_transpose = self._dummy_network_var_ph["{}_ph".format( var.name)][:self.state_dim, :] w_transpose_list.append( self._dummy_network_var_ph["{}_ph".format( var.name)][self.state_dim:, :]) else: w_transpose_list.append( self._dummy_network_var_ph["{}_ph".format(var.name)]) num_layers += 1 else: # odd itr, additive weights if itr == 1: b_transpose_list.append( tf.tile( tf.expand_dims(self._dummy_network_var_ph[ "{}_ph".format(var.name)], axis=0), [self.batch_size, 1]) + tf.matmul(self._next_state_tensor, wx_transpose)) else: b_transpose_list.append( tf.tile( tf.expand_dims( self._dummy_network_var_ph["{}_ph".format( var.name)], axis=0), [self.batch_size, 1])) action_tensor_center = tf.zeros( shape=[self.batch_size, self.action_dim]) l_infty_norm_bound = np.max(self.action_max) if method == "duality_based": self.dual_maxq_tensor = dual_method.create_dual_approx( num_layers, self.batch_size, l_infty_norm_bound, w_transpose_list, b_transpose_list, action_tensor_center) elif method == "ibp": # ibp dual solver self.dual_maxq_tensor = dual_ibp_method.create_dual_ibp_approx( num_layers, self.batch_size, l_infty_norm_bound, w_transpose_list, b_transpose_list, action_tensor_center) else: # mix method dual_maxq_tensor = dual_method.create_dual_approx( num_layers, self.batch_size, l_infty_norm_bound, w_transpose_list, b_transpose_list, action_tensor_center) dual_ibp_maxq_tensor = dual_ibp_method.create_dual_ibp_approx( num_layers, self.batch_size, l_infty_norm_bound, w_transpose_list, b_transpose_list, action_tensor_center) # minimum of the upper-bound self.dual_maxq_tensor = tf.minimum(dual_maxq_tensor, dual_ibp_maxq_tensor)
def testCreate_Dual_Approx(self): num_layers = 3 batch_size = 2 action_max = 1.0 action_tensor_center = tf.tile( tf.convert_to_tensor( np.array([1.0, 2.0, 3.0, 4.0]).reshape([1, 4]).astype(np.float32)), [2, 1]) W_T_list = [ tf.convert_to_tensor( np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]).astype(np.float32)), tf.convert_to_tensor( np.array([[-1.0], [-3.0], [-5.0]]).astype(np.float32)) ] b_T_list = [ tf.tile( tf.convert_to_tensor( np.array([1.0, -0.5, 0.1]).reshape([1, 3]).astype(np.float32)), [2, 1]), tf.tile( tf.convert_to_tensor( np.array([2.0]).reshape([1, 1]).astype(np.float32)), [2, 1]) ] (neg_J_tilde, l_list, u_list, D_list, Nu_list, gamma_list, psi, l_ip1, u_ip1, Nu_hat_1) = dual_method.create_dual_approx(num_layers, batch_size, action_max, W_T_list, b_T_list, action_tensor_center, return_full_info=True) self.assertIsInstance(neg_J_tilde, tf.Tensor) self.assertEqual((2, 1), neg_J_tilde.shape) self.assertIsInstance(l_list, list) self.assertEqual(num_layers - 1, len(l_list)) for itr, ele in enumerate(l_list): self.assertIsInstance(ele, tf.Tensor) if itr == 0: self.assertEqual((2, 4), ele.shape) elif itr == 1: self.assertEqual((2, 3), ele.shape) self.assertIsInstance(u_list, list) self.assertEqual(num_layers - 1, len(u_list)) for itr, ele in enumerate(u_list): self.assertIsInstance(ele, tf.Tensor) if itr == 0: self.assertEqual((2, 4), ele.shape) elif itr == 1: self.assertEqual((2, 3), ele.shape) self.assertIsInstance(D_list, list) self.assertEqual(num_layers - 1, len(D_list)) for itr, ele in enumerate(D_list): self.assertIsInstance(ele, tf.Tensor) if itr == 0: self.assertEqual((2, 4), ele.shape) elif itr == 1: self.assertEqual((2, 3), ele.shape) self.assertIsInstance(Nu_list, list) self.assertEqual(num_layers - 1, len(Nu_list)) for itr, ele in enumerate(Nu_list): self.assertIsInstance(ele, tf.Tensor) if itr == 0: self.assertEqual((2, 3, 1), ele.shape) elif itr == 1: self.assertEqual((2, 3, 1), ele.shape) self.assertIsInstance(gamma_list, list) self.assertEqual(num_layers - 1, len(gamma_list)) for ele in gamma_list: self.assertIsInstance(ele, tf.Tensor) self.assertEqual((2, 1), ele.shape) self.assertIsInstance(psi, tf.Tensor) self.assertEqual((2, 1), psi.shape) self.assertIsInstance(l_ip1, tf.Tensor) self.assertEqual((2, 1), l_ip1.shape) self.assertIsInstance(u_ip1, tf.Tensor) self.assertEqual((2, 1), u_ip1.shape) self.assertIsInstance(Nu_hat_1, tf.Tensor) self.assertEqual((2, 4, 1), Nu_hat_1.shape) neg_J_tilde_np = self.sess.run(neg_J_tilde) l_list_np = self.sess.run(l_list) u_list_np = self.sess.run(u_list) D_list_np = self.sess.run(D_list) Nu_list_np = self.sess.run(Nu_list) gamma_list_np = self.sess.run(gamma_list) psi_np = self.sess.run(psi) l_ip1_np = self.sess.run(l_ip1) u_ip1_np = self.sess.run(u_ip1) Nu_hat_1_np = self.sess.run(Nu_hat_1) self.assertArrayNear(np.array([[-508.], [-508.]]).flatten(), neg_J_tilde_np.flatten(), err=1e-4) for itr, ele in enumerate(l_list_np): if itr == 0: print(ele) self.assertArrayNear(np.array([[0., 0., 0., 0.], [0., 0., 0., 0.]]).flatten(), ele.flatten(), err=1e-4) elif itr == 1: self.assertArrayNear(np.array([[49., 53.5, 60.1], [49., 53.5, 60.1]]).flatten(), ele.flatten(), err=1e-4) for itr, ele in enumerate(u_list_np): if itr == 0: self.assertArrayNear(np.array([[0., 0., 0., 0.], [0., 0., 0., 0.]]).flatten(), ele.flatten(), err=1e-4) elif itr == 1: self.assertArrayNear(np.array([[93., 105.5, 120.1], [93., 105.5, 120.1]]).flatten(), ele.flatten(), err=1e-4) for itr, ele in enumerate(D_list_np): if itr == 0: self.assertArrayNear(np.array([[0., 0., 0., 0.], [0., 0., 0., 0.]]).flatten(), ele.flatten(), err=1e-4) elif itr == 1: self.assertArrayNear(np.array([[1., 1., 1.], [1., 1., 1.]]).flatten(), ele.flatten(), err=1e-4) for itr, ele in enumerate(Nu_list_np): if itr == 0: self.assertArrayNear(np.array([[[0.], [0.], [0.]], [[0.], [0.], [0.]]]).flatten(), ele.flatten(), err=1e-4) elif itr == 1: self.assertArrayNear(np.array([[[-1.], [-3.], [-5.]], [[-1.], [-3.], [-5.]]]).flatten(), ele.flatten(), err=1e-4) for itr, ele in enumerate(gamma_list_np): if itr == 0: self.assertArrayNear(np.array([[0.], [0.]]).flatten(), ele.flatten(), err=1e-4) elif itr == 1: self.assertArrayNear(np.array([[2.], [2.]]).flatten(), ele.flatten(), err=1e-4) self.assertArrayNear(np.array([[-758.], [-758.]]).flatten(), psi_np.flatten(), err=1e-4) self.assertArrayNear(np.array([[0.], [0.]]).flatten(), l_ip1_np.flatten(), err=1e-4) self.assertArrayNear(np.array([[-0.], [-0.]]).flatten(), u_ip1_np.flatten(), err=1e-4) self.assertArrayNear(np.array([[[-22.], [-49.], [-76.], [-103.]], [[-22.], [-49.], [-76.], [-103.]]]).flatten(), Nu_hat_1_np.flatten(), err=1e-4)