Exemple #1
0
    def derive(self, sample_num=None, valid_idx=0):
        """TODO(brendan): We are always deriving based on the very first batch
        of validation data? This seems wrong...
        """
        hidden = self.shared.init_hidden(self.args.batch_size)

        if sample_num is None:
            sample_num = self.args.derive_num_sample

        dags, _, entropies = self.controller.sample(sample_num,
                                                    with_details=True)

        max_R = 0
        best_dag = None
        for dag in dags:
            R, _ = self.get_reward(dag, entropies, hidden, valid_idx)
            if R.max() > max_R:
                max_R = R.max()
                best_dag = dag

        logger.info(f'derive | max_R: {max_R:8.6f}')
        fname = (f'{self.epoch:03d}-{self.controller_step:06d}-'
                 f'{max_R:6.4f}-best.png')
        path = os.path.join(self.args.model_dir, 'networks', fname)
        utils.draw_network(best_dag, path)
        self.tb.image_summary('derive/best', [path], self.epoch)

        return best_dag
Exemple #2
0
    def _summarize_controller_train(self,
                                    total_loss,
                                    adv_history,
                                    entropy_history,
                                    reward_history,
                                    avg_reward_base,
                                    dags):
        """Logs the controller's progress for this training epoch."""
        cur_loss = total_loss / self.args.log_step

        avg_adv = np.mean(adv_history)
        avg_entropy = np.mean(entropy_history)
        avg_reward = np.mean(reward_history)

        if avg_reward_base is None:
            avg_reward_base = avg_reward

        logger.info(
            f'| epoch {self.epoch:3d} | lr {self.controller_lr:.5f} '
            f'| R {avg_reward:.5f} | entropy {avg_entropy:.4f} '
            f'| loss {cur_loss:.5f}')

        # Tensorboard
        if self.tb is not None:
            self.tb.scalar_summary('controller/loss',
                                   cur_loss,
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward',
                                   avg_reward,
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward-B_per_epoch',
                                   avg_reward - avg_reward_base,
                                   self.controller_step)
            self.tb.scalar_summary('controller/entropy',
                                   avg_entropy,
                                   self.controller_step)
            self.tb.scalar_summary('controller/adv',
                                   avg_adv,
                                   self.controller_step)

            paths = []
            for dag in dags:
                fname = (f'{self.epoch:03d}-{self.controller_step:06d}-'
                         f'{avg_reward:6.4f}.png')
                path = os.path.join(self.args.model_dir, 'networks', fname)
                utils.draw_network(dag, path)
                paths.append(path)

            self.tb.image_summary('controller/sample',
                                  paths,
                                  self.controller_step)
Exemple #3
0
    def sample(self, batch_size=1, with_details=False, save_dir=None):
        assert batch_size >= 1

        # [B, L, H]
        inputs = torch.Tensor([self.num_total_tokens-1]).to(self.args.device).long()
        hidden = None

        activations = []
        entropies = []
        log_probs = []
        prev_nodes = []

        for block_idx in range(2 * (self.args.num_blocks - 1) + 1):
            logits, hidden = self.forward(inputs, hidden, block_idx)

            probs = F.softmax(logits, dim=-1)
            log_prob = F.log_softmax(logits, dim=-1)
            entropy = -(log_prob * probs).sum(1, keepdim=False)  # ????????

            action = probs.multinomial(num_samples=1).data

            selected_log_prob = log_prob.gather(
                1, utils.get_variable(action, requires_grad=False))

            entropies.append(entropy)
            log_probs.append(selected_log_prob[:, 0])

            # 0: function, 1: previous node
            mode = block_idx % 2
            inputs = utils.get_variable(action[:, 0] + sum(self.num_tokens[:mode]), requires_grad=False)

            if mode == 0:  # function
                activations.append(action[:, 0])
            elif mode == 1:
                prev_nodes.append(action[:, 0])

        prev_nodes = torch.stack(prev_nodes).transpose(0, 1)
        activations = torch.stack(activations).transpose(0, 1)

        dags = _construct_dags(prev_nodes, activations, self.func_names, self.args.num_blocks)

        # 나중에 사용
        if save_dir is not None:
            for idx, dag in enumerate(dags):
                utils.draw_network(dag, os.path.join(save_dir, f'graph{idx}.png'))

        if with_details:
            return dags, torch.cat(log_probs), torch.cat(entropies)

        return dags
