示例#1
0
    def add_placeholders_op(self):
        """
        Adds placeholders to the graph

        These placeholders are used as inputs by the rest of the model building and will be fed
        data during training.
        """
        # here, typically, a state shape is (5,3,22)

        state_shape = list([
            self.env.args.visible_radius_unit_front + 1,
            2 * self.env.args.visible_radius_unit_side + 1,
            len(self.env.state.xmap.item_class_id) + 1
        ])
        self.s = tf.placeholder(tf.bool,
                                shape=(None, None, state_shape[0],
                                       state_shape[1], state_shape[2]))
        self.hs = DNC.state_placeholder(self.config)
        self.slen = tf.placeholder(tf.int32, shape=(None))
        self.sp = tf.placeholder(tf.bool,
                                 shape=(None, None, state_shape[0],
                                        state_shape[1], state_shape[2]))
        self.hsp = DNC.state_placeholder(self.config)
        self.splen = tf.placeholder(tf.int32, shape=(None))

        self.a = tf.placeholder(tf.int32, shape=(None))  # (nb*state_history,)
        self.past_a = tf.placeholder(tf.int32,
                                     shape=(None))  # (nb*state_history,)
        self.r = tf.placeholder(tf.float32,
                                shape=(None))  # (nb*state_history,)
        self.done_mask = tf.placeholder(tf.bool,
                                        shape=(None))  # (nb*state_history,)
        self.seq_mask = tf.placeholder(tf.bool,
                                       shape=(None))  # (nb*state_history,)
        self.lr = tf.placeholder(tf.float32, shape=(None))
    def add_placeholders_op(self):
        """
        Adds placeholders to the graph

        These placeholders are used as inputs by the rest of the model building and will be fed
        data during training.
        """
        # here, typically, a state shape is (2*4*3+2)

        num_classes = len(self.env.state.xmap.item_class_id)
        ndigits = self.config.ndigits
        nway = self.config.nway
        self.s = tf.placeholder(tf.float32, shape=(None, None, 2*(num_classes-1)*ndigits*nway+2))
        self.hs = DNC.state_placeholder(self.config)
        self.slen = tf.placeholder(tf.int32, shape=(None))
        self.pred_flag = tf.placeholder(tf.float32, shape=(None, None)) # (nb, state_history)
        self.target_action = tf.placeholder(tf.float32, shape=(None, None, 2*(num_classes-1)*ndigits*nway)) # (nb, state_history, num_actions)
        self.lr = tf.placeholder(tf.float32, shape=(None))