示例#1
0
    def controller_eval_step(self, ctrl_dataloader, epoch, subset='training'):
        model = self.controller.epd
        ground_truth_perf_list = []
        ground_truth_arch_seq_list = []
        predict_value_list = []
        arch_seq_list = []

        time = StopwatchMeter()
        time.start()

        def _safe_extend(l, v):
            v = v.data.squeeze().tolist()
            if isinstance(v, (float, int)):
                l.append(v)
            else:
                l.extend(v)

        with th.cuda.device(self.hparams.epd_device):
            for step, sample in enumerate(ctrl_dataloader):
                model.eval()
                sample = nao_utils.prepare_ctrl_sample(sample, evaluation=True)
                predict_value, logits, arch = model(
                    sample['encoder_input'])  # target_variable=None
                _safe_extend(predict_value_list, predict_value)
                _safe_extend(arch_seq_list, arch)
                _safe_extend(ground_truth_perf_list, sample['encoder_target'])
                _safe_extend(ground_truth_arch_seq_list,
                             sample['decoder_target'])

        pairwise_acc = nao_utils.pairwise_accuracy(ground_truth_perf_list,
                                                   predict_value_list)
        hamming_dis = nao_utils.hamming_distance(ground_truth_arch_seq_list,
                                                 arch_seq_list)

        if pairwise_acc > self._ctrl_best_pa[subset]:
            self._ctrl_best_pa[subset] = pairwise_acc

        time.stop()
        logging.info(
            '| ctrl eval ({}) | epoch {:03d} | PA {:<6.6f} | BestPA {:<6.6f} |'
            ' HD {:<6.6f} | {:<6.2f} secs'.format(subset, epoch, pairwise_acc,
                                                  self._ctrl_best_pa[subset],
                                                  hamming_dis, time.sum))
示例#2
0
文件: nao_train.py 项目: m8e/NAS4Text
    def controller_generate_step(self, old_arches, log_compare_perf=False):
        epd = self.controller.epd

        old_arches = old_arches[:self.hparams.num_remain_top]

        # print('#old_arches:', [a.blocks for a in old_arches])

        new_arches = []
        mapped_old_perf_list = []
        final_old_perf_list, final_new_perf_list = [], []

        predict_lambda = 0
        topk_arches = [
            self._parse_arch_to_seq(arch)
            for arch in old_arches[:self.hparams.num_pred_top]
        ]
        topk_arches_loader = nao_utils.make_tensor_dataloader(
            [th.LongTensor(topk_arches)],
            self.hparams.ctrl_batch_size,
            shuffle=False)

        with th.cuda.device(self.hparams.epd_device):
            while len(new_arches) + len(
                    old_arches) < self.hparams.num_seed_arch:
                # [NOTE]: When predict_lambda get larger, increase faster.
                if predict_lambda < 50:
                    predict_lambda += self.hparams.lambda_step
                elif predict_lambda >= 10000000:
                    # FIXME: A temporary solution: stop the generation when the lambda is too large.
                    break
                else:
                    predict_lambda += predict_lambda / 50
                logging.info(
                    'Generating new architectures using gradient descent with step size {}'
                    .format(predict_lambda))

                new_arch_seq_list = []
                for step, (encoder_input, ) in enumerate(topk_arches_loader):
                    epd.eval()
                    epd.zero_grad()
                    encoder_input = common.make_variable(encoder_input,
                                                         volatile=False,
                                                         cuda=True)
                    new_arch_seq, ret_dict = epd.generate_new_arch(
                        encoder_input, predict_lambda)
                    new_arch_seq_list.extend(
                        new_arch_seq.data.squeeze().tolist())
                    mapped_old_perf_list.extend(
                        ret_dict['predict_value'].squeeze().tolist())
                    del ret_dict

                for i, (arch_seq, mapped_perf) in enumerate(
                        zip(new_arch_seq_list, mapped_old_perf_list)):
                    # Insert new arches (skip same and invalid).
                    # [NOTE]: Reduce the "ctrl_trade_off" value to let it generate different architectures.
                    arch = self.controller.parse_seq_to_arch(arch_seq)
                    if arch is None:
                        continue

                    if not self._arch_contains(
                            arch, old_arches) and not self._arch_contains(
                                arch, new_arches):
                        new_arches.append(arch)
                        # Test the new arch.
                        sample = nao_utils.prepare_ctrl_sample(
                            [
                                th.LongTensor([arch_seq]),
                                th.LongTensor([arch_seq]),
                                th.LongTensor([arch_seq])
                            ],
                            perf=False,
                        )
                        predict_value, _, _ = epd(sample['encoder_input'],
                                                  sample['decoder_input'])
                        final_old_perf_list.append(mapped_perf)
                        final_new_perf_list.append(predict_value.item())
                    if len(new_arches) + len(
                            old_arches) >= self.hparams.num_seed_arch:
                        break
                logging.info('{} new arches generated now'.format(
                    len(new_arches)))

        # Compare old and new perf.
        if log_compare_perf:
            print('Old and new performances:')
            _s_old, _s_new = 0.0, 0.0
            for _old, _new in zip(final_old_perf_list, final_new_perf_list):
                print('old = {}, new = {}, old - new = {}'.format(
                    _old, _new, _old - _new))
                _s_old += _old
                _s_new += _new
            _s_old /= len(final_new_perf_list)
            _s_new /= len(final_new_perf_list)
            print('Average: old = {}, new = {}, old - new = {}'.format(
                _s_old, _s_new, _s_old - _s_new))

        self.arch_pool = old_arches + new_arches
        return self.arch_pool