Exemple #4
0
    def derive(self, sample_num=None, valid_idx=0):
        if sample_num is None:
            sample_num = self.args.derive_num_sample  # args.derive_num_sample = 100

        # def sample(self, batch_size=1, with_details=False, save_dir=None):
        dags, _, entropies = self.controller.sample(with_details=True)

        max_R = 0
        best_dag = None
        for dag in dags:
            R = self.get_reward(dag, entropies, valid_idx)
            if R.sum() > max_R:
                max_R = R.sum()
                best_dag = dag

        fname = (f'{self.epoch:03d}-{self.controller_step:06d}-'
                 f'{max_R:6.4f}-best.png')

        dir_path = 'sample_model/'
        path = os.path.join(dir_path, fname)
        utils.draw_network(best_dag, path)

        return best_dag
    def sample(self,
               batch_size=1,
               with_details=False,
               save_dir=None,
               construct_dag_method=None):
        """Samples a set of `args.num_blocks` many computational nodes from the
        controller, where each node is made up of an activation function, and
        each node except the last also includes a previous node.
        """
        def _construct_micro_cnn_dags(prev_nodes, activations, func_names,
                                      num_blocks):
            """Constructs a set of DAGs based on the actions, i.e., previous nodes and
            activation functions, sampled from the controller/policy pi.

            This will be tailored for CNN only. Not the afore-mentioned RNN.

            Args:
                prev_nodes: Previous node actions from the policy.
                activations: Activations sampled from the policy.
                func_names: [normal_func_names, reduce_func_names]
                num_blocks: Number of blocks in the target RNN cell.

            Returns:
                A list of DAGs defined by the inputs.

            CNN cell DAGs are represented in the following way:

            1. entire DAG is represent as a simple list, of element 2
                [ Normal-Cell, Reduction-Cell ]
            2. each element is another list, containing such information
                [ (node_id1, node_id2, ops), ] * num_blocks
                    represents node1 -- ops --> node 2

            3. node 0, represents the h(t-1), i.e. previous layer input
               node 1, represents the h(t), i.e. current input
                    so, the actually index for current block starts from2

            """
            dags = []
            for nodes, func_ids in zip(prev_nodes, activations):
                dag = []

                # compute the first node
                # dag.append(MicroNode(0, 2, func_names[func_ids[0]]))
                # dag.append(MicroNode(1, 2, func_names[func_ids[0]]))
                leaf_nodes = set(range(2, num_blocks + 2))

                # add following nodes
                for curr_idx, (prev_idx,
                               func_id) in enumerate(zip(nodes, func_ids)):
                    layer_id = curr_idx // 2 + 2
                    _prev_idx = utils.to_item(prev_idx)
                    if _prev_idx == layer_id:
                        continue
                    assert _prev_idx < layer_id, "Crutial logical error"
                    dag.append(
                        MicroNode(_prev_idx, layer_id, func_names[func_id]))
                    leaf_nodes -= set([_prev_idx])

                # add leaf node connection with concat
                # for idx in leaf_nodes:
                #     dag.append(MicroNode(idx, num_blocks, 'concat'))
                dag.sort()
                dags.append(dag)

            return dags

        construct_dag_method = construct_dag_method or _construct_micro_cnn_dags

        list_dags = []
        final_log_probs = []
        final_entropies = []
        block_num = 4 * self.args.num_blocks
        # Iterate Normal cell and Reduced cell
        for type_id in range(2):

            if batch_size < 1:
                raise Exception(f'Wrong batch_size: {batch_size} < 1')

            # [B, L, H]
            inputs = self.static_inputs[batch_size]
            hidden = self.static_init_hidden[batch_size]

            activations = []
            entropies = []
            log_probs = []
            prev_nodes = []

            for block_idx in range((0 + type_id) * block_num,
                                   (1 + type_id) * block_num):
                logits, hidden = self.forward(
                    inputs,
                    hidden,
                    block_idx,
                    is_embed=(block_idx == (0 + type_id) * block_num))

                probs = F.softmax(logits, dim=-1)
                log_prob = F.log_softmax(logits, dim=-1)
                entropy = -(log_prob * probs).sum(1, keepdim=False)

                action = probs.multinomial(num_samples=1).data
                selected_log_prob = log_prob.gather(
                    1, utils.get_variable(action, requires_grad=False))

                # .view()? Same below with `action`.
                entropies.append(entropy)
                log_probs.append(selected_log_prob[:, 0])

                # 1: function, 0: previous node
                mode = block_idx % 2
                inputs = utils.get_variable(action[:, 0] +
                                            sum(self.num_tokens[:mode]),
                                            requires_grad=False)

                if mode == 1:
                    activations.append(action[:, 0])
                elif mode == 0:
                    prev_nodes.append(action[:, 0])

            prev_nodes = torch.stack(prev_nodes).transpose(0, 1)
            activations = torch.stack(activations).transpose(0, 1)

            dags = construct_dag_method(
                prev_nodes, activations, self.normal_func_names if type_id == 0
                else self.reduce_func_names, self.args.num_blocks)
            if save_dir is not None:
                for idx, dag in enumerate(dags):
                    utils.draw_network(
                        dag, os.path.join(save_dir, f'graph{idx}.png'))
            # add to the final result
            list_dags.append(dags)
            final_entropies.extend(entropies)
            final_log_probs.extend(log_probs)

        list_dags = [
            MicroArchi(d1, d2) for d1, d2 in zip(list_dags[0], list_dags[1])
        ]

        if with_details:
            return list_dags, torch.cat(final_log_probs), torch.cat(
                final_entropies)

        if batch_size == 1 and len(list_dags) != 1:
            list_dags = [list_dags]
        elif batch_size != len(list_dags):
            raise RuntimeError(
                f"Sample batch_size {batch_size} does not match with len list_dags {len(list_dags)}"
            )
        return list_dags
