コード例 #1
0
class Trainer(object):
    """A class to wrap training code."""
    def __init__(self, args, dataset):
        """Constructor for training algorithm.

        Args:
            args: From command line, picked up by `argparse`.
            dataset: Currently only `data.text.Corpus` is supported.

        Initializes:
            - Data: train, val and test.
            - Model: shared and controller.
            - Inference: optimizers for shared and controller parameters.
            - Criticism: cross-entropy loss for training the shared model.
        """
        self.args = args
        self.controller_step = 0
        self.cuda = args.cuda
        self.dataset = dataset
        self.epoch = 0
        self.shared_step = 0
        self.start_epoch = 0

        logger.info('regularizing:')
        for regularizer in [('activation regularization',
                             self.args.activation_regularization),
                            ('temporal activation regularization',
                             self.args.temporal_activation_regularization),
                            ('norm stabilizer regularization',
                             self.args.norm_stabilizer_regularization)]:
            if regularizer[1]:
                logger.info(f'{regularizer[0]}')

        self.train_data = utils.batchify(dataset.train, args.batch_size,
                                         self.cuda)
        # NOTE(brendan): The validation set data is batchified twice
        # separately: once for computing rewards during the Train Controller
        # phase (valid_data, batch size == 64), and once for evaluating ppl
        # over the entire validation set (eval_data, batch size == 1)
        self.valid_data = utils.batchify(dataset.valid, args.batch_size,
                                         self.cuda)
        self.eval_data = utils.batchify(dataset.valid, args.test_batch_size,
                                        self.cuda)
        self.test_data = utils.batchify(dataset.test, args.test_batch_size,
                                        self.cuda)

        self.max_length = self.args.shared_rnn_max_length

        if args.use_tensorboard:
            self.tb = TensorBoard(args.model_dir)
        else:
            self.tb = None
        self.build_model()

        if self.args.load_path:
            self.load_model()

        shared_optimizer = _get_optimizer(self.args.shared_optim)
        controller_optimizer = _get_optimizer(self.args.controller_optim)

        self.shared_optim = shared_optimizer(
            self.shared.parameters(),
            lr=self.shared_lr,
            weight_decay=self.args.shared_l2_reg)

        self.controller_optim = controller_optimizer(
            self.controller.parameters(), lr=self.args.controller_lr)

        self.ce = nn.CrossEntropyLoss()

    def build_model(self):
        """Creates and initializes the shared and controller models."""
        if self.args.network_type == 'rnn':
            self.shared = models.RNN(self.args, self.dataset)
        elif self.args.network_type == 'cnn':
            self.shared = models.CNN(self.args, self.dataset)
        else:
            raise NotImplementedError(f'Network type '
                                      f'`{self.args.network_type}` is not '
                                      f'defined')
        self.controller = models.Controller(self.args)

        if self.args.num_gpu == 1:
            self.shared.cuda()
            self.controller.cuda()
        elif self.args.num_gpu > 1:
            raise NotImplementedError('`num_gpu > 1` is in progress')

    def train(self):
        """Cycles through alternately training the shared parameters and the
        controller, as described in Section 2.2, Training ENAS and Deriving
        Architectures, of the paper.

        From the paper (for Penn Treebank):

        - In the first phase, shared parameters omega are trained for 400
          steps, each on a minibatch of 64 examples.

        - In the second phase, the controller's parameters are trained for 2000
          steps.
        """
        if self.args.shared_initial_step > 0:
            self.train_shared(self.args.shared_initial_step)
            self.train_controller()

        for self.epoch in range(self.start_epoch, self.args.max_epoch):
            # 1. Training the shared parameters omega of the child models
            self.train_shared()

            # 2. Training the controller parameters theta
            self.train_controller()

            if self.epoch % self.args.save_epoch == 0:
                with _get_no_grad_ctx_mgr():
                    best_dag = self.derive()
                    self.evaluate(self.eval_data,
                                  best_dag,
                                  'val_best',
                                  max_num=self.args.batch_size * 100)
                self.save_model()

            if self.epoch >= self.args.shared_decay_after:
                utils.update_lr(self.shared_optim, self.shared_lr)

    def get_loss(self, inputs, targets, hidden, dags):
        """Computes the loss for the same batch for M models.

        This amounts to an estimate of the loss, which is turned into an
        estimate for the gradients of the shared model.
        """
        if not isinstance(dags, list):
            dags = [dags]

        loss = 0
        for dag in dags:
            output, hidden, extra_out = self.shared(inputs, dag, hidden=hidden)
            output_flat = output.view(-1, self.dataset.num_tokens)
            sample_loss = (self.ce(output_flat, targets) /
                           self.args.shared_num_sample)
            loss += sample_loss

        assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`'
        return loss, hidden, extra_out

    def train_shared(self, max_step=None):
        """Train the language model for 400 steps of minibatches of 64
        examples.

        Args:
            max_step: Used to run extra training steps as a warm-up.

        BPTT is truncated at 35 timesteps.

        For each weight update, gradients are estimated by sampling M models
        from the fixed controller policy, and averaging their gradients
        computed on a batch of training data.
        """
        model = self.shared
        model.train()
        self.controller.eval()

        hidden = self.shared.init_hidden(self.args.batch_size)

        if max_step is None:
            max_step = self.args.shared_max_step
        else:
            max_step = min(self.args.shared_max_step, max_step)

        abs_max_grad = 0
        abs_max_hidden_norm = 0
        step = 0
        raw_total_loss = 0
        total_loss = 0
        train_idx = 0
        # TODO(brendan): Why - 1 - 1?
        while train_idx < self.train_data.size(0) - 1 - 1:
            if step > max_step:
                break

            dags = self.controller.sample(self.args.shared_num_sample)
            inputs, targets = self.get_batch(self.train_data, train_idx,
                                             self.max_length)

            loss, hidden, extra_out = self.get_loss(inputs, targets, hidden,
                                                    dags)
            hidden.detach_()
            raw_total_loss += loss.data

            loss += _apply_penalties(extra_out, self.args)

            # update
            self.shared_optim.zero_grad()
            loss.backward()

            h1tohT = extra_out['hiddens']
            new_abs_max_hidden_norm = utils.to_item(
                h1tohT.norm(dim=-1).data.max())
            if new_abs_max_hidden_norm > abs_max_hidden_norm:
                abs_max_hidden_norm = new_abs_max_hidden_norm
                logger.info(f'max hidden {abs_max_hidden_norm}')
            abs_max_grad = _check_abs_max_grad(abs_max_grad, model)
            torch.nn.utils.clip_grad_norm(model.parameters(),
                                          self.args.shared_grad_clip)
            self.shared_optim.step()

            total_loss += loss.data

            if ((step % self.args.log_step) == 0) and (step > 0):
                self._summarize_shared_train(total_loss, raw_total_loss)
                raw_total_loss = 0
                total_loss = 0

            step += 1
            self.shared_step += 1
            train_idx += self.max_length

    def get_reward(self, dag, entropies, hidden, valid_idx=0):
        """Computes the perplexity of a single sampled model on a minibatch of
        validation data.
        """
        if not isinstance(entropies, np.ndarray):
            entropies = entropies.data.cpu().numpy()

        inputs, targets = self.get_batch(self.valid_data,
                                         valid_idx,
                                         self.max_length,
                                         volatile=True)
        valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag)
        valid_loss = utils.to_item(valid_loss.data)

        valid_ppl = math.exp(valid_loss)

        # TODO: we don't know reward_c
        if self.args.ppl_square:
            # TODO: but we do know reward_c=80 in the previous paper
            R = self.args.reward_c / valid_ppl**2
        else:
            R = self.args.reward_c / valid_ppl

        if self.args.entropy_mode == 'reward':
            rewards = R + self.args.entropy_coeff * entropies
        elif self.args.entropy_mode == 'regularizer':
            rewards = R * np.ones_like(entropies)
        else:
            raise NotImplementedError(
                f'Unkown entropy mode: {self.args.entropy_mode}')

        return rewards, hidden

    def train_controller(self):
        """Fixes the shared parameters and updates the controller parameters.

        The controller is updated with a score function gradient estimator
        (i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl
        is computed on a minibatch of validation data.

        A moving average baseline is used.

        The controller is trained for 2000 steps per epoch (i.e.,
        first (Train Shared) phase -> second (Train Controller) phase).
        """
        model = self.controller
        model.train()
        # TODO(brendan): Why can't we call shared.eval() here? Leads to loss
        # being uniformly zero for the controller.
        # self.shared.eval()

        avg_reward_base = None
        baseline = None
        adv_history = []
        entropy_history = []
        reward_history = []

        hidden = self.shared.init_hidden(self.args.batch_size)
        total_loss = 0
        valid_idx = 0
        for step in range(self.args.controller_max_step):
            # sample models
            dags, log_probs, entropies = self.controller.sample(
                with_details=True)

            # calculate reward
            np_entropies = entropies.data.cpu().numpy()
            # NOTE(brendan): No gradients should be backpropagated to the
            # shared model during controller training, obviously.
            with _get_no_grad_ctx_mgr():
                rewards, hidden = self.get_reward(dags, np_entropies, hidden,
                                                  valid_idx)

            # discount
            if 1 > self.args.discount > 0:
                rewards = discount(rewards, self.args.discount)

            reward_history.extend(rewards)
            entropy_history.extend(np_entropies)

            # moving average baseline
            if baseline is None:
                baseline = rewards
            else:
                decay = self.args.ema_baseline_decay
                baseline = decay * baseline + (1 - decay) * rewards

            adv = rewards - baseline
            adv_history.extend(adv)

            # policy loss
            loss = -log_probs * utils.get_variable(
                adv, self.cuda, requires_grad=False)
            if self.args.entropy_mode == 'regularizer':
                loss -= self.args.entropy_coeff * entropies

            loss = loss.sum()  # or loss.mean()

            # update
            self.controller_optim.zero_grad()
            loss.backward()

            if self.args.controller_grad_clip > 0:
                torch.nn.utils.clip_grad_norm(model.parameters(),
                                              self.args.controller_grad_clip)
            self.controller_optim.step()

            total_loss += utils.to_item(loss.data)

            if ((step % self.args.log_step) == 0) and (step > 0):
                self._summarize_controller_train(total_loss, adv_history,
                                                 entropy_history,
                                                 reward_history,
                                                 avg_reward_base, dags)

                reward_history, adv_history, entropy_history = [], [], []
                total_loss = 0

            self.controller_step += 1

            prev_valid_idx = valid_idx
            valid_idx = ((valid_idx + self.max_length) %
                         (self.valid_data.size(0) - 1))
            # NOTE(brendan): Whenever we wrap around to the beginning of the
            # validation data, we reset the hidden states.
            if prev_valid_idx > valid_idx:
                hidden = self.shared.init_hidden(self.args.batch_size)

    def evaluate(self, source, dag, name, batch_size=1, max_num=None):
        """Evaluate on the validation set.

        NOTE(brendan): We should not be using the test set to develop the
        algorithm (basic machine learning good practices).
        """
        self.shared.eval()
        self.controller.eval()

        data = source[:max_num * self.max_length]

        total_loss = 0
        hidden = self.shared.init_hidden(batch_size)

        pbar = range(0, data.size(0) - 1, self.max_length)
        for count, idx in enumerate(pbar):
            inputs, targets = self.get_batch(data, idx, volatile=True)
            output, hidden, _ = self.shared(inputs,
                                            dag,
                                            hidden=hidden,
                                            is_train=False)
            output_flat = output.view(-1, self.dataset.num_tokens)
            total_loss += len(inputs) * self.ce(output_flat, targets).data
            hidden.detach_()
            ppl = math.exp(
                utils.to_item(total_loss) / (count + 1) / self.max_length)

        val_loss = utils.to_item(total_loss) / len(data)
        ppl = math.exp(val_loss)

        self.tb.scalar_summary(f'eval/{name}_loss', val_loss, self.epoch)
        self.tb.scalar_summary(f'eval/{name}_ppl', ppl, self.epoch)
        logger.info(f'eval | loss: {val_loss:8.2f} | ppl: {ppl:8.2f}')

    def derive(self, sample_num=None, valid_idx=0):
        """TODO(brendan): We are always deriving based on the very first batch
        of validation data? This seems wrong...
        """
        hidden = self.shared.init_hidden(self.args.batch_size)

        if sample_num is None:
            sample_num = self.args.derive_num_sample

        dags, _, entropies = self.controller.sample(sample_num,
                                                    with_details=True)

        max_R = 0
        best_dag = None
        for dag in dags:
            R, _ = self.get_reward(dag, entropies, hidden, valid_idx)
            if R.max() > max_R:
                max_R = R.max()
                best_dag = dag

        logger.info(f'derive | max_R: {max_R:8.6f}')
        fname = (f'{self.epoch:03d}-{self.controller_step:06d}-'
                 f'{max_R:6.4f}-best.png')
        path = os.path.join(self.args.model_dir, 'networks', fname)
        utils.draw_network(best_dag, path)
        self.tb.image_summary('derive/best', [path], self.epoch)

        return best_dag

    @property
    def shared_lr(self):
        degree = max(self.epoch - self.args.shared_decay_after + 1, 0)
        return self.args.shared_lr * (self.args.shared_decay**degree)

    @property
    def controller_lr(self):
        return self.args.controller_lr

    def get_batch(self, source, idx, length=None, volatile=False):
        # code from
        # https://github.com/pytorch/examples/blob/master/word_language_model/main.py
        length = min(length if length else self.max_length,
                     len(source) - 1 - idx)
        data = Variable(source[idx:idx + length], volatile=volatile)
        target = Variable(source[idx + 1:idx + 1 + length].view(-1),
                          volatile=volatile)
        return data, target

    @property
    def shared_path(self):
        return f'{self.args.model_dir}/shared_epoch{self.epoch}_step{self.shared_step}.pth'

    @property
    def controller_path(self):
        return f'{self.args.model_dir}/controller_epoch{self.epoch}_step{self.controller_step}.pth'

    def get_saved_models_info(self):
        paths = glob.glob(os.path.join(self.args.model_dir, '*.pth'))
        paths.sort()

        def get_numbers(items, delimiter, idx, replace_word, must_contain=''):
            return list(
                set([
                    int(name.split(delimiter)[idx].replace(replace_word, ''))
                    for name in basenames if must_contain in name
                ]))

        basenames = [
            os.path.basename(path.rsplit('.', 1)[0]) for path in paths
        ]
        epochs = get_numbers(basenames, '_', 1, 'epoch')
        shared_steps = get_numbers(basenames, '_', 2, 'step', 'shared')
        controller_steps = get_numbers(basenames, '_', 2, 'step', 'controller')

        epochs.sort()
        shared_steps.sort()
        controller_steps.sort()

        return epochs, shared_steps, controller_steps

    def save_model(self):
        torch.save(self.shared.state_dict(), self.shared_path)
        logger.info(f'[*] SAVED: {self.shared_path}')

        torch.save(self.controller.state_dict(), self.controller_path)
        logger.info(f'[*] SAVED: {self.controller_path}')

        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        for epoch in epochs[:-self.args.max_save_num]:
            paths = glob.glob(
                os.path.join(self.args.model_dir, f'*_epoch{epoch}_*.pth'))

            for path in paths:
                utils.remove_file(path)

    def load_model(self):
        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        if len(epochs) == 0:
            logger.info(f'[!] No checkpoint found in {self.args.model_dir}...')
            return

        self.epoch = self.start_epoch = max(epochs)
        self.shared_step = max(shared_steps)
        self.controller_step = max(controller_steps)

        if self.args.num_gpu == 0:
            map_location = lambda storage, loc: storage
        else:
            map_location = None

        self.shared.load_state_dict(
            torch.load(self.shared_path, map_location=map_location))
        logger.info(f'[*] LOADED: {self.shared_path}')

        self.controller.load_state_dict(
            torch.load(self.controller_path, map_location=map_location))
        logger.info(f'[*] LOADED: {self.controller_path}')

    def _summarize_controller_train(self, total_loss, adv_history,
                                    entropy_history, reward_history,
                                    avg_reward_base, dags):
        """Logs the controller's progress for this training epoch."""
        cur_loss = total_loss / self.args.log_step

        avg_adv = np.mean(adv_history)
        avg_entropy = np.mean(entropy_history)
        avg_reward = np.mean(reward_history)

        if avg_reward_base is None:
            avg_reward_base = avg_reward

        logger.info(f'| epoch {self.epoch:3d} | lr {self.controller_lr:.5f} '
                    f'| R {avg_reward:.5f} | entropy {avg_entropy:.4f} '
                    f'| loss {cur_loss:.5f}')

        # Tensorboard
        if self.tb is not None:
            self.tb.scalar_summary('controller/loss', cur_loss,
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward', avg_reward,
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward-B_per_epoch',
                                   avg_reward - avg_reward_base,
                                   self.controller_step)
            self.tb.scalar_summary('controller/entropy', avg_entropy,
                                   self.controller_step)
            self.tb.scalar_summary('controller/adv', avg_adv,
                                   self.controller_step)

            paths = []
            for dag in dags:
                fname = (f'{self.epoch:03d}-{self.controller_step:06d}-'
                         f'{avg_reward:6.4f}.png')
                path = os.path.join(self.args.model_dir, 'networks', fname)
                utils.draw_network(dag, path)
                paths.append(path)

            self.tb.image_summary('controller/sample', paths,
                                  self.controller_step)

    def _summarize_shared_train(self, total_loss, raw_total_loss):
        """Logs a set of training steps."""
        cur_loss = utils.to_item(total_loss) / self.args.log_step
        # NOTE(brendan): The raw loss, without adding in the activation
        # regularization terms, should be used to compute ppl.
        cur_raw_loss = utils.to_item(raw_total_loss) / self.args.log_step
        ppl = math.exp(cur_raw_loss)

        logger.info(f'| epoch {self.epoch:3d} '
                    f'| lr {self.shared_lr:4.2f} '
                    f'| raw loss {cur_raw_loss:.2f} '
                    f'| loss {cur_loss:.2f} '
                    f'| ppl {ppl:8.2f}')

        # Tensorboard
        if self.tb is not None:
            self.tb.scalar_summary('shared/loss', cur_loss, self.shared_step)
            self.tb.scalar_summary('shared/perplexity', ppl, self.shared_step)
コード例 #2
0
class Trainer(object):
    def __init__(self, args, dataset):
        self.args = args
        self.cuda = args.cuda
        self.dataset = dataset

        self.train_data = batchify(dataset.train, args.batch_size, self.cuda)
        self.valid_data = batchify(dataset.valid, args.batch_size, self.cuda)
        self.test_data = batchify(dataset.test, args.test_batch_size,
                                  self.cuda)

        self.max_length = self.args.shared_rnn_max_length

        if args.use_tensorboard:
            self.tb = TensorBoard(args.model_dir)
        else:
            self.tb = None
        self.build_model()

        if self.args.load_path:
            self.load_model()

    def build_model(self):
        self.start_epoch = self.epoch = 0
        self.shared_step, self.controller_step = 0, 0

        if self.args.network_type == 'rnn':
            self.shared = RNN(self.args, self.dataset)
        elif self.args.network_type == 'cnn':
            self.shared = CNN(self.args, self.dataset)
        else:
            raise NotImplemented(
                f"Network type `{self.args.network_type}` is not defined")
        self.controller = Controller(self.args)

        if self.args.num_gpu == 1:
            self.shared.cuda()
            self.controller.cuda()
        elif self.args.num_gpu > 1:
            raise NotImplemented("`num_gpu > 1` is in progress")

        self.ce = nn.CrossEntropyLoss()

    def train(self):
        shared_optimizer = get_optimizer(self.args.shared_optim)
        controller_optimizer = get_optimizer(self.args.controller_optim)

        self.shared_optim = shared_optimizer(
            self.shared.parameters(),
            lr=self.shared_lr,
            weight_decay=self.args.shared_l2_reg)

        self.controller_optim = controller_optimizer(
            self.controller.parameters(), lr=self.args.controller_lr)

        hidden = self.shared.init_hidden(self.args.batch_size)

        for self.epoch in range(self.start_epoch, self.args.max_epoch):
            # 1. Training the shared parameters ω of the child models
            hidden = self.train_shared(hidden)

            # 2. Training the controller parameters θ
            self.train_controller()

            if self.epoch % self.args.save_epoch == 0:
                if self.epoch > 0:
                    best_dag = self.derive()
                    loss, ppl = self.test(self.test_data, best_dag,
                                          "test_best")
                self.save_model()

            if self.epoch >= self.args.shared_decay_after:
                update_lr(self.shared_optim, self.shared_lr)

    def get_loss(self, inputs, targets, hidden, dags, with_hidden=False):
        if type(dags) != list:
            dags = [dags]

        loss = 0
        for dag in dags:
            # previous hidden is useless
            output, hidden = self.shared(inputs, hidden, dag)
            output_flat = output.view(-1, self.dataset.num_tokens)
            sample_loss = self.ce(output_flat,
                                  targets) / self.args.shared_num_sample
            loss += sample_loss

        if with_hidden:
            assert len(
                dags) == 1, "there are multiple `hidden` for multple `dags`"
            return loss, hidden
        else:
            return loss

    def train_shared(self, hidden):
        total_loss = 0

        model = self.shared
        model.train()

        step, train_idx = 0, 0
        pbar = tqdm(total=self.train_data.size(0), desc="train_shared")

        while train_idx < self.train_data.size(0) - 1 - 1:
            if step > self.args.shared_max_step:
                break

            dags = self.controller.sample(self.args.shared_num_sample)
            inputs, targets = self.get_batch(self.train_data, train_idx,
                                             self.max_length)

            loss = self.get_loss(inputs, targets, hidden, dags)

            # update
            self.shared_optim.zero_grad()
            loss.backward()

            t.nn.utils.clip_grad_norm(model.parameters(),
                                      self.args.shared_grad_clip)
            self.shared_optim.step()

            total_loss += loss.data
            pbar.set_description(f"train_shared| loss: {loss.data[0]:5.3f}")

            if step % self.args.log_step == 0 and step > 0:
                cur_loss = total_loss[0] / self.args.log_step
                ppl = math.exp(cur_loss)

                logger.info(
                    f'| epoch {self.epoch:3d} | lr {self.shared_lr:4.2f} '
                    f'| loss {cur_loss:.2f} | ppl {ppl:8.2f}')

                # Tensorboard
                if self.tb is not None:
                    self.tb.scalar_summary("shared/loss", cur_loss,
                                           self.shared_step)
                    self.tb.scalar_summary("shared/perplexity", ppl,
                                           self.shared_step)

                total_loss = 0

            step += 1
            self.shared_step += 1

            train_idx += self.max_length
            pbar.update(self.max_length)

    def get_reward(self, dag, valid_idx=None):
        if valid_idx:
            valid_idx = 0

        inputs, targets = self.get_batch(self.valid_data, valid_idx,
                                         self.max_length)
        valid_loss = self.get_loss(inputs, targets, None, dag)

        valid_ppl = math.exp(valid_loss.data[0])
        R = self.args.reward_c / valid_ppl

        return R

    def train_controller(self):
        total_loss = 0

        model = self.controller
        model.train()

        pbar = trange(self.args.controller_max_step, desc="train_controller")

        baseline = None
        reward_history, adv_history, entropy_history = [], [], []

        valid_idx = 0

        for step in pbar:
            # sample models
            dags, log_probs, entropies = self.controller.sample(
                with_details=True)

            # calculate reward
            R = self.get_reward(dags, valid_idx)

            reward_history.append(R)
            entropy_history.extend(entropies)

            # moving average baseline
            if baseline is None:
                baseline = R
            else:
                decay = self.args.ema_baseline_decay
                baseline = decay * baseline + (1 - decay) * R

            adv = R - baseline
            adv_history.append(adv)
            pbar.set_description(
                f"train_controller| R: {R:8.6f} | R-b: {adv:8.6f}")

            rewards = [0] * (2 * (self.args.num_blocks - 1)) + [adv]
            # discount
            if self.args.discount == 1:
                rewards = [adv] * len(log_probs)
            elif self.args.discount > 0:
                rewards = discount(rewards, self.args.discount)
            #rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps)

            # policy loss
            loss = 0
            for log_prob, reward, entropy in zip(log_probs, rewards,
                                                 entropies):
                loss = loss - log_prob * reward - self.args.entropy_coeff * entropy

            # update
            self.controller_optim.zero_grad()
            loss.backward()
            self.controller_optim.step()

            total_loss += loss.data

            if step % self.args.log_step == 0 and step > 0:
                cur_loss = total_loss[0][0] / self.args.log_step

                avg_reward = np.mean(reward_history)
                avg_entropy = np.mean(entropy_history)
                avg_adv = np.mean(adv_history)

                logger.info(
                    f'| epoch {self.epoch:3d} | lr {self.controller_lr:.5f} '
                    f'| R {avg_reward:.5f} | entropy {avg_entropy:.4f} '
                    f'| loss {cur_loss:.5f}')

                # Tensorboard
                if self.tb is not None:
                    self.tb.scalar_summary("controller/loss", cur_loss,
                                           self.controller_step)
                    self.tb.scalar_summary("controller/reward", avg_reward,
                                           self.controller_step)
                    self.tb.scalar_summary("controller/entropy", avg_entropy,
                                           self.controller_step)
                    self.tb.scalar_summary("controller/adv", avg_adv,
                                           self.controller_step)

                    paths = []
                    for dag in dags:
                        fname = f"{self.epoch:03d}-{self.controller_step:06d}-{avg_reward:6.4f}.png"
                        path = os.path.join(self.args.model_dir, "networks",
                                            fname)
                        draw_network(dag, path)
                        paths.append(path)

                    self.tb.image_summary("controller/sample", paths,
                                          self.controller_step)

                reward_history, adv_history, entropy_history = [], [], []

            self.controller_step += 1

            valid_idx = (valid_idx +
                         self.max_length) % (self.valid_data.size(0) - 1)

    def test(self, source, dag, name, batch_size=1):
        self.shared.eval()
        self.controller.eval()

        total_loss = 0
        hidden = self.shared.init_hidden(batch_size)

        pbar = trange(0, source.size(0) - 1, self.max_length, desc="test")
        for count, idx in enumerate(pbar):
            data, targets = self.get_batch(source, idx, evaluation=True)
            output, hidden = self.shared(data, hidden, dag)
            output_flat = output.view(-1, self.dataset.num_tokens)
            total_loss += len(data) * self.ce(output_flat, targets).data
            hidden = detach(hidden)

            ppl = math.exp(total_loss[0] / (count + 1) / self.max_length)
            pbar.set_description(f"test| ppl: {ppl:8.2f}")

        test_loss = total_loss[0] / len(source)
        ppl = math.exp(test_loss)

        self.tb.scalar_summary(f"test/{name}_loss", test_loss, self.epoch)
        self.tb.scalar_summary(f"test/{name}_ppl", ppl, self.epoch)

        return test_loss, ppl

    def derive(self, valid_idx=0, sample_num=None):
        if sample_num is None:
            sample_num = self.args.derive_num_sample

        dags = self.controller.sample(sample_num)

        max_R, best_dag = 0, None
        pbar = tqdm(dags, desc="derive")
        for dag in pbar:
            R = self.get_reward(dag, valid_idx)
            if R > max_R:
                max_R = R
                best_dag = dag
            pbar.set_description(f"derive| max_R: {max_R:8.6f}")

        fname = f"{self.epoch:03d}-{self.controller_step:06d}-{max_R:6.4f}-best.png"
        path = os.path.join(self.args.model_dir, "networks", fname)
        draw_network(best_dag, path)
        self.tb.image_summary("derive/best", [path], self.epoch)

        return best_dag

    @property
    def shared_lr(self):
        degree = max(self.epoch - self.args.shared_decay_after + 1, 0)
        return self.args.shared_lr * (self.args.shared_decay**degree)

    @property
    def controller_lr(self):
        return self.args.controller_lr

    def get_batch(self, source, idx, length=None, evaluation=False):
        # code from https://github.com/pytorch/examples/blob/master/word_language_model/main.py
        length = min(length if length else self.max_length,
                     len(source) - 1 - idx)
        data = Variable(source[idx:idx + length], volatile=evaluation)
        target = Variable(source[idx + 1:idx + 1 + length].view(-1))
        return data, target

    @property
    def shared_path(self):
        return f'{self.args.model_dir}/shared_epoch{self.epoch}_step{self.shared_step}.pth'

    @property
    def controller_path(self):
        return f'{self.args.model_dir}/controller_epoch{self.epoch}_step{self.controller_step}.pth'

    def get_saved_models_info(self):
        paths = glob(os.path.join(self.args.model_dir, '*.pth'))
        paths.sort()

        def get_numbers(items, delimiter, idx, replace_word, must_contain=''):
            return list(
                set([
                    int(name.split(delimiter)[idx].replace(replace_word, ''))
                    for name in basenames if must_contain in name
                ]))

        basenames = [
            os.path.basename(path.rsplit('.', 1)[0]) for path in paths
        ]
        epochs = get_numbers(basenames, '_', 1, 'epoch')
        shared_steps = get_numbers(basenames, '_', 2, 'step', 'shared')
        controller_steps = get_numbers(basenames, '_', 2, 'step', 'controller')

        epochs.sort()
        shared_steps.sort()
        controller_steps.sort()

        return epochs, shared_steps, controller_steps

    def save_model(self):
        t.save(self.shared.state_dict(), self.shared_path)
        logger.info(f"[*] SAVED: {self.shared_path}")

        t.save(self.controller.state_dict(), self.controller_path)
        logger.info(f"[*] SAVED: {self.controller_path}")

        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        for epoch in epochs[:-self.args.max_save_num]:
            paths = glob(
                os.path.join(self.args.model_dir, f'*_epoch{epoch}_*.pth'))

            for path in paths:
                remove_file(path)

    def load_model(self):
        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        if len(epochs) == 0:
            logger.info(f"[!] No checkpoint found in {self.args.model_dir}...")
            return

        self.start_epoch = max(epochs)
        self.shared_step = max(shared_steps)
        self.controller_step = max(controller_steps)

        if self.args.num_gpu == 0:
            map_location = lambda storage, loc: storage
        else:
            map_location = None

        self.shared.load_state_dict(
            t.load(self.shared_path, map_location=map_location))
        logger.info(f"[*] LOADED: {self.shared_path}")

        self.controller.load_state_dict(
            t.load(self.controller_path, map_location=map_location))
        logger.info(f"[*] LOADED: {self.controller_path}")
コード例 #3
0
ファイル: trainer.py プロジェクト: yiqisetian/ENAS-pytorch
class Trainer(object):
    """A class to wrap training code."""
    def __init__(self, args, dataset):
        """Constructor for training algorithm.

        Args:
            args: From command line, picked up by `argparse`.
            dataset: Currently only `data.text.Corpus` is supported.

        Initializes:
            - Data: train, val and test.
            - Model: shared and controller.
            - Inference: optimizers for shared and controller parameters.
            - Criticism: cross-entropy loss for training the shared model.
        """
        self.args = args
        self.controller_step = 0
        self.cuda = args.cuda
        self.dataset = dataset
        self.epoch = 0
        self.shared_step = 0
        self.start_epoch = 0

        logger.info('regularizing:')
        for regularizer in [('activation regularization',
                             self.args.activation_regularization),
                            ('temporal activation regularization',
                             self.args.temporal_activation_regularization),
                            ('norm stabilizer regularization',
                             self.args.norm_stabilizer_regularization)]:
            if regularizer[1]:
                logger.info('{0}'.format(regularizer[0]))

        self.train_data = utils.batchify(dataset.train, args.batch_size,
                                         self.cuda)
        # NOTE(brendan): The validation set data is batchified twice
        # separately: once for computing rewards during the Train Controller
        # phase (valid_data, batch size == 64), and once for evaluating ppl
        # over the entire validation set (eval_data, batch size == 1)
        self.valid_data = utils.batchify(dataset.valid, args.batch_size,
                                         self.cuda)
        self.eval_data = utils.batchify(dataset.valid, args.test_batch_size,
                                        self.cuda)
        self.test_data = utils.batchify(dataset.test, args.test_batch_size,
                                        self.cuda)

        self.max_length = self.args.shared_rnn_max_length  # default=35

        if args.use_tensorboard:
            self.tb = TensorBoard(args.model_dir)
        else:
            self.tb = None
        self.build_model()  # 创建一个模型存入self.shared中,这里可以是RNN或CNN,再创建一个Controler

        if self.args.load_path:
            self.load_model()

        shared_optimizer = _get_optimizer(self.args.shared_optim)
        controller_optimizer = _get_optimizer(self.args.controller_optim)

        self.shared_optim = shared_optimizer(
            self.shared.parameters(),
            lr=self.shared_lr,
            weight_decay=self.args.shared_l2_reg)

        self.controller_optim = controller_optimizer(
            self.controller.parameters(), lr=self.args.controller_lr)

        self.ce = nn.CrossEntropyLoss()

    def build_model(self):
        """Creates and initializes the shared and controller models."""
        if self.args.network_type == 'rnn':
            self.shared = models.RNN(self.args, self.dataset)
        elif self.args.network_type == 'cnn':
            self.shared = models.CNN(self.args, self.dataset)
        else:
            raise NotImplementedError(
                'Network type `{0}` is not defined'.format(
                    self.args.network_type))
        self.controller = models.Controller(
            self.args
        )  # 构建了一个orward:Embedding(130,100)->lstm(100,100)->decoder的列表,对应25个decoder

        if self.args.num_gpu == 1:
            self.shared.cuda()
            self.controller.cuda()
        elif self.args.num_gpu > 1:
            raise NotImplementedError('`num_gpu > 1` is in progress')

    def train(self, single=False):
        """Cycles through alternately training the shared parameters and the
        controller, as described in Section 2.2, Training ENAS and Deriving
        Architectures, of the paper.

        From the paper (for Penn Treebank):

        - In the first phase, shared parameters omega are trained for 400
          steps, each on a minibatch of 64 examples.

        - In the second phase, the controller's parameters are trained for 2000
          steps.
          
        Args:
            single (bool): If True it won't train the controller and use the
                           same dag instead of derive().
        """
        dag = utils.load_dag(self.args) if single else None  # 初始训练dag=None

        if self.args.shared_initial_step > 0:  # self.args.shared_initial_step default=0
            self.train_shared(self.args.shared_initial_step)
            self.train_controller()

        for self.epoch in range(
                self.start_epoch,
                self.args.max_epoch):  # start_epoch=0,max_epoch=150
            # 1. Training the shared parameters omega of the child models
            # 训练RNN,先用Controller随机生成一个dag,然后用这个dag构建一个RNNcell,然后用这个RNNcell去做下一个词预测,得到loss
            self.train_shared(dag=dag)

            # 2. Training the controller parameters theta
            if not single:
                self.train_controller()

            if self.epoch % self.args.save_epoch == 0:
                with _get_no_grad_ctx_mgr():
                    best_dag = dag if dag else self.derive()
                    self.evaluate(self.eval_data,
                                  best_dag,
                                  'val_best',
                                  max_num=self.args.batch_size * 100)
                self.save_model()
            #应该是逐渐降低学习率
            if self.epoch >= self.args.shared_decay_after:
                utils.update_lr(self.shared_optim, self.shared_lr)

    def get_loss(self, inputs, targets, hidden, dags):
        """
        :param inputs:输入数据,[35,64]
        :param targets: 目标数据(相当于标签)[35,64] 输入的词后移一个词
        :param hidden: 隐藏层参数
        :param dags: RNN 的cell结构
        :return: decoded(35,64,10000),hidden(64,1000),extra_out{dropped_output(35,64,1000),h1tohT(35,64,1000),raw_output(35,64,1000)
        """
        """Computes the loss for the same batch for M models.

        This amounts to an estimate of the loss, which is turned into an
        estimate for the gradients of the shared model.
        """
        if not isinstance(dags, list):
            dags = [dags]

        loss = 0
        for dag in dags:
            # decoded(35,64,10000),hidden(64,1000),extra_out{dropped_output(35,64,1000),h1tohT(35,64,1000),raw_output(35,64,1000)
            output, hidden, extra_out = self.shared(
                inputs, dag, hidden=hidden)  # RNN.forward
            output_flat = output.view(-1,
                                      self.dataset.num_tokens)  # (2240,10000)
            # self.ce=nn.CrossEntropyLoss()  target(2240)  shared_num_sample=1
            sample_loss = (self.ce(output_flat, targets) /
                           self.args.shared_num_sample)
            loss += sample_loss

        assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`'
        return loss, hidden, extra_out

    def train_shared(self, max_step=None, dag=None):
        """Train the language model for 400 steps of minibatches of 64
        examples.

        Args:
            max_step: Used to run extra training steps as a warm-up.
            dag: If not None, is used instead of calling sample().

        BPTT is truncated at 35 timesteps.  #基于时间的反向传播算法BPTT(Back Propagation Trough Time)

        For each weight update, gradients are estimated by sampling M models
        from the fixed controller policy, and averaging their gradients
        computed on a batch of training data.
        """
        model = self.shared  # model.RNN
        model.train(
        )  # set RNN.training属性为true 即当前训练的是RNN而不训练Controller https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.train
        self.controller.eval(
        )  # Sets the module in evaluation mode. This is equivalent with self.train(False).
        # 功能:初始化variable,即全零的Tensor
        hidden = self.shared.init_hidden(self.args.batch_size)

        if max_step is None:
            max_step = self.args.shared_max_step  # shared_max_step=150
        else:
            max_step = min(self.args.shared_max_step, max_step)

        abs_max_grad = 0
        abs_max_hidden_norm = 0
        step = 0
        raw_total_loss = 0  # 用于统计结果的,和计算过程无关
        total_loss = 0
        train_idx = 0
        # TODO(brendan): Why - 1 - 1?为什么-1-1?
        # TODO(为什么-1-1)这里的train_idx是批次的编号,一共14524个batch(每个batch有64个词)为了训练输入数据不可能取最后一个batch
        # TODO(为什么-1-1)因为如果是最后一个batch就没有target了,因此最后一个batch是倒数第二个,而倒数第二个的下标是 size-2
        # self.train_data.size(0)   14524
        while train_idx < self.train_data.size(0) - 1 - 1:
            if step > max_step:
                break
            # Controller负责sample一个dag出来,是一个list,里面有一个defaultdict,存储了dag的连接信息
            # 这一步只是提取Controller的值,并没有训练,初始的时候也是随机得出来的一个dag
            dags = dag if dag else self.controller.sample(
                batch_size=self.args.shared_num_sample
            )  # shared_num_sample:default=1
            # 提取一个max_length长度的数据集(35,64),35个批次,每个批次64个词,组成一个训练批次
            # input是训练数据,target是每个输入的词后面的词,用于训练RNN的
            inputs, targets = self.get_batch(self.train_data, train_idx,
                                             self.max_length)  # max_length=35
            # get_loss完成了由dag生成的RNNcell的前向计算
            loss, hidden, extra_out = self.get_loss(inputs, targets, hidden,
                                                    dags)
            # Detaches the Tensor from the graph that created it, making it a leaf. Views cannot be detached in-place.
            hidden.detach_()
            raw_total_loss += loss.data
            # 根据命令行参数加一下正则惩罚项
            loss += _apply_penalties(extra_out, self.args)

            # update
            self.shared_optim.zero_grad()
            loss.backward()  # 反向更新

            h1tohT = extra_out['hiddens']
            # 和日志有关,和计算无关
            new_abs_max_hidden_norm = utils.to_item(
                h1tohT.norm(dim=-1).data.max())
            if new_abs_max_hidden_norm > abs_max_hidden_norm:
                abs_max_hidden_norm = new_abs_max_hidden_norm
                logger.info('max hidden {0}'.format(abs_max_hidden_norm))
            # 函数的功能是获取Tensor图中的最大梯度,来检测是否出现梯度爆炸,但好像后面没有使用
            abs_max_grad = _check_abs_max_grad(abs_max_grad, model)
            # Clips gradient norm of an iterable of parameters.
            # The norm is computed over all gradients together, as if they were concatenated into a single vector.
            # Gradients are modified in-place.
            torch.nn.utils.clip_grad_norm(
                model.parameters(),
                self.args.shared_grad_clip)  # shared_grad_clip=0.25
            self.shared_optim.step()  # Performs a single optimization step.

            total_loss += loss.data
            # 和log有关
            if ((step % self.args.log_step) == 0) and (step > 0):
                self._summarize_shared_train(total_loss, raw_total_loss)
                raw_total_loss = 0
                total_loss = 0

            step += 1
            self.shared_step += 1
            train_idx += self.max_length  # max_length:35,下一个batch

    def get_reward(self, dag, entropies, hidden, valid_idx=0):
        """Computes the perplexity of a single sampled model on a minibatch of
        validation data.
        计算模型的PPL:每个词的条件预测概率(即已知前n个词预测第n+1个词的概率)的累积的倒数开N(全体词的数量)次方
        """
        if not isinstance(entropies, np.ndarray):
            entropies = entropies.data.cpu().numpy()

        inputs, targets = self.get_batch(self.valid_data,
                                         valid_idx,
                                         self.max_length,
                                         volatile=True)
        valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden,
                                              dag)  #RNN.forward
        valid_loss = utils.to_item(valid_loss.data)

        valid_ppl = math.exp(valid_loss)  #计算PPL

        # TODO: we don't know reward_c
        if self.args.ppl_square:  #default:false
            # TODO: but we do know reward_c=80 in the previous paper
            R = self.args.reward_c / valid_ppl**2
        else:
            R = self.args.reward_c / valid_ppl  #这个值的作用在NAS(Zoph and Le, 2017) page 8 states that c is a constant

        if self.args.entropy_mode == 'reward':  #entroy_mode:default:reward
            rewards = R + self.args.entropy_coeff * entropies  # entropy_coeff:default=1e-4
        elif self.args.entropy_mode == 'regularizer':
            rewards = R * np.ones_like(entropies)
        else:
            raise NotImplementedError('Unkown entropy mode: {0}'.format(
                self.args.entropy_mode))

        return rewards, hidden

    def train_controller(self):
        """Fixes the shared parameters and updates the controller parameters.

        The controller is updated with a score function gradient estimator
        (i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl
        is computed on a minibatch of validation data.

        A moving average baseline is used.

        The controller is trained for 2000 steps per epoch (i.e.,
        first (Train Shared) phase -> second (Train Controller) phase).
        """
        model = self.controller
        model.train()  # 设置Controller的train属性为true,当前训练Controller
        # 这里为什么不调用hared.eval()? 这是因为会导致Controller的loss一直为零。
        # self.shared.eval(),上面的解释应该是Brendon这个人测试之后的结论

        avg_reward_base = None
        baseline = None
        # 这几个是用于统计信息的
        adv_history = []
        entropy_history = []
        reward_history = []

        hidden = self.shared.init_hidden(self.args.batch_size)
        total_loss = 0
        valid_idx = 0
        for step in range(self.args.controller_max_step):  #controller_max_step
            # sample models
            #dags:list([1])(defaultdict([25])),log_probs:Tensor.size([23]),entropies:Tensor.size([23])交叉熵:-ylogy
            dags, log_probs, entropies = self.controller.sample(
                with_details=True)

            # calculate reward
            np_entropies = entropies.data.cpu().numpy()
            # NOTE(brendan): No gradients should be backpropagated to the
            # shared model during controller training, obviously.
            """
            with 语句实质是上下文管理。
            1、上下文管理协议。包含方法__enter__() 和 __exit__(),支持该协议对象要实现这两个方法。
            2、上下文管理器,定义执行with语句时要建立的运行时上下文,负责执行with语句块上下文中的进入与退出操作。
            3、进入上下文的时候执行__enter__方法,如果设置as var语句,var变量接受__enter__()方法返回值。
            4、如果运行时发生了异常,就退出上下文管理器。调用管理器__exit__方法。
            """
            # 创建了一个torch.no_grad()的上下文,执行get_reward的时候是不需要计算梯度的,执行完get_reward在恢复计算梯度模式
            with _get_no_grad_ctx_mgr():
                rewards, hidden = self.get_reward(dags, np_entropies, hidden,
                                                  valid_idx)

            # discount  默认未启用
            if 1 > self.args.discount > 0:  #discout:default=1
                rewards = discount(rewards, self.args.discount)

            reward_history.extend(rewards)
            entropy_history.extend(np_entropies)

            # moving average baseline
            if baseline is None:
                baseline = rewards
            else:
                decay = self.args.ema_baseline_decay  #****ema_baseline_decay:default=0.95  very important
                baseline = decay * baseline + (1 - decay) * rewards

            adv = rewards - baseline
            adv_history.extend(adv)

            # policy loss
            loss = -log_probs * utils.get_variable(
                adv, self.cuda, requires_grad=False)
            if self.args.entropy_mode == 'regularizer':  #entropy_mode:default='reward'
                loss -= self.args.entropy_coeff * entropies

            loss = loss.sum()  # or loss.mean()

            # update
            self.controller_optim.zero_grad()
            loss.backward()

            if self.args.controller_grad_clip > 0:
                torch.nn.utils.clip_grad_norm(model.parameters(),
                                              self.args.controller_grad_clip)
            self.controller_optim.step()

            total_loss += utils.to_item(loss.data)

            if ((step % self.args.log_step) == 0) and (step > 0):
                self._summarize_controller_train(total_loss, adv_history,
                                                 entropy_history,
                                                 reward_history,
                                                 avg_reward_base, dags)

                reward_history, adv_history, entropy_history = [], [], []
                total_loss = 0

            self.controller_step += 1

            prev_valid_idx = valid_idx
            valid_idx = ((valid_idx + self.max_length) %
                         (self.valid_data.size(0) - 1))
            # NOTE(brendan): Whenever we wrap around to the beginning of the
            # validation data, we reset the hidden states.
            if prev_valid_idx > valid_idx:
                hidden = self.shared.init_hidden(self.args.batch_size)

    def evaluate(self, source, dag, name, batch_size=1, max_num=None):
        """Evaluate on the validation set.

        NOTE(brendan): We should not be using the test set to develop the
        algorithm (basic machine learning good practices).
        """
        self.shared.eval()
        self.controller.eval()

        data = source[:max_num * self.max_length]

        total_loss = 0
        hidden = self.shared.init_hidden(batch_size)

        pbar = range(0, data.size(0) - 1, self.max_length)
        for count, idx in enumerate(pbar):
            inputs, targets = self.get_batch(data, idx, volatile=True)
            output, hidden, _ = self.shared(inputs,
                                            dag,
                                            hidden=hidden,
                                            is_train=False)
            output_flat = output.view(-1, self.dataset.num_tokens)
            total_loss += len(inputs) * self.ce(output_flat, targets).data
            hidden.detach_()
            ppl = math.exp(
                utils.to_item(total_loss) / (count + 1) / self.max_length)

        val_loss = utils.to_item(total_loss) / len(data)
        ppl = math.exp(val_loss)

        self.tb.scalar_summary('eval/{0}_loss'.format(name), val_loss,
                               self.epoch)
        self.tb.scalar_summary('eval/{0}_ppl'.format(name), ppl, self.epoch)
        logger.info('eval | loss: {0:8.2f} | ppl: {1:8.2f}'.format(
            val_loss, ppl))

    def derive(self, sample_num=None, valid_idx=0):
        """TODO(brendan): We are always deriving based on the very first batch
        of validation data? This seems wrong...
        """
        hidden = self.shared.init_hidden(self.args.batch_size)

        if sample_num is None:
            sample_num = self.args.derive_num_sample

        dags, _, entropies = self.controller.sample(sample_num,
                                                    with_details=True)

        max_R = 0
        best_dag = None
        for dag in dags:
            R, _ = self.get_reward(dag, entropies, hidden, valid_idx)
            if R.max() > max_R:
                max_R = R.max()
                best_dag = dag

        logger.info('derive | max_R: {0:8.6f}'.format(max_R))
        fname = ('{0:03d}-{1:06d}-{2:6.4f}-best.png'.format(
            self.epoch, self.controller_step, max_R))
        path = os.path.join(self.args.model_dir, 'networks', fname)
        #utils.draw_network(best_dag, path)
        #self.tb.image_summary('derive/best', [path], self.epoch)

        return best_dag

    @property
    def shared_lr(self):
        degree = max(self.epoch - self.args.shared_decay_after + 1, 0)
        return self.args.shared_lr * (self.args.shared_decay**degree)

    @property  #将类方法转换为类属性,可以用 . 直接获取属性值或者对属性进行赋值
    def controller_lr(self):
        return self.args.controller_lr

    def get_batch(self, source, idx, length=None, volatile=False):
        """
        这个函数的作用是从数据集中取得length长度的数据组成一个Variable(这个操作在pytorch中已经过时了,可以直接使用Tensor来生成计算,而不用
        再使用Variable来封装Tensor来计算
        这里的batch指的是取词窗口组成的batch,length是最多取多少个batch_size的词
        :param source:数据集train_data
        :param idx: 当前数据样本索引值
        :param length:max_length=35?
        :param volatile(易变的):Volatile is recommended for purely inference mode, when you’re sure you won’t be even calling .backward()
        设定volatie选项为true的话则只是取值模式,而不会进行反向计算
        :return:
        """
        # code from
        # https://github.com/pytorch/examples/blob/master/word_language_model/main.py
        length = min(length if length else self.max_length,
                     len(source) - 1 - idx)
        #UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.
        data = Variable(source[idx:idx + length],
                        volatile=volatile)  # shape(35,64) 取35个批次,每个批次64个词
        target = Variable(source[idx + 1:idx + 1 + length].view(-1),
                          volatile=volatile)  # view(35,64)->(2240)
        # 这里target=data+1的意思是从data中推断下一个词
        return data, target

    @property
    def shared_path(self):
        return '{0}/shared_epoch{1:d}_step{2:d}.pth'.format(
            self.args.model_dir, self.epoch, self.shared_step)

    @property
    def controller_path(self):
        return '{}/controller_epoch{}_step{}.pth'.format(
            self.args.model_dir, self.epoch, self.controller_step)

    def get_saved_models_info(self):
        paths = glob.glob(os.path.join(self.args.model_dir, '*.pth'))
        paths.sort()

        def get_numbers(items, delimiter, idx, replace_word, must_contain=''):
            return list(
                set([
                    int(name.split(delimiter)[idx].replace(replace_word, ''))
                    for name in basenames if must_contain in name
                ]))

        basenames = [
            os.path.basename(path.rsplit('.', 1)[0]) for path in paths
        ]
        epochs = get_numbers(basenames, '_', 1, 'epoch')
        shared_steps = get_numbers(basenames, '_', 2, 'step', 'shared')
        controller_steps = get_numbers(basenames, '_', 2, 'step', 'controller')

        epochs.sort()
        shared_steps.sort()
        controller_steps.sort()

        return epochs, shared_steps, controller_steps

    def save_model(self):
        torch.save(self.shared.state_dict(), self.shared_path)
        logger.info('[*] SAVED: {0}'.format(self.shared_path))

        torch.save(self.controller.state_dict(), self.controller_path)
        logger.info('[*] SAVED: {0}'.format(self.controller_path))

        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        for epoch in epochs[:-self.args.max_save_num]:
            paths = glob.glob(
                os.path.join(self.args.model_dir,
                             '*_epoch{0}_*.pth'.format(epoch)))

            for path in paths:
                utils.remove_file(path)

    def load_model(self):
        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        if len(epochs) == 0:
            logger.info('[!] No checkpoint found in {0}...'.format(
                self.args.model_dir))
            return

        self.epoch = self.start_epoch = max(epochs)
        self.shared_step = max(shared_steps)
        self.controller_step = max(controller_steps)

        if self.args.num_gpu == 0:
            map_location = lambda storage, loc: storage
        else:
            map_location = None

        self.shared.load_state_dict(
            torch.load(self.shared_path, map_location=map_location))
        logger.info('[*] LOADED: {0}'.format(self.shared_path))

        self.controller.load_state_dict(
            torch.load(self.controller_path, map_location=map_location))
        logger.info('[*] LOADED: {0}'.format(self.controller_path))

    def _summarize_controller_train(self, total_loss, adv_history,
                                    entropy_history, reward_history,
                                    avg_reward_base, dags):
        """Logs the controller's progress for this training epoch."""
        cur_loss = total_loss / self.args.log_step

        avg_adv = np.mean(adv_history)
        avg_entropy = np.mean(entropy_history)
        avg_reward = np.mean(reward_history)

        if avg_reward_base is None:
            avg_reward_base = avg_reward

        logger.info(
            '| epoch {0:3d} | lr {1:.5f} | R {2:.5f} | entropy {3:.4f} | loss {:.5f}'
            .format(self.epoch, self.controller_lr, avg_reward, avg_entropy,
                    cur_loss))

        # Tensorboard
        if self.tb is not None:
            self.tb.scalar_summary('controller/loss', cur_loss,
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward', avg_reward,
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward-B_per_epoch',
                                   avg_reward - avg_reward_base,
                                   self.controller_step)
            self.tb.scalar_summary('controller/entropy', avg_entropy,
                                   self.controller_step)
            self.tb.scalar_summary('controller/adv', avg_adv,
                                   self.controller_step)

            paths = []
            for dag in dags:
                fname = ('{0:03d}-{1:06d}-{2:6.4f}.png'.format(
                    self.epoch, self.controller_step, avg_reward))
                path = os.path.join(self.args.model_dir, 'networks', fname)
                utils.draw_network(dag, path)
                paths.append(path)

            self.tb.image_summary('controller/sample', paths,
                                  self.controller_step)

    def _summarize_shared_train(self, total_loss, raw_total_loss):
        """Logs a set of training steps."""
        cur_loss = utils.to_item(total_loss) / self.args.log_step
        # NOTE(brendan): The raw loss, without adding in the activation
        # regularization terms, should be used to compute ppl.
        cur_raw_loss = utils.to_item(raw_total_loss) / self.args.log_step
        ppl = math.exp(cur_raw_loss)

        logger.info(
            '| epoch {0:3d} | lr {1:4.2f} | raw loss {2:.2f} | loss {3:.2f} | ppl {4:8.2f}'
            .format(self.epoch, self.shared_lr, cur_raw_loss, cur_loss, ppl))

        # Tensorboard
        if self.tb is not None:
            self.tb.scalar_summary('shared/loss', cur_loss, self.shared_step)
            self.tb.scalar_summary('shared/perplexity', ppl, self.shared_step)