示例#3
0
文件: nao_train.py 项目: m8e/NAS4Text
    def controller_train_step(self,
                              old_arches,
                              old_arches_perf,
                              split_test=True):
        logging.info('Training Encoder-Predictor-Decoder')
        augment = self.hparams.augment
        augment_rep = self.hparams.augment_rep

        perf = self._normalized_perf(old_arches_perf)

        if split_test:
            train_ap, test_ap = self._shuffle_and_split(old_arches,
                                                        perf,
                                                        test_size=0.1)
            if augment:
                train_ap = nao_utils.arch_augmentation(
                    *train_ap,
                    augment_rep=self.hparams.augment_rep,
                    focus_top=self.hparams.focus_top)
            train_arches, train_bleus = train_ap
            train_ap = [
                self._parse_arch_to_seq(arch) for arch in train_arches
            ], train_bleus
            test_arches, test_bleus = test_ap
            test_ap = [self._parse_arch_to_seq(arch)
                       for arch in test_arches], test_bleus

            ctrl_dataloader = nao_utils.make_ctrl_dataloader(
                *train_ap,
                batch_size=self.hparams.ctrl_batch_size,
                shuffle=True,
                sos_id=self.controller.epd.sos_id)
            test_ctrl_dataloader = nao_utils.make_ctrl_dataloader(
                *test_ap,
                batch_size=self.hparams.ctrl_batch_size,
                shuffle=False,
                sos_id=self.controller.epd.sos_id)
        else:
            if augment:
                old_arches, perf = nao_utils.arch_augmentation(
                    old_arches,
                    perf,
                    augment_rep=augment_rep,
                    focus_top=self.hparams.focus_top)
            arch_seqs = [self._parse_arch_to_seq(arch) for arch in old_arches]
            ctrl_dataloader = nao_utils.make_ctrl_dataloader(
                arch_seqs,
                perf,
                batch_size=self.hparams.ctrl_batch_size,
                shuffle=True,
                sos_id=self.controller.epd.sos_id)
            test_ctrl_dataloader = ctrl_dataloader

        epochs = range(1, self.hparams.ctrl_train_epochs + 1)
        if tqdm is not None:
            epochs = tqdm(list(epochs))

        step = 0
        with th.cuda.device(self.hparams.epd_device):
            for epoch in epochs:
                for epoch_step, sample in enumerate(ctrl_dataloader):
                    self.controller.epd.train()

                    sample = nao_utils.prepare_ctrl_sample(sample,
                                                           evaluation=False)

                    # print('#Expected global range:', self.controller.expected_global_range())
                    # print('#Expected node range:', self.controller.expected_index_range())
                    # print('#Expected op range:', self.controller.expected_op_range(False), self.controller.expected_op_range(True))
                    # print('#encoder_input', sample['encoder_input'].shape, sample['encoder_input'][0])
                    # print('#encoder_target', sample['encoder_target'].shape, sample['encoder_target'][0])
                    # print('#decoder_input', sample['decoder_input'].shape, sample['decoder_input'][0])
                    # print('#decoder_target', sample['decoder_target'].shape, sample['decoder_target'][0])

                    # FIXME: Use ParallelModel here?
                    predict_value, logits, arch = self.controller.epd(
                        sample['encoder_input'], sample['decoder_input'])

                    # print('#predict_value', predict_value.shape, predict_value.tolist())
                    # print('#logits', logits.shape)
                    # print('$arch', arch.shape, arch)

                    # Loss and optimize.
                    loss_1 = F.mse_loss(predict_value.squeeze(),
                                        sample['encoder_target'].squeeze())
                    logits_size = logits.size()
                    n = logits_size[0] * logits_size[1]
                    loss_2 = F.cross_entropy(logits.contiguous().view(n, -1),
                                             sample['decoder_target'].view(n))

                    loss = self.hparams.ctrl_trade_off * loss_1 + (
                        1 - self.hparams.ctrl_trade_off) * loss_2

                    self.ctrl_optimizer.zero_grad()
                    loss.backward()
                    grad_norm = th.nn.utils.clip_grad_norm_(
                        self.controller.epd.parameters(),
                        self.hparams.ctrl_clip_norm)
                    self.ctrl_optimizer.step()

                    # TODO: Better logging here.
                    if step % self.hparams.ctrl_log_freq == 0:
                        print(
                            '| ctrl | epoch {:03d} | step {:03d} | loss={:5.6f} '
                            '| mse={:5.6f} | cse={:5.6f} | gnorm={:5.6f}'.
                            format(
                                epoch,
                                step,
                                loss.data,
                                loss_1.data,
                                loss_2.data,
                                grad_norm,
                            ))

                    step += 1

                # TODO: Add evaluation and controller model saving here.
                if epoch % self.hparams.ctrl_eval_freq == 0:
                    if ctrl_dataloader is not test_ctrl_dataloader:
                        self.controller_eval_step(test_ctrl_dataloader,
                                                  epoch,
                                                  subset='test')
                        # [NOTE]: Can omit eval on training set to speed up.
                        if epoch % self.hparams.ctrl_eval_train_freq == 0:
                            self.controller_eval_step(ctrl_dataloader,
                                                      epoch,
                                                      subset='training')
                    else:
                        self.controller_eval_step(test_ctrl_dataloader,
                                                  epoch,
                                                  subset='training')