Exemple #6
0
    def sample(self, batch_size=1, with_details=False, save_dir=None):
        """Samples a set of `args.num_blocks(12)` many computational nodes from the
        controller, where each node is made up of an activation function, and
        each node except the last also includes a previous node.
        将所有结点和激活函数的关系都保存起来形成了一个dag,每次sample就是产生一个dag的过程
        """
        if batch_size < 1:
            raise Exception('Wrong batch_size: {batch_size} < 1')

        # [B, L, H]
        inputs = self.static_inputs[
            batch_size]  #Variable(tensor.shape([1,100]))
        hidden = self.static_init_hidden[
            batch_size]  #tuple(Variable(tensor([1,100]),Variable(tensor([1,100]))

        activations = []
        entropies = []
        log_probs = []
        prev_nodes = []
        # NOTE(brendan): The RNN controller alternately outputs an activation,
        # followed by a previous node, for each block except the last one,
        # which only gets an activation function. The last node is the output
        # node, and its previous node is the average of all leaf nodes.
        for block_idx in range(2 * (self.args.num_blocks - 1) +
                               1):  #range(23) 25个块,最后两个一个是输出节点,一个是平均所有叶子节点
            #12个节点,每个节点都要决定2件事情,一个是激活函数,一个是前一个节点,再加上一个初始0号节点,一共是25个块
            #logits是输出的选择结果,哪一个结点或哪一个激活函数
            #hidden是输入给下一个lstm单元的激活值
            logits, hidden = self.forward(inputs,
                                          hidden,
                                          block_idx,
                                          is_embed=(block_idx == 0))

            probs = F.softmax(logits, dim=-1)
            log_prob = F.log_softmax(
                logits,
                dim=-1)  #While mathematically equivalent to log(softmax(x))
            # TODO(brendan): .mean() for entropy?
            #-ylog(y)
            entropy = -(log_prob * probs).sum(1, keepdim=False)
            #Returns a tensor where each row contains num_samples indices sampled from the multinomial
            #probability distribution located in the corresponding row of tensor input.
            action = probs.multinomial(num_samples=1).data
            selected_log_prob = log_prob.gather(
                1, utils.get_variable(action, requires_grad=False))

            # TODO(brendan): why the [:, 0] here? Should it be .squeeze(), or
            # .view()? Same below with `action`. 功能应该是一样的,可能是代码版本比较老的原因
            entropies.append(entropy)
            log_probs.append(selected_log_prob[:, 0])

            mode = block_idx % 2  # 0: function, 1: previous node
            inputs = utils.get_variable(action[:, 0] +
                                        sum(self.num_tokens[:mode]),
                                        requires_grad=False)

            if mode == 0:
                activations.append(action[:, 0])
            elif mode == 1:
                prev_nodes.append(action[:, 0])

        prev_nodes = torch.stack(prev_nodes).transpose(0, 1)  #[1,11]
        activations = torch.stack(activations).transpose(0, 1)  #[1,12]
        #虽然命名是dags,但实际上只有一个dag
        dags = _construct_dags(prev_nodes, activations, self.func_names,
                               self.args.num_blocks)

        if save_dir is not None:
            for idx, dag in enumerate(dags):
                utils.draw_network(dag, os.path.join(save_dir,
                                                     'graph{idx}.png'))

        #这个是用于训练Controller时除了要返回dag外还要返回一些其他的信息
        if with_details:  #log_probs:list[23](tensor.size([1])) entropies:list[23](tensor.size([1]))
            return dags, torch.cat(log_probs), torch.cat(
                entropies)  #torch.cat(log_probs):Tensor.size([23])

        return dags  #list([1])(defaultdict([25]))