コード例 #4
0
class Trainer(object):
    """A class to wrap training code."""
    def __init__(self, args, dataset):
        """Constructor for training algorithm.

        Args:
            args: From command line, picked up by `argparse`.
            dataset: Currently only `data.text.Corpus` is supported.

        Initializes:
            - Data: train, val and test.
            - Model: shared and controller.
            - Inference: optimizers for shared and controller parameters.
            - Criticism: cross-entropy loss for training the shared model.
        """
        self.args = args
        self.controller_step = 0
        self.cuda = args.cuda
        self.device = gpu = torch.device("cuda:0")
        self.dataset = dataset
        self.epoch = 0
        self.shared_step = 0
        self.start_epoch = 0
        self.compute_fisher = False

        logger.info('regularizing:')
        for regularizer in [('activation regularization',
                             self.args.activation_regularization),
                            ('temporal activation regularization',
                             self.args.temporal_activation_regularization),
                            ('norm stabilizer regularization',
                             self.args.norm_stabilizer_regularization)]:
            if regularizer[1]:
                logger.info(f'{regularizer[0]}')

        self.image_dataset = isinstance(dataset, Image)
        if self.image_dataset:
            self._train_data = dataset.train
            self._valid_data = dataset.valid
            self._test_data = dataset.test
            self._eval_data = dataset.valid
            self.train_data = wrap_iterator_with_name(self._train_data,
                                                      'train')
            self.valid_data = wrap_iterator_with_name(self._valid_data,
                                                      'valid')
            self.test_data = wrap_iterator_with_name(self._test_data, 'test')
            self.eval_data = wrap_iterator_with_name(self._eval_data, 'eval')

            self.max_length = 0

        else:
            self.train_data = utils.batchify(dataset.train, args.batch_size,
                                             self.cuda)
            self.valid_data = utils.batchify(dataset.valid, args.batch_size,
                                             self.cuda)
            self.eval_data = utils.batchify(dataset.valid,
                                            args.test_batch_size, self.cuda)
            self.test_data = utils.batchify(dataset.test, args.test_batch_size,
                                            self.cuda)

            self.max_length = self.args.shared_rnn_max_length

        self.train_data_size = self.train_data.size(
            0) if not self.image_dataset else len(self.train_data)
        self.valid_data_size = self.valid_data.size(
            0) if not self.image_dataset else len(self.valid_data)
        self.test_data_size = self.test_data.size(
            0) if not self.image_dataset else len(self.test_data)

        # Visualization
        if args.use_tensorboard:
            self.tb = TensorBoard(args.model_dir)
        else:
            self.tb = None
        self.draw_network = utils.draw_network

        self.build_model()

        if self.args.load_path:
            self.load_model()

        shared_optimizer = _get_optimizer(self.args.shared_optim)
        controller_optimizer = _get_optimizer(self.args.controller_optim)

        # As fisher information, and it should be seen by this model, to get the loss.

        self.shared_optim = shared_optimizer(
            self.shared.parameters(),
            lr=self.shared_lr,
            weight_decay=self.args.shared_l2_reg)

        self.controller_optim = controller_optimizer(
            self.controller.parameters(), lr=self.args.controller_lr)

        self.ce = nn.CrossEntropyLoss()
        self.top_k_acc = top_k_accuracy

    def build_model(self):
        """Creates and initializes the shared and controller models."""
        if self.args.network_type == 'rnn':
            self.shared = models.RNN(self.args, self.dataset)
            self.controller = models.Controller(self.args)
        elif self.args.network_type == 'micro_cnn':
            self.shared = models.CNN(self.args, self.dataset)
            self.controller = models.CNNMicroController(self.args)
        else:
            raise NotImplementedError(f'Network type '
                                      f'`{self.args.network_type}` is not '
                                      f'defined')

        if self.args.num_gpu == 1:
            if torch.__version__ == '0.3.1':
                self.shared.cuda()
                self.controller.cuda()
            else:
                self.shared.to(self.device)
                self.controller.to(self.device)

        elif self.args.num_gpu > 1:
            raise NotImplementedError('`num_gpu > 1` is in progress')

    def train(self):
        """Cycles through alternately training the shared parameters and the
        controller, as described in Section 2.2, Training ENAS and Deriving
        Architectures, of the paper.

        From the paper (for Penn Treebank):

        - In the first phase, shared parameters omega are trained for 400
          steps, each on a minibatch of 64 examples.

        - In the second phase, the controller's parameters are trained for 2000
          steps.
        """
        if self.args.shared_initial_step > 0:
            self.train_shared(self.args.shared_initial_step)
            self.train_controller()

        for self.epoch in range(self.start_epoch, self.args.max_epoch):

            if self.epoch >= self.args.start_using_fisher:
                self.compute_fisher = True

            if self.args.set_fisher_zero_per_iter > 0 \
                    and self.epoch % self.args.set_fisher_zero_per_iter == 0:
                self.shared.set_fisher_zero()

            # 1. Training the shared parameters omega of the child models
            self.train_shared()

            # 2. Training the controller parameters theta
            if self.args.train_controller:
                if self.epoch < self.args.stop_training_controller:
                    self.train_controller()

            if self.epoch % self.args.save_epoch == 0:
                with _get_no_grad_ctx_mgr():
                    best_dag = self.derive()
                    self.evaluate(self.eval_data,
                                  best_dag,
                                  'val_best',
                                  max_num=self.args.batch_size * 100)
                self.save_model()

            if self.epoch >= self.args.shared_decay_after:
                utils.update_lr(self.shared_optim, self.shared_lr)

    def get_loss(self, inputs, targets, dags, **kwargs):
        """Computes the loss for the same batch for M models.

        This amounts to an estimate of the loss, which is turned into an
        estimate for the gradients of the shared model.

        We store, compute the new WPL.
        :param **kwargs: passed into self.shared(, such as hidden)
        """
        if not isinstance(dags, list):
            dags = [dags]

        loss = 0
        for dag in dags:
            output, hidden, extra_out = self.shared(inputs, dag, **kwargs)
            output_flat = output.view(-1, self.dataset.num_classes)
            sample_loss = (self.ce(output_flat, targets) /
                           self.args.shared_num_sample)

            # Get WPL part
            if self.compute_fisher:
                wpl = self.shared.compute_weight_plastic_loss_with_update_fisher(
                    dag)
                wpl = 0.5 * wpl
                loss += sample_loss + wpl
                rest_loss = wpl
            else:
                loss += sample_loss
                rest_loss = Variable(torch.zeros(1))
                # logger.info(f'Loss {loss.data[0]} = '
                #             f'sample_loss {sample_loss.data[0]}')

        #assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`'
        return loss, sample_loss, rest_loss, hidden, extra_out

    def train_shared(self, max_step=None):
        """Train the language model for 400 steps of minibatches of 64
        examples.

        Args:
            max_step: Used to run extra training steps as a warm-up.

        BPTT is truncated at 35 timesteps.

        For each weight update, gradients are estimated by sampling M models
        from the fixed controller policy, and averaging their gradients
        computed on a batch of training data.
        """
        valid_ppls = []
        valid_ppls_after = []

        model = self.shared
        model.train()
        self.controller.eval()

        hidden = self.shared.init_training(self.args.batch_size)
        v_hidden = self.shared.init_training(self.args.batch_size)

        if max_step is None:
            max_step = self.args.shared_max_step
        else:
            max_step = min(self.args.shared_max_step, max_step)

        abs_max_grad = 0
        abs_max_hidden_norm = 0
        step = 0
        raw_total_loss = 0
        total_loss = 0
        total_sample_loss = 0
        total_rest_loss = 0
        train_idx = 0
        valid_idx = 0

        def _run_shared_one_batch(inputs, targets, hidden, dags,
                                  raw_total_loss):
            # global abs_max_grad
            # global abs_max_hidden_norm
            # global raw_total_loss
            loss, sample_loss, rest_loss, hidden, extra_out = self.get_loss(
                inputs, targets, dags, hidden=hidden)

            # Detach the hidden
            # Because they are input from previous state.
            hidden = utils.detach(hidden)
            raw_total_loss += sample_loss.data / self.args.num_batch_per_iter
            penalty_loss = _apply_penalties(extra_out, self.args)
            loss += penalty_loss
            rest_loss += penalty_loss
            return loss, sample_loss, rest_loss, hidden, extra_out, raw_total_loss

        def _clip_gradient(abs_max_grad, abs_max_hidden_norm):

            h1tohT = extra_out['hiddens']
            new_abs_max_hidden_norm = utils.to_item(
                h1tohT.norm(dim=-1).data.max())
            if new_abs_max_hidden_norm > abs_max_hidden_norm:
                abs_max_hidden_norm = new_abs_max_hidden_norm
                logger.info(f'max hidden {abs_max_hidden_norm}')
            abs_max_grad = _check_abs_max_grad(abs_max_grad, model)
            torch.nn.utils.clip_grad_norm(model.parameters(),
                                          self.args.shared_grad_clip)
            return abs_max_grad, abs_max_hidden_norm

        def _evaluate_valid(dag):
            hidden_eval = self.shared.init_training(self.args.batch_size)
            inputs_eval, targets_eval = self.get_batch(self.valid_data,
                                                       0,
                                                       self.max_length,
                                                       volatile=True)
            _, valid_loss_eval, _, _, _ = self.get_loss(inputs_eval,
                                                        targets_eval,
                                                        dag,
                                                        hidden=hidden_eval)
            valid_loss_eval = utils.to_item(valid_loss_eval.data)
            valid_ppl_eval = math.exp(valid_loss_eval)
            # return valid_ppl_eval

        dags_eval = []
        while train_idx < self.train_data_size - 1 - 1:
            if step > max_step:
                break
            dags = self.controller.sample(self.args.shared_num_sample)
            dags_eval.append(dags[0])
            for b in range(0, self.args.num_batch_per_iter):
                # For each model, do the update for 30 batches.
                inputs, targets = self.get_batch(self.train_data, train_idx,
                                                 self.max_length)

                loss, sample_loss, rest_loss, hidden, extra_out, raw_total_loss = \
                    _run_shared_one_batch(
                        inputs, targets, hidden, dags, raw_total_loss)

                # update with complete logic
                # First, normally we compute one loss and do update accordingly.
                # if in the last batch, we compute the fisher information
                # based on two kinds of loss, complete or ce-loss only.
                self.shared_optim.zero_grad()

                # If it is the last training batch. Update the Fisher information
                if self.compute_fisher and (not self.args.shared_valid_fisher):
                    if b == self.args.num_batch_per_iter - 1:
                        sample_loss.backward()
                        if self.args.shared_ce_fisher:
                            self.shared.update_fisher(dags[0])
                            rest_loss.backward()
                        else:
                            rest_loss.backward()
                            self.shared.update_fisher(dags[0])
                    else:
                        loss.backward()
                else:
                    loss.backward()

                abs_max_grad, abs_max_hidden_norm = _clip_gradient(
                    abs_max_grad, abs_max_hidden_norm)

                self.shared_optim.step()

                total_loss += loss.data / self.args.num_batch_per_iter
                total_sample_loss += sample_loss.data / self.args.num_batch_per_iter
                total_rest_loss += rest_loss.data / self.args.num_batch_per_iter

                train_idx = ((train_idx + self.max_length) %
                             (self.train_data_size - 1))

            if self.epoch > self.args.start_evaluate_diff:
                valid_ppl_eval = _evaluate_valid(dags[0])
                valid_ppls.append(valid_ppl_eval)

            logger.info(
                f'Step {step}'
                f'Loss {utils.to_item(total_loss) / (step + 1):.5f} = '
                f'sample_loss {utils.to_item(total_sample_loss) / (step + 1):.5f} + '
                f'wpl {utils.to_item(total_rest_loss) / (step + 1):.5f}')

            if ((step % self.args.log_step) == 0) and (step > 0):
                self._summarize_shared_train(total_loss, raw_total_loss)
                raw_total_loss = 0
                total_loss = 0
                total_sample_loss = 0
                total_rest_loss = 0

            if self.compute_fisher:
                # Update with the validation dataset for fisher information after each step,
                # with update the optimal weights.
                v_inputs, v_targets = self.get_batch(self.valid_data,
                                                     valid_idx,
                                                     self.max_length)
                v_loss, v_sample_loss, _, v_hidden, v_extra_out, _ = _run_shared_one_batch(
                    v_inputs, v_targets, v_hidden, dags, 0)
                self.shared_optim.zero_grad()
                if self.args.shared_ce_fisher:
                    v_sample_loss.backward()
                else:
                    v_loss.backward()
                self.shared.update_fisher(dags[0], self.epoch)
                self.shared.update_optimal_weights()
                valid_idx = ((valid_idx + self.max_length) %
                             (self.valid_data_size - 1))

            step += 1
            self.shared_step += 1

        if self.epoch > self.args.start_evaluate_diff:
            for arch in dags_eval:
                valid_ppl_eval = _evaluate_valid(arch)
                valid_ppls_after.append(valid_ppl_eval)
                logger.info(f'valid_ppl {valid_ppl_eval}')
            diff = np.array(valid_ppls_after) - np.array(valid_ppls)
            logger.info(f'Mean_diff {np.mean(diff)}')
            logger.info(f'Max_diff {np.amax(diff)}')
            self.tb.scalar_summary(f'Mean difference', np.mean(diff),
                                   self.epoch)
            self.tb.scalar_summary(f'Max difference', np.amax(diff),
                                   self.epoch)
            self.tb.scalar_summary(f'Mean valid_ppl after training',
                                   np.mean(np.array(valid_ppls_after)),
                                   self.epoch)
            self.tb.scalar_summary(f'Mean valid_ppl before training',
                                   np.mean(np.array(valid_ppls)), self.epoch)
            self.tb.scalar_summary(f'std_diff', np.std(np.array(diff)),
                                   self.epoch)

    def get_reward(self, dags, entropies, hidden, valid_idx=None):
        """
        Computes the reward of a single sampled model or multiple on a minibatch of
        validation data.

        """
        if not isinstance(entropies, np.ndarray):
            entropies = entropies.data.cpu().numpy()

        if valid_idx is None:
            valid_idx = 0

        inputs, targets = self.get_batch(self.valid_data,
                                         valid_idx,
                                         self.max_length,
                                         volatile=True)
        _, valid_loss, _, hidden, _ = self.get_loss(inputs,
                                                    targets,
                                                    dags,
                                                    hidden=hidden)
        valid_loss = utils.to_item(valid_loss.data)

        valid_ppl = math.exp(valid_loss)

        if self.args.ppl_square:
            R = self.args.reward_c / valid_ppl**2
        else:
            R = self.args.reward_c / valid_ppl

        if self.args.entropy_mode == 'reward':
            rewards = R + self.args.entropy_coeff * entropies
        elif self.args.entropy_mode == 'regularizer':
            rewards = R * np.ones_like(entropies)
        else:
            raise NotImplementedError(
                f'Unkown entropy mode: {self.args.entropy_mode}')

        return rewards, hidden

    def train_controller(self):
        """Fixes the shared parameters and updates the controller parameters.

        The controller is updated with a score function gradient estimator
        (i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl
        is computed on a minibatch of validation data.

        A moving average baseline is used.

        The controller is trained for 2000 steps per epoch (i.e.,
        first (Train Shared) phase -> second (Train Controller) phase).
        """
        model = self.controller
        model.train()

        avg_reward_base = None
        baseline = None
        adv_history = []
        entropy_history = []
        reward_history = []
        hidden = self.shared.init_training(self.args.batch_size)
        total_loss = 0
        valid_idx = 0

        for step in range(self.args.controller_max_step):
            # print("************ train controller ****************")
            # sample models
            dags, log_probs, entropies = self.controller.sample(
                batch_size=self.args.policy_batch_size, with_details=True)

            # calculate reward
            np_entropies = entropies.data.cpu().numpy()
            with _get_no_grad_ctx_mgr():
                rewards, hidden = self.get_reward(dags, np_entropies, hidden,
                                                  valid_idx)

            # discount
            if 1 > self.args.discount > 0:
                rewards = discount(rewards, self.args.discount)

            reward_history.extend(rewards)
            entropy_history.extend(np_entropies)

            # moving average baseline
            if baseline is None:
                baseline = rewards
            else:
                decay = self.args.ema_baseline_decay
                baseline = decay * baseline + (1 - decay) * rewards

            adv = rewards - baseline
            adv_history.extend(adv)

            # policy loss
            loss = -log_probs * utils.get_variable(
                adv, self.cuda, requires_grad=False)
            if self.args.entropy_mode == 'regularizer':
                loss -= self.args.entropy_coeff * entropies

            loss = loss.sum()  # or loss.mean()

            # update
            self.controller_optim.zero_grad()
            loss.backward()

            if self.args.controller_grad_clip > 0:
                torch.nn.utils.clip_grad_norm(model.parameters(),
                                              self.args.controller_grad_clip)
            self.controller_optim.step()

            total_loss += utils.to_item(loss.data)

            if ((step % self.args.log_step) == 0) and (step > 0):
                self._summarize_controller_train(total_loss, adv_history,
                                                 entropy_history,
                                                 reward_history,
                                                 avg_reward_base, dags)

                reward_history, adv_history, entropy_history = [], [], []
                total_loss = 0

            self.controller_step += 1

            prev_valid_idx = valid_idx
            valid_idx = ((valid_idx + self.max_length) %
                         (self.valid_data_size - 1))
            if prev_valid_idx > valid_idx:
                hidden = self.shared.init_training(self.args.batch_size)

    def evaluate(self, source, dag, name, batch_size=1, max_num=None):
        """Evaluate on the validation set.
        """
        self.shared.eval()
        self.controller.eval()

        if self.image_dataset:
            data = source
        else:
            data = source[:max_num * self.max_length]

        total_loss = 0
        hidden = self.shared.init_training(batch_size)

        pbar = range(0, self.valid_data_size - 1, self.max_length)
        for count, idx in enumerate(pbar):
            inputs, targets = self.get_batch(data, idx, volatile=True)
            output, hidden, _ = self.shared(inputs,
                                            dag,
                                            hidden=hidden,
                                            is_train=False)
            output_flat = output.view(-1, self.dataset.num_classes)
            total_loss += len(inputs) * self.ce(output_flat, targets).data
            hidden = utils.detach(hidden)
            ppl = math.exp(
                utils.to_item(total_loss) / (count + 1) / self.max_length)

        val_loss = utils.to_item(total_loss) / len(data)
        ppl = math.exp(val_loss)

        self.tb.scalar_summary(f'eval/{name}_loss', val_loss, self.epoch)
        self.tb.scalar_summary(f'eval/{name}_ppl', ppl, self.epoch)
        logger.info(f'eval | loss: {val_loss:8.2f} | ppl: {ppl:8.2f}')

    def derive(self, sample_num=None, valid_idx=0):

        if sample_num is None:
            sample_num = self.args.derive_num_sample

        dags, _, entropies = self.controller.sample(sample_num,
                                                    with_details=True)

        max_R = 0
        best_dag = None
        for dag in dags:
            if self.image_dataset:
                R, _ = self.get_reward([dag], entropies, valid_idx)
            else:
                hidden = self.shared.init_training(self.args.batch_size)
                R, _ = self.get_reward(dag, entropies, hidden, valid_idx)
            if R.max() > max_R:
                max_R = R.max()
                best_dag = dag

        logger.info(f'derive | max_R: {max_R:8.6f}')
        fname = (f'{self.epoch:03d}-{self.controller_step:06d}-'
                 f'{max_R:6.4f}-best.png')
        path = os.path.join(self.args.model_dir, 'networks', fname)
        success = self.draw_network(best_dag, path)
        if success:
            self.tb.image_summary('derive/best', [path], self.epoch)

        return best_dag

    def reset_dataloader_by_name(self, name):
        """ Works for only reset _DataLoaderIter by DataLoader with name """
        try:
            new_iter = wrap_iterator_with_name(
                iter(getattr(self, f'_{name}_data')), name)
            setattr(self, f'{name}_data', new_iter)
        except Exception as e:
            print(e)
        return new_iter

    @property
    def shared_lr(self):
        degree = max(self.epoch - self.args.shared_decay_after + 1, 0)
        return self.args.shared_lr * (self.args.shared_decay**degree)

    @property
    def controller_lr(self):
        return self.args.controller_lr

    def get_batch(self, source, idx, length=None, volatile=False):
        # code from
        # https://github.com/pytorch/examples/blob/master/word_language_model/main.py

        if not self.image_dataset:
            length = min(length if length else self.max_length,
                         len(source) - 1 - idx)
            data = Variable(source[idx:idx + length], volatile=volatile)
            target = Variable(source[idx + 1:idx + 1 + length].view(-1),
                              volatile=volatile)
        else:
            # Try the dataloader logic.
            # type is _DataLoaderIter

            try:
                data, target = next(source)
            except StopIteration as e:
                print(f'{e}')
                name = source.name
                source = self.reset_dataloader_by_name(name)
                data, target = next(source)

            # data.to(self.device)
            return data.to(self.device), target.to(self.device)
        return data, target

    @property
    def shared_path(self):
        return f'{self.args.model_dir}/shared_epoch{self.epoch}_step{self.shared_step}.pth'

    @property
    def controller_path(self):
        return f'{self.args.model_dir}/controller_epoch{self.epoch}_step{self.controller_step}.pth'

    def get_saved_models_info(self):
        paths = glob.glob(os.path.join(self.args.model_dir, '*.pth'))
        paths.sort()

        def get_numbers(items, delimiter, idx, replace_word, must_contain=''):
            return list(
                set([
                    int(name.split(delimiter)[idx].replace(replace_word, ''))
                    for name in basenames if must_contain in name
                ]))

        basenames = [
            os.path.basename(path.rsplit('.', 1)[0]) for path in paths
        ]
        epochs = get_numbers(basenames, '_', 1, 'epoch')
        shared_steps = get_numbers(basenames, '_', 2, 'step', 'shared')
        controller_steps = get_numbers(basenames, '_', 2, 'step', 'controller')

        epochs.sort()
        shared_steps.sort()
        controller_steps.sort()

        return epochs, shared_steps, controller_steps

    def save_model(self):
        torch.save(self.shared.state_dict(), self.shared_path)
        logger.info(f'[*] SAVED: {self.shared_path}')

        torch.save(self.controller.state_dict(), self.controller_path)
        logger.info(f'[*] SAVED: {self.controller_path}')

        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        for epoch in epochs[:-self.args.max_save_num]:
            paths = glob.glob(
                os.path.join(self.args.model_dir, f'*_epoch{epoch}_*.pth'))

            for path in paths:
                utils.remove_file(path)

    def load_model(self):
        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        if len(epochs) == 0:
            logger.info(f'[!] No checkpoint found in {self.args.model_dir}...')
            return

        self.epoch = self.start_epoch = max(epochs)
        self.shared_step = max(shared_steps)
        self.controller_step = max(controller_steps)

        if self.args.num_gpu == 0:
            map_location = lambda storage, loc: storage
        else:
            map_location = None

        self.shared.load_state_dict(
            torch.load(self.shared_path, map_location=map_location))
        logger.info(f'[*] LOADED: {self.shared_path}')

        self.controller.load_state_dict(
            torch.load(self.controller_path, map_location=map_location))
        logger.info(f'[*] LOADED: {self.controller_path}')

    def _summarize_controller_train(self, total_loss, adv_history,
                                    entropy_history, reward_history,
                                    avg_reward_base, dags):
        """Logs the controller's progress for this training epoch."""
        cur_loss = total_loss / self.args.log_step

        avg_adv = np.mean(adv_history)
        avg_entropy = np.mean(entropy_history)
        avg_reward = np.mean(reward_history)

        if avg_reward_base is None:
            avg_reward_base = avg_reward

        logger.info(f'| epoch {self.epoch:3d} | lr {self.controller_lr:.5f} '
                    f'| R {avg_reward:.5f} | entropy {avg_entropy:.4f} '
                    f'| loss {cur_loss:.5f}')

        # Tensorboard
        if self.tb is not None:
            self.tb.scalar_summary('controller/loss', cur_loss,
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward', avg_reward,
                                   self.controller_step)
            self.tb.scalar_summary('controller/std/reward',
                                   np.std(reward_history),
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward-B_per_epoch',
                                   avg_reward - avg_reward_base,
                                   self.controller_step)
            self.tb.scalar_summary('controller/entropy', avg_entropy,
                                   self.controller_step)
            self.tb.scalar_summary('controller/adv', avg_adv,
                                   self.controller_step)

            paths = []
            for dag in dags:
                fname = (f'{self.epoch:03d}-{self.controller_step:06d}-'
                         f'{avg_reward:6.4f}.png')
                path = os.path.join(self.args.model_dir, 'networks', fname)
                self.draw_network(dag, path)
                paths.append(path)

            self.tb.image_summary('controller/sample', paths,
                                  self.controller_step)

    def _summarize_shared_train(self, total_loss, raw_total_loss):
        """Logs a set of training steps."""
        cur_loss = utils.to_item(total_loss) / self.args.log_step
        cur_raw_loss = utils.to_item(raw_total_loss) / self.args.log_step
        try:
            ppl = math.exp(cur_raw_loss)
        except RuntimeError as e:
            print(f"Got error {e}")

        logger.info(f'| epoch {self.epoch:3d} '
                    f'| lr {self.shared_lr:4.2f} '
                    f'| raw loss {cur_raw_loss:.2f} '
                    f'| loss {cur_loss:.2f} '
                    f'| ppl {ppl:8.2f}')

        # Tensorboard
        if self.tb is not None:
            self.tb.scalar_summary('shared/loss', cur_loss, self.shared_step)
            self.tb.scalar_summary('shared/perplexity', ppl, self.shared_step)
