Example #1
0
    def sample(self, batch_size=1, with_details=False, with_entropy=False):
        """
        Returns
        -------
        configs : list of dict
            list of configurations
        """
        inputs = self.static_inputs[batch_size]
        hidden = self.static_init_hidden[batch_size]

        actions = []
        entropies = []
        log_probs = []

        for idx in range(len(self.num_tokens)):
            logits, hidden = self.forward(inputs,
                                          hidden,
                                          idx,
                                          is_embed=(idx == 0))

            probs = F.softmax(logits, axis=-1)
            log_prob = F.log_softmax(logits, axis=-1)
            entropy = -(log_prob *
                        probs).sum(1, keepdims=False) if with_entropy else None

            action = mx.random.multinomial(probs, 1)
            ind = mx.nd.stack(mx.nd.arange(probs.shape[0], ctx=action.context),
                              action.astype('float32'))
            selected_log_prob = F.gather_nd(log_prob, ind)

            actions.append(action[:, 0])
            entropies.append(entropy)
            log_probs.append(selected_log_prob)

            inputs = action[:, 0] + sum(self.num_tokens[:idx])
            inputs.detach()

        configs = []
        for idx in range(batch_size):
            config = {}
            for i, action in enumerate(actions):
                choice = action[idx].asscalar()
                k, space = self.spaces[i]
                config[k] = int(choice)
            configs.append(config)

        if with_details:
            entropies = F.stack(*entropies,
                                axis=1) if with_entropy else entropies
            return configs, F.stack(*log_probs, axis=1), entropies
        else:
            return configs
Example #2
0
    def sample(self, batch_size=1, with_details=False, with_entropy=False):
        # self-attention
        x = self.embedding(batch_size).reshape(
            -3, 0)  # .squeeze() # b x action x h
        kshape = (batch_size, self.num_total_tokens, self.hidden_size)
        vshape = (batch_size, self.num_total_tokens, 1)
        querry = self.querry(x).reshape(*kshape)  # b x actions x h
        key = self.key(x).reshape(*kshape)  # b x actions x h
        value = self.value(x).reshape(*vshape)  # b x actions x 1
        atten = mx.nd.linalg_gemm2(querry, key,
                                   transpose_b=True).softmax(axis=1)
        alphas = mx.nd.linalg_gemm2(atten, value).squeeze(axis=-1)

        actions = []
        entropies = []
        log_probs = []
        for idx in range(len(self.num_tokens)):
            i0 = sum(self.num_tokens[:idx])
            i1 = sum(self.num_tokens[:idx + 1])
            logits = alphas[:, i0:i1]

            probs = F.softmax(logits, axis=-1)
            log_prob = F.log_softmax(logits, axis=-1)

            entropy = -(log_prob *
                        probs).sum(1, keepdims=False) if with_entropy else None

            action = mx.random.multinomial(probs, 1)
            ind = mx.nd.stack(mx.nd.arange(probs.shape[0], ctx=action.context),
                              action.astype('float32'))
            selected_log_prob = F.gather_nd(log_prob, ind)

            actions.append(action[:, 0])
            entropies.append(entropy)
            log_probs.append(selected_log_prob)

        configs = []
        for idx in range(batch_size):
            config = {}
            for i, action in enumerate(actions):
                choice = action[idx].asscalar()
                k, space = self.spaces[i]
                config[k] = int(choice)
            configs.append(config)

        if with_details:
            entropies = F.stack(*entropies,
                                axis=1) if with_entropy else entropies
            return configs, F.stack(*log_probs, axis=1), entropies
        else:
            return configs
Example #3
0
    def sample(self, batch_size=1, with_details=False, with_entropy=False):
        actions = []
        entropies = []
        log_probs = []

        for idx in range(len(self.num_tokens)):
            logits = self.decoders[idx](batch_size)

            probs = F.softmax(logits, axis=-1)
            log_prob = F.log_softmax(logits, axis=-1)

            entropy = -(log_prob *
                        probs).sum(1, keepdims=False) if with_entropy else None

            action = mx.random.multinomial(probs, 1)
            ind = mx.nd.stack(mx.nd.arange(probs.shape[0], ctx=action.context),
                              action.astype('float32'))
            selected_log_prob = F.gather_nd(log_prob, ind)

            actions.append(action[:, 0])
            entropies.append(entropy)
            log_probs.append(selected_log_prob)

        configs = []
        for idx in range(batch_size):
            config = {}
            for i, action in enumerate(actions):
                choice = action[idx].asscalar()
                k, space = self.spaces[i]
                config[k] = int(choice)
            configs.append(config)

        if with_details:
            entropies = F.stack(*entropies,
                                axis=1) if with_entropy else entropies
            return configs, F.stack(*log_probs, axis=1), entropies
        else:
            return configs