Exemple #7
0
    def sample(self,
               batch_size=1,
               with_details=False,
               save_dir=None):  # ppo랑 연관시켜야할 듯
        """Samples a set of `args.num_blocks` many computational nodes from the
        controller, where each node is made up of an activation function, and
        each node except the last also includes a previous node.
        """
        if batch_size < 1:
            raise Exception(f'Wrong batch_size: {batch_size} < 1')

        # [B, L, H]
        inputs = self.static_inputs[batch_size]
        hidden = self.static_init_hidden[batch_size]

        activations = []
        entropies = []
        log_probs = []
        prev_nodes = []
        # NOTE(brendan): The RNN controller alternately outputs an activation,
        # followed by a previous node, for each block except the last one,
        # which only gets an activation function. The last node is the output
        # node, and its previous node is the average of all leaf nodes.
        for block_idx in range(2 * (self.args.num_blocks - 1) + 1):
            logits, hidden = self.forward(inputs,
                                          hidden,
                                          block_idx,
                                          is_embed=(block_idx == 0))

            probs = F.softmax(logits, dim=-1)
            log_prob = F.log_softmax(logits, dim=-1)
            # TODO(brendan): .mean() for entropy?
            entropy = -(log_prob * probs).sum(1, keepdim=False)

            action = probs.multinomial(num_samples=1).data
            selected_log_prob = log_prob.gather(
                1, utils.get_variable(action, requires_grad=False))

            # TODO(brendan): why the [:, 0] here? Should it be .squeeze(), or
            # .view()? Same below with `action`.
            entropies.append(entropy)
            log_probs.append(selected_log_prob[:, 0])

            # 0: function, 1: previous node
            mode = block_idx % 2
            inputs = utils.get_variable(action[:, 0] +
                                        sum(self.num_tokens[:mode]),
                                        requires_grad=False)

            if mode == 0:
                activations.append(action[:, 0])
            elif mode == 1:
                prev_nodes.append(action[:, 0])

        prev_nodes = torch.stack(prev_nodes).transpose(0, 1)
        activations = torch.stack(activations).transpose(0, 1)

        cell = _construct_cell(prev_nodes, activations, self.func_names,
                               self.args.num_blocks)

        if save_dir is not None:
            for idx, dag in enumerate(cell):
                utils.draw_network(dag,
                                   os.path.join(save_dir, f'graph{idx}.png'))

        if with_details:
            return cell, torch.cat(log_probs), torch.cat(entropies)

        return cell