コード例 #5
0
ファイル: trainer.py プロジェクト: stsaten6/ENAS-cnn
class Trainer(object):
    """A class to wrap training code."""
    def __init__(self, args, dataset):
        """Constructor for training algorithm.

        Args:
            args: From command line, picked up by 'argparse'
            dataset: Currently only `data.text.Corpus` is supported.

        Initializes:
            - Data: train, val and test.
            - Model: shared and controller.
            - Inference: optimizers for shared and controller parameters.
            - Criticism: cross-entropy loss for training the shared model.
        """
        #TODO   加个检查准确率的
        self.args = args
        self.controller_step = 0
        self.cuda = args.cuda
        self.dataset = dataset
        self.epoch = 0
        self.shared_step = 0
        self.start_epoch = 0

        print('regularizing:')
        for regularizer in [('activation regularization',
                             self.args.activation_regularization),
                            ('temporal activation regularization',
                             self.args.temporal_activation_regularization),
                            ('norm stabilizer regularization',
                             self.args.norm_stabilizer_regularization)]:
            if regularizer[1]:
                print(f'{regularizer[0]}')

        # self.train_data = utils.batchify(dataset.train,
        #                                  args.batch_size,
        #                                  self.cuda)
        # NOTE(brendan): The validation set data is batchified twice
        # separately: once for computing rewards during the Train Controller
        # phase (valid_data, batch size == 64), and once for evaluating ppl
        # over the entire validation set (eval_data, batch size == 1)
        self.train_data = dataset.train
        self.valid_data = dataset.valid
        self.test_data = dataset.test
        # self.max_length = self.args.shared_rnn_max_length

        if args.use_tensorboard:
            self.tb = TensorBoard(args.model_dir)
        else:
            self.tb = None
        #TODO initialize controller and shared model
        self.build_model()
        # print("11111111")
        if self.args.load_path:
            print("=======load_path=======")
            self.load_model()

        shared_optimizer = _get_optimizer(self.args.shared_optim)
        controller_optimizer = _get_optimizer(self.args.controller_optim)
        print("=======make optimizer========")
        self.shared_optim = shared_optimizer(
            self.shared.parameters(),
            lr=self.shared_lr,
            weight_decay=self.args.shared_l2_reg)
        print("=======make optimizer========")
        self.controller_optim = controller_optimizer(
            self.controller.parameters(), lr=self.args.controller_lr)

        self.ce = nn.CrossEntropyLoss()
        print("finish init")

    def build_model(self):
        """Creates and initializes the shared and controller models."""
        if self.args.network_type == 'rnn':
            self.shared = models.RNN(self.args, self.dataset)
        elif self.args.network_type == 'cnn':
            print("----- begin to init cnn------")
            self.shared = models.CNN(self.args, self.dataset)
            # self.shared = self.shared.cuda()
        else:
            raise NotImplementedError(f'Network type '
                                      f'`{self.args.network_type}` is not '
                                      f'defined')
        print("---- begin to init controller-----")
        self.controller = models.Controller(self.args)
        #self.controller = self.controller.cuda()
        print("===begin to cuda")
        if True:
            print("cuda")
            self.shared.cuda()
            self.controller.cuda()
            print("finish cuda")
        elif self.args.num_gpu > 1:
            raise NotImplementedError('`num_gpu > 1` is in process')

    def train(self):
        """Cycles through alternately training the shared parameters and the
        controller, as described in Section2.4 Training ENAS and deriving
        Architectures, of the paraer.
        """
        if self.args.shared_initial_step > 0:
            self.train_shared(self.args.shared_initial_step)
            self.train_controller()

        for self.epoch in range(self.start_epoch, self.args.max_epoch):
            # 1. Training the shared parameters omega of the child models
            self.train_shared()

            # 2. Training the controller parameters theta
            #self.train_controller()
            if self.epoch == 0:
                with _get_no_grad_ctx_mgr():
                    best_dag = self.derive()
                    self.evaluate(iter(self.test_data),
                                  best_dag,
                                  'val_best',
                                  max_num=self.args.batch_size * 100)
                self.save_model()

            if self.epoch % self.args.save_epoch == 0:
                with _get_no_grad_ctx_mgr():
                    best_dag = self.derive()
                    self.evaluate(iter(self.test_data),
                                  best_dag,
                                  'val_best',
                                  max_num=self.args.batch_size * 100)
                self.save_model()

            if self.epoch >= self.args.shared_decay_after:
                utils.update_lr(self.shared_optim, self.shared_lr)

    def get_loss(self, inputs, targets, dags):
        """Computes the loss for the same batch for M models.

        This amounts to an estimate of the loss, which is turned into an
        estimate for the gradients of the shared model.
        """
        if not isinstance(dags, list):
            dags = [dags]

        loss = 0
        for dag in dags:
            inputs = Variable(inputs.cuda())
            targets = Variable(targets.cuda())
            # inputs = inputs.cuda()
            #targets = targets.cuda()
            #self.shared = self.shared.cuda()
            output = self.shared(inputs, dag)
            sample_loss = (self.ce(output, targets) /
                           self.args.shared_num_sample)
            loss += sample_loss

        assert len(
            dags) == 1, 'there are multiple `hidden` for multiple `dags`'
        return loss

    def train_shared(self, max_step=None):
        """Train the image classification model for 310 steps
        """
        #TODO check if it is right that create a new dag for every batch and may be
        #one epoch one bathc will improve efficient
        model = self.shared
        model.train()
        self.controller.eval()

        if max_step is None:
            max_step = self.args.shared_max_step
        else:
            max_step = min(self.args.shared_max_step, max_step)

        step = 0
        raw_total_loss = 0
        total_loss = 0
        # train_idx = 0
        train_iter = iter(self.train_data)
        #TODO understanding how it train
        while True:
            if step > max_step:
                break
            dags = self.controller.sample(self.args.shared_num_sample)
            #print(dags)
            #TODO use iterator to create batch but need to add StopIteration
            #may be have some method to improve
            try:
                inputs, targets = train_iter.next()
            except StopIteration:
                print("====>train_shared<====== finish one epoch")
                break
                train_iter = iter(self.train_data)
            #print(dags)
            loss = self.get_loss(inputs, targets, dags)
            raw_total_loss += loss.data
            #TODO understand penality
            # loss += _apply_penalties()
            self.shared_optim.zero_grad()
            loss.backward()

            self.shared_optim.step()

            total_loss += loss.data
            #if step % 20 == 0:
            #    print("loss, ", total_loss, step, total_loss /(step+1))

            if ((step % self.args.log_step) == 0) and (step > 0):
                self._summarize_shared_train(total_loss, raw_total_loss)
                raw_total_loss = 0
                total_loss = 0

            step += 1
            self.shared_step += 1
            # train_idx += self.max_length
    def get_reward(self, dag, entropies, data_iter):
        """Computes the perplexity of a single sampled model on a minibatch of
        validation data.
        """
        if not isinstance(entropies, np.ndarray):
            entropies = entropies.data.cpu().numpy()
        try:
            inputs, targets = data_iter.next()
        except StopIteration:
            data_iter = iter(self.valid_data)
            inputs, targets = data_iter.next()
        #TODO 怎么做volidate
        valid_loss = self.get_loss(inputs, targets, dag)
        # convert valid_loss to numpy ndarray
        valid_loss = utils.to_item(valid_loss.data)

        valid_ppl = math.exp(valid_loss)

        # TODO we don't knoe reward_c
        if self.args.ppl_square:
            #TODO: but we do know reward_c =80 in the previous paper need to read previous paper
            R = self.args.reward_c / valid_ppl**2
        else:
            R = self.args.reward_c / valid_ppl

        if self.args.entropy_mode == 'reward':
            rewards = R + self.args.entropy_coeff * entropies
        elif self.args.entropy_mode == 'regularizer':
            rewards = R * np.ones_like(entropies)
        else:
            raise NotImplementedError(
                f'Unknown entropy mode: {self.args.entropy_mode}')

        return rewards

    def train_controller(self):
        """Fixes the shared parameters and updates the controller parameters.

        The controller is updated with a score function gradient estimator
        (i.e., REINFORCE), with the reward being c/valid_ppl. where valid_ppl
        is computed on a minibatch of vlaidation data.

        A moving average baseline is used.

        The controller is trained for 2000 steps per epoch (i.e.,
        first (Train Shared) phase -. Second (Train Controller) phase).
        """
        model = self.controller
        model.train()

        avg_reward_base = None
        baseline = None
        adv_history = []
        entropy_history = []
        reward_history = []
        valid_iter = iter(self.valid_data)
        total_loss = 0
        for step in range(self.args.controller_max_step):

            dags, log_probs, entropies = self.controller.sample(
                with_details=True)
            #print(dags)
            np_entropies = entropies.data.cpu().numpy()

            with _get_no_grad_ctx_mgr():
                rewards = self.get_reward(dags, np_entropies, valid_iter)
            if 1 > self.args.discount > 0:
                rewards = discount(rewards, self.args.discount)

            reward_history.extend(rewards)
            entropy_history.extend(np_entropies)

            # moving average baseline
            if baseline is None:
                baseline = rewards
            else:
                decay = self.args.ema_baseline_decay
                baseline = decay * baseline + (1 - decay) * rewards

            adv = rewards - baseline
            adv_history.extend(adv)

            #policy loss
            loss = -log_probs * utils.get_variable(
                adv, self.cuda, requires_grad=False)
            if self.args.entropy_mode == 'regularizer':
                loss -= self.args.entropy_coeff * np_entropies

            loss = loss.sum()

            self.controller_optim.zero_grad()
            loss.backward()

            if self.args.controller_grad_clip > 0:
                torch.nn.utils.clip_grad_norm(model.parameters(),
                                              self.args.controller_grad_clip)
            self.controller_optim.step()

            total_loss += utils.to_item(loss.data)
            #if step%20 ==0:
            #    print("total loss", total_loss, step, total_loss / (step+1))
            if ((step % self.args.log_step) == 0) and (step > 0):
                self._summarize_controller_train(total_loss, adv_history,
                                                 entropy_history,
                                                 reward_history,
                                                 avg_reward_base, dags)
                reward_history, adv_history, entropy_history = [], [], []
                total_loss = 0
            self.controller_step += 1

            # prev_valid_idx = valid_idx
            # valid_idx = ((valid_idx + self.max_length) %
            #             (self.valid_data.size(0) - 1))

            # NOTE(brendan): Whenever we wrap around to the beginning of the
            # validation data, we reset the hidden states.

    def evaluate(self, test_iter, dag, name, batch_size=1, max_num=None):
        """Evaluate on the validation set.
        (lianqing)what is the data of source ?

        NOTE: use validation to check reward but test set is the same as valid set
        """
        self.shared.eval()
        self.controller.eval()
        acc = AverageMeter()
        # data = source[:max_num*self.max_length]
        total_loss = 0
        # pbar = range(0, data.size(0) - 1, self.max_length)
        count = 0
        while True:
            try:
                count += 1
                inputs, targets = next(test_iter)
            except StopIteration:
                print("========> finish evaluate on one epoch<======")
                break
                test_iter = iter(self.test_data)
                inputs, targets = next(test_iter)
                # inputs = Variable(inputs)
            #check if is train the controller will have what difference
            inputs = Variable(inputs.cuda())
            targets = Variable(targets.cuda())
            # inputs = inputs.cuda()
            #targets = targets.cuda()
            output = self.shared(inputs, dag, is_train=False)
            # check is self.loss wil work ?:
            total_loss += len(inputs) * self.ce(output, targets).data
            ppl = math.exp(utils.to_item(total_loss) / (count + 1))
            acc.update(utils.get_accuracy(targets, output))
        val_loss = utils.to_item(total_loss) / count
        ppl = math.exp(val_loss)
        #TODO it's fix for rnn need to fix for cnn
        #self.tb.scalar_summary(f'eval/{name}_loss', val_loss, self.epoch)
        #self.tb.scalar_summary(f'eval/{name}_ppl', ppl, self.epoch)
        print(
            f'eval | loss: {val_loss:8.2f} | ppl: {ppl:8.2f} | accuracy: {acc.avg:8.2f}'
        )

    def derive(self, sample_num=None, valid_iter=None):
        """
        pass sample_num is always to 1 test if batch_size > 1 will work ? for controller.sample
        """
        if sample_num is None:
            sample_num = self.args.derive_num_sample
        if valid_iter == None:
            valid_iter = iter(self.valid_data)
        dags, _, entropies = self.controller.sample(sample_num,
                                                    with_details=True)
        max_R = 0
        best_dag = None
        for dag in dags:
            R = self.get_reward(dag, entropies, valid_iter)
            if R.max() > max_R:
                max_R = R.max()
                best_dag = dag

        print(f'derive | max_R: {max_R:8.6f}')
        fname = (f'{self.epoch:03d}-{self.controller_step:06d}-'
                 f'{max_R:6.4}-best.png')
        path = os.path.join(self.args.model_dir, 'networks', fname)
        # utils.draw_network(best_dag, path)
        # self.tb.image_summary('derive/best', [path], self.epoch)

        return best_dag

    @property
    def shared_lr(self):
        degree = max(self.epoch - self.args.shared_decay_after + 1, 0)
        return self.args.shared_lr * (self.args.shared_decay**degree)

    @property
    def controller_lr(self):
        return self.args.controller_lr

    @property
    def shared_path(self):
        return f'{self.args.model_dir}/shared_epoch{self.epoch}_step{self.shared_step}.pth'

    @property
    def controller_path(self):
        return f'{self.args.model_dir}/controller_epoch{self.epoch}_step{self.controller_step}.pth'

    def get_saved_models_info(self):
        paths = glob.glob(os.path.join(self.args.model_dir, '*.pth'))
        paths.sort()

        def get_numbers(items, delimiter, idx, replace_word, must_contain=''):
            return list(
                set([
                    int(name.split(delimiter)[idx].replace(replace_word, ''))
                    for name in basenames if must_contain in name
                ]))

        basenames = [
            os.path.basename(path.rsplit('.', 1)[0]) for path in paths
        ]
        epochs = get_numbers(basenames, '_', 1, 'epoch')
        shared_steps = get_numbers(basenames, '_', 2, 'step', 'shared')
        controller_steps = get_numbers(basenames, '_', 2, 'step', 'controller')

        epochs.sort()
        shared_steps.sort()
        controller_steps.sort()

        return epochs, shared_steps, controller_steps

    def save_model(self):
        torch.save(self.shared.state_dict(), self.shared_path)
        print(f'[*] SAVED: {self.shared_path}')

        torch.save(self.controller.state_dict(), self.controller_path)
        print(f'[*] SAVED: {self.controller_path}')

        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        for epoch in epochs[:-self.args.max_save_num]:
            paths = glob.glob(
                os.path.join(self.args.model_dir, f'*_epoch{epoch}_*.pth'))

            for path in paths:
                utils.remove_file(path)

    def load_model(self):
        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        if len(epochs) == 0:
            print(f'[!] No checkpoint found in {self.args.model_dir}...')
            return

        self.epoch = self.start_epoch = max(epochs)
        self.shared_step = max(shared_steps)
        self.controller_step = max(controller_steps)

        if self.args.num_gpu == 0:
            map_location = lambda storage, loc: storage
        else:
            map_location = None

        self.shared.load_state_dict(
            torch.load(self.shared_path, map_location=map_location))
        print(f'[*] LOADED: {self.shared_path}')

        self.controller.load_state_dict(
            torch.load(self.controller_path, map_location=map_location))
        print(f'[*] LOADED: {self.controller_path}')

    def _summarize_controller_train(self, total_loss, adv_history,
                                    entropy_history, reward_history,
                                    avg_reward_base, dags):
        """Logs the controller's progress for this training epoch."""
        cur_loss = total_loss / self.args.log_step

        avg_adv = np.mean(adv_history)
        avg_entropy = np.mean(entropy_history)
        avg_reward = np.mean(reward_history)

        if avg_reward_base is None:
            avg_reward_base = avg_reward

        print(f'| epoch {self.epoch:3d} | lr {self.controller_lr:.5f} '
              f'| R {avg_reward:.5f} | entropy {avg_entropy:.4f} '
              f'| loss {cur_loss:.5f}')

        # Tensorboard
        if self.tb is not None:
            self.tb.scalar_summary('controller/loss', cur_loss,
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward', avg_reward,
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward-B_per_epoch',
                                   avg_reward - avg_reward_base,
                                   self.controller_step)
            self.tb.scalar_summary('controller/entropy', avg_entropy,
                                   self.controller_step)
            self.tb.scalar_summary('controller/adv', avg_adv,
                                   self.controller_step)

            paths = []
            for dag in dags:
                fname = (f'{self.epoch:03d}-{self.controller_step:06d}-'
                         f'{avg_reward:6.4f}.png')
                path = os.path.join(self.args.model_dir, 'networks', fname)
                # utils.draw_network(dag, path)
                paths.append(path)

            self.tb.image_summary('controller/sample', paths,
                                  self.controller_step)

    def _summarize_shared_train(self, total_loss, raw_total_loss):
        """Logs a set of training steps."""
        cur_loss = utils.to_item(total_loss) / self.args.log_step
        # NOTE(brendan): The raw loss, without adding in the activation
        # regularization terms, should be used to compute ppl.
        cur_raw_loss = utils.to_item(raw_total_loss) / self.args.log_step
        ppl = math.exp(cur_raw_loss)

        print(f'| epoch {self.epoch:3d} '
              f'| lr {self.shared_lr:4.2f} '
              f'| raw loss {cur_raw_loss:.2f} '
              f'| loss {cur_loss:.2f} '
              f'| ppl {ppl:8.2f}')

        # Tensorboard
        if self.tb is not None:
            self.tb.scalar_summary('shared/loss', cur_loss, self.shared_step)
            self.tb.scalar_summary('shared/perplexity', ppl, self.shared_step)