Exemple #8
0
    def sample(self, batch_size=1, with_details=False, save_dir=None):
        if batch_size < 1:
            raise Exception(f"Wrong batch_size: {batch_size} < 1")

        # [B, L, H]
        inputs = self.static_inputs[batch_size]
        hidden = self.static_init_hidden[batch_size]

        log_probs, entropies = [], []
        activations, prev_nodes = [], []

        for block_idx in range(2 * (self.args.num_blocks - 1) + 1):
            # 0: function, 1: previous node
            mode = block_idx % 2

            logits, hidden = self.forward(inputs,
                                          hidden,
                                          block_idx,
                                          is_embed=block_idx == 0)

            probs = F.softmax(logits)
            log_prob = F.log_softmax(logits)
            entropy = -(log_prob * probs).sum(1, keepdim=True)
            entropies.append(entropy.data[0][0])

            action = probs.multinomial().data
            selected_log_prob = log_prob.gather(
                1, get_variable(action, requires_grad=False))
            log_probs.append(selected_log_prob[0])

            inputs = get_variable(action[:, 0] + sum(self.num_tokens[:mode]),
                                  requires_grad=False)

            if mode == 0:
                activations.append(action[:, 0])
            elif mode == 1:
                prev_nodes.append(action[:, 0])

        prev_nodes = t.stack(prev_nodes).transpose(0, 1)
        activations = t.stack(activations).transpose(0, 1)

        dags = []
        for nodes, func_ids in zip(prev_nodes, activations):
            dag = defaultdict(list)

            # add first node
            dag[-1] = [Node(0, self.func_names[func_ids[0]])]
            dag[-2] = [Node(0, self.func_names[func_ids[0]])]

            # add following nodes
            for jdx, (idx, func_id) in enumerate(zip(nodes, func_ids[1:])):
                dag[idx].append(Node(jdx + 1, self.func_names[func_id]))

            leaf_nodes = set(range(self.args.num_blocks)) - dag.keys()

            # merge with avg
            for idx in leaf_nodes:
                dag[idx] = [Node(self.args.num_blocks, 'avg')]

            # last h[t] node
            dag[self.args.num_blocks] = [
                Node(self.args.num_blocks + 1, 'h[t]')
            ]
            dags.append(dag)

        if save_dir:
            for idx, dag in enumerate(dags):
                draw_network(dag, os.path.join(save_dir, f"graph{idx}.png"))

        if with_details:
            return dags, log_probs, entropies
        else:
            return dags
Exemple #9
0
            elif mode == 1:
                prev_nodes.append(action[:, 0])

        prev_nodes = torch.stack(prev_nodes).transpose(0, 1)
        activations = torch.stack(activations).transpose(0, 1)

        dags = _construct_dags(prev_nodes, activations, self.func_names, self.args.num_blocks)

        # 나중에 사용
        if save_dir is not None:
            for idx, dag in enumerate(dags):
                utils.draw_network(dag, os.path.join(save_dir, f'graph{idx}.png'))

        if with_details:
            return dags, torch.cat(log_probs), torch.cat(entropies)

        return dags



if __name__ == '__main__':
    args, _ = config.get_args()
    c = Controller(args)
    dags = c.sample()
    print(len(dags))
    path = 'sample_model/'

    path = os.path.join(path,  'test.png')

    utils.draw_network(dags[0], path)
Exemple #10
0
    def train_controller(self):

        avg_reward_base = None
        baseline = None
        adv_history = []
        entropy_history = []
        reward_history = []

        controller = models.Controller(self.n_tranformers, self.n_scalers,
                                       self.n_constructers, self.n_selecters,
                                       self.n_models, self.func_names,
                                       self.lstm_size, self.temperature,
                                       self.tanh_constant, self.save_dir)

        controller_optimizer = _get_optimizer(self.optimizer)
        controller_optim = controller_optimizer(controller.parameters(),
                                                lr=self.controller_lr)

        controller.train()
        total_loss = 0

        results_dag = []
        results_acc = []
        random_history = []
        acc_history = []

        for step in range(self.controller_max_step):
            # sample models
            dags, actions, sample_entropy, sample_log_probs = controller()
            sample_entropy = torch.sum(sample_entropy)
            sample_log_probs = torch.sum(sample_log_probs)
            # print(sample_log_probs)
            print(actions)

            random_actions = self.random_actions()
            with torch.no_grad():
                acc = self.get_reward(actions)
                random_acc = self.get_reward(torch.LongTensor(random_actions))

            random_history.append(random_acc)
            results_acc.append(acc)
            results_dag.append(dags)
            acc_history.append(acc)

            rewards = torch.tensor(acc)

            if self.entropy_weight is not None:
                rewards += self.entropy_weight * sample_entropy

            reward_history.append(rewards)
            entropy_history.append(sample_entropy)

            # moving average baseline
            if baseline is None:
                baseline = rewards
            else:
                decay = self.ema_baseline_decay
                baseline = decay * baseline + (1 - decay) * rewards

            adv = rewards - baseline
            adv_history.append(adv)

            # policy loss
            loss = sample_log_probs * adv

            # update
            controller_optim.zero_grad()
            loss.backward()

            if self.controller_grad_clip > 0:
                torch.nn.utils.clip_grad_norm(controller.parameters(),
                                              self.controller_grad_clip)
            controller_optim.step()

            total_loss += loss.item()

            if ((step % self.log_step) == 0) and (step > 0):
                self._summarize_controller_train(total_loss, adv_history,
                                                 entropy_history,
                                                 reward_history, acc_history,
                                                 random_history,
                                                 avg_reward_base, dags)

                reward_history, adv_history, entropy_history,acc_history,random_history = [], [], [],[],[]
                total_loss = 0
            self.controller_step += 1

        max_acc = np.max(results_acc)
        max_dag = results_dag[np.argmax(results_acc)]
        path = os.path.join(self.model_dir, 'networks', 'best.png')
        utils.draw_network(max_dag[0], path)
        # np.sort(results_acc)[-10:]
        return np.sort(list(set(results_acc)))[-10:]
Exemple #11
0
    def sample(self, batch_size=1, with_details=False, save_dir=None):
        """Samples a set of `args.num_blocks` many computational nodes from the
        controller, where each node is made up of an activation function, and
        each node except the last also includes a previous node.
        """
        #TODO: make it fit both for rnn and cnn
        if self.args.network_type == 'rnn':
            if batch_size < 1:
                raise Exception(f'Wrong batch_size: {batch_size} < 1')

            # [B, L, H]
            inputs = self.static_inputs[batch_size]
            hidden = self.static_init_hidden[batch_size]

            activations = []
            entropies = []
            log_probs = []
            prev_nodes = []
            # NOTE(brendan): The RNN controller alternately outputs an activation,
            # followed by a previous node, for each block except the last one,
            # which only gets an activation function. The last node is the output
            # node, and its previous node is the average of all leaf nodes.
            for block_idx in range(2*(self.args.num_blocks - 1) + 1):
                logits, hidden = self.forward(inputs,
                                              hidden,
                                              block_idx,
                                              is_embed=(block_idx == 0))

                probs = F.softmax(logits, dim=-1)
                log_prob = F.log_softmax(logits, dim=-1)
                # TODO(brendan): .mean() for entropy?
                entropy = -(log_prob * probs).sum(1, keepdim=False)

                action = probs.multinomial(num_samples=1).data
                selected_log_prob = log_prob.gather(
                    1, utils.get_variable(action, requires_grad=False))

                # TODO(brendan): why the [:, 0] here? Should it be .squeeze(), or
                # .view()? Same below with `action`.
                entropies.append(entropy)
                log_probs.append(selected_log_prob[:, 0])

                # 0: function, 1: previous node
                mode = block_idx % 2
                inputs = utils.get_variable(
                    action[:, 0] + sum(self.num_tokens[:mode]),
                    requires_grad=False)

                if mode == 0:
                    activations.append(action[:, 0])
                elif mode == 1:
                    prev_nodes.append(action[:, 0])

            prev_nodes = torch.stack(prev_nodes).transpose(0, 1)  #prev_nodes 是做什么的?   用来index的...
            activations = torch.stack(activations).transpose(0, 1)

            dags = _construct_dags(prev_nodes,
                                   activations,
                                   self.func_names,
                                   sum(self.args.cnn_num_blocks),
                                   self.args)

            if save_dir is not None:
                for idx, dag in enumerate(dags):
                    utils.draw_network(dag,
                                       os.path.join(save_dir, f'graph{idx}.png'))

            if with_details:
                return dags, torch.cat(log_probs), torch.cat(entropies)
        if self.args.network_type == 'cnn':
            if batch_size < 1:
                raise Exception(f'Wrong batch_size: {batch_size} < 1')

            # [B, L, H]
            inputs = self.static_inputs[batch_size]
            hidden = self.static_init_hidden[batch_size]
            cnn_functions = []
            entropies = []
            log_probs = []
            prev_nodes = []
            #NOTE The RNN controller alternately outputs an cnn function,
            #followed by a previous node, for each block except the last one,
            #which only gets an cnn function.
            #11*2 + 1 = 23  the first one not need to chose which index
            for block_idx in range(2*(sum(self.args.cnn_num_blocks) - 1) + 1):
                logits, hidden = self.forward(inputs,
                                              hidden,
                                              block_idx,
                                              is_embed=(block_idx == 0))
                probs = F.softmax(logits, dim=-1)
                log_prob = F.log_softmax(logits, dim=-1)
                #TODO understanding policy gradient and improve the code
                entropy = -(log_prob * probs).sum(1, keepdim=False)
                #use multinomial to chose
                #TODO may be skip connect can choose more than one layer
                action = probs.multinomial(num_samples=1).data
                selected_log_prob = log_prob.gather(
                    1, utils.get_variable(action, requires_grad=False))

                entropies.append(entropy)
                log_probs.append(selected_log_prob[:, 0])

                # 0: function, 1: previous node
                mode = block_idx % 2
                inputs = utils.get_variable(
                    action[:, 0] + sum(self.num_tokens[:mode]),
                    requires_grad=False)

                if mode == 0:
                    cnn_functions.append(action[:, 0])
                elif mode == 1:
                    prev_nodes.append(action[:, 0])
            prev_nodes = torch.stack(prev_nodes).transpose(0, 1)
            cnn_functions = torch.stack(cnn_functions).transpose(0, 1)

            dags = _construct_dags(prev_nodes,
                                   cnn_functions,
                                   self.func_names,
                                   sum(self.args.cnn_num_blocks),
                                   self.args)
            if save_dir is not None:
                for idx, dag in enumerate(dags):
                    utils.drwa_network(dag,
                                    os.path.join(save_dir, f'graph{idx}.png'))
            #TODO when with_details need to check
            if with_details:
                return dags, torch.cat(log_probs), torch.cat(entropies)
        return dags
Exemple #12
0
    def sample(self,
               batch_size=1,
               with_details=False,
               save_dir=None,
               construct_dag_method=None):
        """Samples a set of `args.num_blocks` many computational nodes from the
        controller, where each node is made up of an activation function, and
        each node except the last also includes a previous node.
        """

        construct_dag_method = construct_dag_method or _construct_dags

        if batch_size < 1:
            raise Exception(f'Wrong batch_size: {batch_size} < 1')

        # [B, L, H]
        inputs = self.static_inputs[batch_size]
        hidden = self.static_init_hidden[batch_size]

        activations = []
        entropies = []
        log_probs = []
        prev_nodes = []
        for block_idx in range(2 * (self.args.num_blocks - 1) + 1):
            logits, hidden = self.forward(inputs,
                                          hidden,
                                          block_idx,
                                          is_embed=(block_idx == 0))

            probs = F.softmax(logits, dim=-1)
            log_prob = F.log_softmax(logits, dim=-1)
            entropy = -(log_prob * probs).sum(1, keepdim=False)

            action = probs.multinomial(num_samples=1).data
            selected_log_prob = log_prob.gather(
                1, utils.get_variable(action, requires_grad=False))

            entropies.append(entropy)
            log_probs.append(selected_log_prob[:, 0])

            # 0: function, 1: previous node
            mode = block_idx % 2
            inputs = utils.get_variable(action[:, 0] +
                                        sum(self.num_tokens[:mode]),
                                        requires_grad=False)

            if mode == 0:
                activations.append(action[:, 0])
            elif mode == 1:
                prev_nodes.append(action[:, 0])

        prev_nodes = torch.stack(prev_nodes).transpose(0, 1)
        activations = torch.stack(activations).transpose(0, 1)

        dags = construct_dag_method(prev_nodes, activations, self.func_names,
                                    self.args.num_blocks)

        if save_dir is not None:
            for idx, dag in enumerate(dags):
                utils.draw_network(dag,
                                   os.path.join(save_dir, f'graph{idx}.png'))

        if with_details:
            return dags, torch.cat(log_probs), torch.cat(entropies)

        return dags