예제 #1
0
class Trainer(object):
    def __init__(self, args, dataset):
        self.args = args
        self.cuda = args.cuda
        self.dataset = dataset
        if args.network_type in ['seq2seq'] and args.dataset in ['msrvtt']:
            self.train_data = dataset['train']
            self.valid_data = dataset['val']
            self.test_data = dataset['test']
        else:
            raise Exception(f"Unknown network type: {args.network_type} and unknown dataset: {args.dataset} combination !!")

        if args.use_tensorboard and args.mode == 'train':
            self.tb = TensorBoard(args.model_dir)
        else:
            self.tb = None
        self.build_model()

        if self.args.load_path:
            self.load_model()

        if self.args.loss_function in ['rl','xe+rl'] and self.args.reward_type=='CIDEnt':
            self.build_load_entailment_model()

    def build_model(self):
        self.start_epoch = self.epoch = 0
        self.step = 0
        if self.args.network_type == 'seq2seq':
            self.model = Seq2seqAttention(self.args)
        else:
            raise NotImplemented(f"Network type `{self.args.network_type}` is not defined")

        if self.args.num_gpu == 1:
            self.model.cuda()
        elif self.args.num_gpu > 1:
            raise NotImplemented("`num_gpu > 1` is in progress")

        self.ce = nn.CrossEntropyLoss()
        logger.info(f"[*] # Parameters: {self.count_parameters}")

    def build_load_entailment_model(self):
        logger.info(f"Building Entailment model...")
        vocab = data.common_loader.Vocab(self.args.snli_vocab_file, self.args.max_snli_vocab_size)
        self.entailment_data = data.common_loader.SNLIBatcher(self.args.decoder_rnn_max_length, vocab)
        self.entailment_model = CoattMaxPool(self.args)
        if self.args.num_gpu == 1:
            self.entailment_model.cuda()
        self.entailment_model.load_state_dict(
        t.load(self.args.load_entailment_path, map_location=None))
        logger.info(f"[*] LOADED: {self.args.load_entailment_path}")
        


    def train(self):
        optimizer = get_optimizer(self.args.optim)
        self.optim = optimizer(
                self.model.parameters(),
                lr=self.args.lr)

        for self.epoch in range(self.start_epoch, self.args.max_epoch):  
            self.train_model()
            if self.epoch % self.args.save_epoch == 0:
                scores = self.test(mode='val')
                self.save_model(save_criteria_score=scores)


    def train_model(self):
        total_loss = 0
        model = self.model
        model.train()
 
        pbar = tqdm(total=self.train_data.num_steps, desc="train_model")

        batcher = self.train_data.get_batcher()

        for step in range(0,self.train_data.num_steps): 
            batch = next(batcher)
            if self.args.network_type == 'seq2seq':
                video_features = batch.get('video_batch')
                flengths = batch.get('video_len_batch')
                captions = batch.get('caption_batch')
                clengths = batch.get('caption_len_batch')
                video_features = to_var(self.args, video_features)
                captions = to_var(self.args, captions)
                if self.args.loss_function == 'xe':     
                    outputs = self.model(video_features, flengths, captions, clengths)
                    targets = pack_padded_sequence(captions, clengths, batch_first=True)[0]
                    loss = self.ce(outputs, targets)
                elif self.args.loss_function in ['rl', 'xe+rl']:
                    sampled_sequence, outputs = self.model.sample_rl(video_features, flengths, sampling='multinomial')
                    sampled_sequence_numpy = sampled_sequence.cpu().data.numpy()
                    argmax_sequence,_ = self.model.sample_rl(video_features, flengths, sampling='argmax')
                    argmax_sequence_numpy = argmax_sequence.cpu().data.numpy()
                    reward, seq_lengths = self.calculate_reward(sampled_sequence_numpy, batch.get('original_caption_dict'), batch.get('video_id'), 
                                                            self.train_data.vocab)
                    base_reward, _ = self.calculate_reward(argmax_sequence_numpy, batch.get('original_caption_dict'), batch.get('video_id'), 
                                                    self.train_data.vocab)
                    reward = reward - base_reward
                    reward = Variable(torch.FloatTensor(reward).cuda(), requires_grad=True).unsqueeze(2)

                    log_prob = F.log_softmax(outputs, 2)

                    target_one_hot = Variable(torch.FloatTensor(log_prob.size()).cuda().zero_().scatter_(2, sampled_sequence.unsqueeze(2).data, 1.0), requires_grad=True)

                    loss = -log_prob * target_one_hot * reward.expand_as(log_prob)

                    loss = loss.sum()/Variable(torch.FloatTensor(seq_lengths).cuda(), requires_grad=True).sum()
                    if self.args.loss_function == 'xe+rl':
                        outputs = pack_padded_sequence(outputs, clengths, batch_first=True)[0]
                        targets = pack_padded_sequence(captions, clengths, batch_first=True)[0]
                        ml_loss = self.ce(outputs,targets)
                        loss  = self.args.gamma_ml_rl * loss + (1-self.args.gamma_ml_rl) * ml_loss
            else:
                raise Exception(f"Unknown network type: {self.args.network_type}")
            # update
            self.optim.zero_grad()
            loss.backward()

            t.nn.utils.clip_grad_norm(
                    model.parameters(), self.args.grad_clip)
            self.optim.step()

            total_loss += loss.data
            pbar.set_description(f"train_model| loss: {loss.data[0]:5.3f}")
            
            if step % self.args.log_step == 0 and step > 0:
                cur_loss = total_loss[0] / self.args.log_step
                ppl = math.exp(cur_loss)

                logger.info(f'| epoch {self.epoch:3d} | lr {self.args.lr:8.6f} '
                            f'| loss {cur_loss:.2f} | ppl {ppl:8.2f}')

                # Tensorboard
                if self.tb is not None:
                    self.tb.scalar_summary("model/loss", cur_loss, self.step)
                    self.tb.scalar_summary("model/perplexity", ppl, self.step)

                total_loss = 0

            step += 1
            self.step += 1

            pbar.update(1)



    def test(self, mode):

        self.model.eval()
        counter = 0
        if mode == 'val':
            batcher = self.valid_data.get_batcher()
            num_steps = self.valid_data.num_steps
        elif mode == 'test':
            batcher = self.test_data.get_batcher()
            num_steps = self.test_data.num_steps
        else:
            raise Exception("Unknow mode: {}".format(mode))

        if self.args.network_type == 'seq2seq':
            gts = {}
            res = {}
            for i in range(num_steps):
                batch = next(batcher)
                video_features = batch.get('video_batch')
                flengths = batch.get('video_len_batch')    
                video_features = to_var(self.args, video_features)
                if self.args.beam_size>1:
                    predicted_targets = self.model.beam_search(video_features, flengths, self.args.beam_size)
                else:
                    predicted_targets = self.model.sample(video_features, flengths)
                    predicted_targets = predicted_targets.cpu().data.numpy()
                for k,vid in enumerate(batch.get('video_id')):
                    caption = [self.valid_data.vocab.id2word(id_) for id_ in predicted_targets[k,:]]
                    punctuation = np.argmax(np.array(caption) == '[END]')

                    if punctuation == 0 and not caption:
                        caption = caption
                    else: 
                        caption = caption[:punctuation]
                        caption = ' '.join(caption)
                    if not caption:
                        caption = '[UNK]'

                    print(caption)

                    res[counter] = [caption]
                    gts[counter] = batch.get('original_caption_dict')[vid]
                    counter += 1


            scores = evaluate(gts, res, score_type='macro', tokenized=True)
            scores_dict = {}
            save_criteria_score = None
            logger.info("Results:")
            for method, score in scores:
                if mode == 'val':
                    self.tb.scalar_summary(f"test/{mode}_{method}", score, self.epoch)
                scores_dict[method] = score
                logger.info("{}:{}".format(method,score))
                if self.args.save_criteria == method:
                    save_criteria_score = score

            if mode == 'test':
                # save the result
                if not self.args.load_path.endswith('.pth'):
                    if not os.path.exists(os.path.join(self.args.model_dir,'results')):
                        os.mkdir(os.path.join(self.args.model_dir,'results'))

                    result_save_path = self.result_path
                    final_dict = {}
                    final_dict['args'] = self.args.__dict__
                    final_dict['scores'] = scores_dict
                    with open(result_save_path, 'w') as fp:
                        json.dump(final_dict, fp, indent=4, sort_keys=True)

            return save_criteria_score


    def calculate_reward(self,sampled_sequence, gts, video_ids, vocab):
        """
        :param sampled_sequence:
            sampled sequence in the form of token_ids of size : batch_size x max_steps
        :param ref_sequence:
            dictionary of reference captions for the given videos
        :param video_ids:
            list of the video_ids
        :param vocab:
            vocab class object used to convert token ids to words
        :param reward_type:
            specify the reward
        :return rewards:
            rewards obtained from the sampled seq w.r.t. ref_seq (metric scores)
        :return seq_lens
            sampled sequence lengths array of size batch_size
        """

        res = {}
        gts_tmp = {}
        seq_lens = []
        batch_size, step_size = sampled_sequence.shape
        counter = 0
        for k in range(batch_size):
            caption = [vocab.id2word(id_) for id_ in sampled_sequence[k,:]]
            # print caption
            punctuation = np.argmax(np.array(caption) == STOP_DECODING)
            if punctuation == 0 and not caption:
                caption = caption
            else: 
                caption = caption[:punctuation]
                caption = ' '.join(caption)

            if not caption:
                caption = UNKNOWN_TOKEN

            res[counter] = [caption]
            gts_tmp[counter] = gts[video_ids[k]]
            counter +=1 
            seq_lens.append(len(caption.split())+1)

        _,reward = evaluate(gts_tmp,res,metric='CIDEr' if self.args.reward_type=='CIDEnt' else self.args.reward_type ,score_type='micro',tokenized=True)[0]
        
        if self.args.reward_type == 'CIDEnt':
            entailment_scores  = self.compute_entailment_scores(gts_tmp, res)


            reward = [x-self.args.lambda_threshold if y<self.args.beta_threshold else x for x,y in zip(reward, entailment_scores)]

        reward = np.array(reward)

        reward = np.reshape(reward,[batch_size,1])

        return reward, np.array(seq_lens)


    def compute_entailment_scores(self,gts,res,length_norm=False):
        scores = []
        for key, value in res.items():
            tmp_prem = gts[key]
            tmp_hypo = [value[0] for _ in range(len(tmp_prem))] 
            batch = self.entailment_data.process_external_data(tmp_prem,tmp_hypo)
            premise = batch.get('premise_batch')
            premise_len = batch.get('premise_length')
            premise = to_var(self.args, premise)
            hypothesis = batch.get('hypothesis_batch')
            hypothesis_len = batch.get('hypothesis_length')
            hypothesis = to_var(self.args, hypothesis)
            self.entailment_model.eval()
            logits, batch_prob, preds = self.entailment_model(premise, premise_len, hypothesis, hypothesis_len)

            batch_prob = batch_prob.cpu().data.numpy()

            scores.append(batch_prob.max())



        return scores


    
    def save_model(self, save_criteria_score=None):
        t.save(self.model.state_dict(), self.path)
        logger.info(f"[*] SAVED: {self.path}")
        epochs, steps  = self.get_saved_models_info()
        
        if save_criteria_score is not None:
            if os.path.exists(os.path.join(self.args.model_dir,'checkpoint_tracker.dat')):
                checkpoint_tracker = t.load(os.path.join(self.args.model_dir,'checkpoint_tracker.dat'))

            else:
                checkpoint_tracker = {}
            key = f"{self.epoch}_{self.step}"
            value = save_criteria_score
            checkpoint_tracker[key] = value
            if len(epochs)>=self.args.max_save_num:
                low_value = 100000.0
                remove_key = None
                for key,value in checkpoint_tracker.items():
                    if low_value > value:
                        remove_key = key
                        low_value = value

                del checkpoint_tracker[remove_key]

                remove_epoch = remove_key.split("_")[0]
                paths = glob(os.path.join(self.args.model_dir,f'*_epoch{remove_epoch}_*.pth'))
                for path in paths:
                    remove_file(path)

            # save back the checkpointer tracker
            t.save(checkpoint_tracker, os.path.join(self.args.model_dir,'checkpoint_tracker.dat'))

        else:
 

            for epoch in epochs[:-self.args.max_save_num]:
                paths = glob(os.path.join(self.args.model_dir, f'*_epoch{epoch}_*.pth'))

                for path in paths:
                    remove_file(path)


    def get_saved_models_info(self):
        paths = glob(os.path.join(self.args.model_dir, '*.pth'))
        paths.sort()

        def get_numbers(items, delimiter, idx, replace_word, must_contain=''):
            return list(set([int(
                    name.split(delimiter)[idx].replace(replace_word, ''))
                    for name in basenames if must_contain in name]))

        basenames = [os.path.basename(path.rsplit('.', 1)[0]) for path in paths]
        epochs = get_numbers(basenames, '_', 1, 'epoch')
        steps = get_numbers(basenames, '_', 2, 'step', 'model')

        epochs.sort()
        steps.sort()


        return epochs, steps
    



    def load_model(self):
        
        if self.args.load_path.endswith('.pth'):
            map_location=None
            self.model.load_state_dict(
                t.load(self.args.load_path, map_location=map_location))
            logger.info(f"[*] LOADED: {self.args.load_path}")
        else:
            if os.path.exists(os.path.join(self.args.load_path,'checkpoint_tracker.dat')):
                checkpoint_tracker = t.load(os.path.join(self.args.load_path,'checkpoint_tracker.dat'))
                best_key = None
                best_score = -1.0
                for key,value in checkpoint_tracker.items():
                    if value>best_score:
                        best_score = value
                        best_key = key


                self.epoch = int(best_key.split("_")[0])
                self.step = int(best_key.split("_")[1])

            else:
                epochs, steps = self.get_saved_models_info()

                if len(epochs) == 0:
                    logger.info(f"[!] No checkpoint found in {self.args.model_dir}...")
                    return

                self.epoch = self.start_epoch = max(epochs)
                self.step = max(steps)
            
            if self.args.num_gpu == 0:
                map_location = lambda storage, loc: storage
            else:
                map_location = None

            self.model.load_state_dict(
                    t.load(self.load_path, map_location=map_location))
            logger.info(f"[*] LOADED: {self.load_path}")


    def create_result_path(self, filename):
        return f'{self.args.model_dir}/results/model_epoch{self.epoch}_step{self.step}_{filename}'


    @property
    def count_parameters(self):
        return sum(p.numel() for p in self.model.parameters() if p.requires_grad)

    @property
    def path(self):
        return f'{self.args.model_dir}/model_epoch{self.epoch}_step{self.step}.pth'

    @property
    def load_path(self):
        return f'{self.args.load_path}/model_epoch{self.epoch}_step{self.step}.pth'

    @property
    def result_path(self):
        return f'{self.args.model_dir}/results/model_epoch{self.epoch}_step{self.step}.json'

    @property
    def lr(self):
        degree = max(self.epoch - self.args.decay_after + 1, 0)
        return self.args.lr * (self.args.decay ** degree)
예제 #2
0
class Trainer(object):
    """A class to wrap training code."""
    def __init__(self, args, dataset):
        """Constructor for training algorithm.

        Args:
            args: From command line, picked up by `argparse`.
            dataset: Currently only `data.text.Corpus` is supported.

        Initializes:
            - Data: train, val and test.
            - Model: shared and controller.
            - Inference: optimizers for shared and controller parameters.
            - Criticism: cross-entropy loss for training the shared model.
        """
        self.args = args
        if self.args.cuda:
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        self.controller_step = 0
        self.cuda = args.cuda
        self.dataset = dataset
        self.epoch = 0
        self.shared_step = 0
        self.start_epoch = 0

        # logger.info('regularizing:')
        # for regularizer in [('activation regularization',
        #                      self.args.activation_regularization),
        #                     ('temporal activation regularization',
        #                      self.args.temporal_activation_regularization),
        #                     ('norm stabilizer regularization',
        #                      self.args.norm_stabilizer_regularization)]:
        #     if regularizer[1]:
        #         logger.info(f'{regularizer[0]}')

        self.train_data = utils.batchify(dataset.train, args.batch_size,
                                         self.cuda)
        # NOTE(brendan): The validation set data is batchified twice
        # separately: once for computing rewards during the Train Controller
        # phase (valid_data, batch size == 64), and once for evaluating ppl
        # over the entire validation set (eval_data, batch size == 1)
        self.valid_data = utils.batchify(dataset.valid, args.batch_size,
                                         self.cuda)
        self.eval_data = utils.batchify(dataset.valid, args.test_batch_size,
                                        self.cuda)
        self.test_data = utils.batchify(dataset.test, args.test_batch_size,
                                        self.cuda)

        self.max_length = self.args.shared_rnn_max_length

        if args.use_tensorboard:
            self.tb = TensorBoard(args.model_dir)
        else:
            self.tb = None
        self.build_model()

        if self.args.load_path:
            self.load_model()

        shared_optimizer = _get_optimizer(self.args.shared_optim)
        controller_optimizer = _get_optimizer(self.args.controller_optim)

        self.shared_optim = shared_optimizer(
            self.shared.parameters(),
            lr=self.shared_lr,
            weight_decay=self.args.shared_l2_reg)

        self.controller_optim = controller_optimizer(
            self.controller.parameters(), lr=self.args.controller_lr)

        self.ce = nn.CrossEntropyLoss()

    def build_model(self):
        """Creates and initializes the shared and controller models."""
        self.shared = models.RNN(self.args, self.dataset)
        self.controller = models.Controller(self.args)

        if self.args.num_gpu == 1:
            self.shared.cuda()
            self.controller.cuda()
        elif self.args.num_gpu > 1:
            raise NotImplementedError('`num_gpu > 1` is in progress')

    def train(self, single=False):
        """Cycles through alternately training the shared parameters and the
        controller, as described in Section 2.2, Training ENAS and Deriving
        Architectures, of the paper.

        From the paper (for Penn Treebank):

        - In the first phase, shared parameters omega are trained for 400
          steps, each on a minibatch of 64 examples.

        - In the second phase, the controller's parameters are trained for 2000
          steps.
          
        Args:
            single (bool): If True it won't train the controller and use the
                           same dag instead of derive().
        """
        # dag = utils.load_dag(self.args) if single else None
        dag = None

        # if self.args.shared_initial_step > 0:
        #     self.train_shared(self.args.shared_initial_step)
        #     self.train_controller()

        for self.epoch in range(self.start_epoch, self.args.max_epoch):
            # 1. Training the shared parameters omega of the child models
            self.train_shared(dag=dag)

            # 2. Training the controller parameters theta
            if not single:
                self.train_controller()

            if self.epoch % self.args.save_epoch == 0:
                with _get_no_grad_ctx_mgr():
                    best_dag = dag if dag else self.derive()
                    self.evaluate(self.eval_data,
                                  best_dag,
                                  'val_best',
                                  max_num=self.args.batch_size * 100)
                self.save_model()

            if self.epoch >= self.args.shared_decay_after:
                utils.update_lr(self.shared_optim, self.shared_lr)

    def get_loss(self, inputs, targets, hidden, dags, is_training=True):
        """Computes the loss for the same batch for M models.

        This amounts to an estimate of the loss, which is turned into an
        estimate for the gradients of the shared model.
        """
        if not isinstance(dags, list):
            dags = [dags]

        loss = 0
        for dag in dags:
            output, hidden = self.shared(inputs,
                                         dag,
                                         prev_s=hidden,
                                         is_training=is_training)
            output_flat = output.view(-1, self.dataset.num_tokens)
            sample_loss = (self.ce(output_flat, targets) /
                           self.args.shared_num_sample)
            loss += sample_loss

        assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`'
        return loss, hidden

    def train_shared(self, max_step=None, dag=None):
        """Train the language model for 400 steps of minibatches of 64
        examples.

        Args:
            max_step: Used to run extra training steps as a warm-up.
            dag: If not None, is used instead of calling sample().

        BPTT is truncated at 35 timesteps.

        For each weight update, gradients are estimated by sampling M models
        from the fixed controller policy, and averaging their gradients
        computed on a batch of training data.
        """
        model = self.shared
        model.train()
        self.controller.eval()

        hidden = self.shared.init_hidden(self.args.batch_size)

        if max_step is None:
            max_step = self.args.shared_max_step
        else:
            max_step = min(self.args.shared_max_step, max_step)

        abs_max_grad = 0
        abs_max_hidden_norm = 0
        step = 0
        raw_total_loss = 0
        total_loss = 0
        train_idx = 0
        # TODO(brendan): Why - 1 - 1?
        while train_idx < self.train_data.size(0) - 1 - 1:
            if step > max_step:
                break

            if dag:
                dags = dag
            else:
                dags, sample_log_probs, sample_entropy = self.controller.sample(
                    self.args.shared_num_sample)
            inputs, targets = self.get_batch(self.train_data, train_idx,
                                             self.max_length)

            loss, hidden = self.get_loss(inputs, targets, hidden, dags)
            hidden = hidden.detach_()
            raw_total_loss += loss.data

            # loss += _apply_penalties(extra_out, self.args)

            # update
            self.shared_optim.zero_grad()
            loss.backward()

            # h1tohT = extra_out['hiddens']
            # new_abs_max_hidden_norm = utils.to_item(
            #     h1tohT.norm(dim=-1).data.max())
            # if new_abs_max_hidden_norm > abs_max_hidden_norm:
            #     abs_max_hidden_norm = new_abs_max_hidden_norm
            #     logger.info(f'max hidden {abs_max_hidden_norm}')
            # abs_max_grad = _check_abs_max_grad(abs_max_grad, model)
            torch.nn.utils.clip_grad_norm(model.parameters(),
                                          self.args.shared_grad_clip)
            self.shared_optim.step()

            total_loss += loss.data

            if ((step % self.args.log_step) == 0) and (step > 0):
                self._summarize_shared_train(total_loss, raw_total_loss)
                raw_total_loss = 0
                total_loss = 0

            step += 1
            self.shared_step += 1
            train_idx += self.max_length

    def get_reward(self, dag, entropies, hidden, valid_idx=0):
        """Computes the perplexity of a single sampled model on a minibatch of
        validation data.
        """
        if not isinstance(entropies, np.ndarray):
            entropies = entropies.data.cpu().numpy()

        inputs, targets = self.get_batch(self.valid_data,
                                         valid_idx,
                                         self.max_length,
                                         volatile=True)
        valid_loss, hidden = self.get_loss(inputs,
                                           targets,
                                           hidden,
                                           dag,
                                           is_training=False)
        valid_loss = utils.to_item(valid_loss.data)

        valid_ppl = math.exp(valid_loss)

        # TODO: we don't know reward_c
        if self.args.ppl_square:
            # TODO: but we do know reward_c=80 in the previous paper
            R = self.args.reward_c / valid_ppl**2
        else:
            R = self.args.reward_c / valid_ppl

        if self.args.entropy_mode == 'reward':
            rewards = R + self.args.controller_entropy_weight * entropies
        elif self.args.entropy_mode == 'regularizer':
            rewards = R * np.ones_like(entropies)
        else:
            raise NotImplementedError(
                f'Unkown entropy mode: {self.args.entropy_mode}')

        return rewards, hidden

    def train_controller(self):
        """Fixes the shared parameters and updates the controller parameters.

        The controller is updated with a score function gradient estimator
        (i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl
        is computed on a minibatch of validation data.

        A moving average baseline is used.

        The controller is trained for 2000 steps per epoch (i.e.,
        first (Train Shared) phase -> second (Train Controller) phase).
        """
        model = self.controller
        model.train()
        # TODO(brendan): Why can't we call shared.eval() here? Leads to loss
        # being uniformly zero for the controller.
        # self.shared.eval()

        avg_reward_base = None
        baseline = None
        adv_history = []
        entropy_history = []
        reward_history = []

        hidden = self.shared.init_hidden(self.args.batch_size)
        total_loss = 0
        valid_idx = 0
        for step in range(self.args.controller_max_step):
            # sample models, need M=10?
            loss_avg = []
            for m in range(1):
                dags, log_probs, entropies = self.controller.sample(
                    with_details=True)

                # calculate reward
                np_entropies = entropies.data.cpu().numpy()
                # NOTE(brendan): No gradients should be backpropagated to the
                # shared model during controller training, obviously.
                with _get_no_grad_ctx_mgr():
                    rewards, hidden = self.get_reward(dags, np_entropies,
                                                      hidden, valid_idx)

                #hidden = hidden[-1].detach_() # should we reset immediately? like below
                hidden = self.shared.init_hidden(self.args.batch_size)
                # discount
                # if 1 > self.args.discount > 0:
                #     rewards = discount(rewards, self.args.discount)

                reward_history.extend(rewards)
                entropy_history.extend(np_entropies)

                # moving average baseline
                if baseline is None:
                    baseline = rewards
                else:
                    decay = self.args.ema_baseline_decay
                    baseline = decay * baseline + (1 - decay) * rewards

                adv = rewards - baseline
                adv_history.extend(adv)

                # policy loss
                loss = -log_probs * utils.get_variable(
                    adv, self.cuda, requires_grad=False)
                loss_avg.append(loss)
            # if self.args.entropy_mode == 'regularizer':
            #     loss -= self.args.entropy_coeff * entropies
            loss = torch.stack(loss_avg)
            loss = loss.sum()
            #loss = loss.sum()  # or loss.mean()

            # update
            self.controller_optim.zero_grad()
            loss.backward()

            if self.args.controller_grad_clip > 0:
                torch.nn.utils.clip_grad_norm(model.parameters(),
                                              self.args.controller_grad_clip)
            self.controller_optim.step()

            total_loss += utils.to_item(loss.data)

            if ((step % self.args.log_step) == 0) and (step > 0):
                self._summarize_controller_train(total_loss, adv_history,
                                                 entropy_history,
                                                 reward_history,
                                                 avg_reward_base, dags)

                reward_history, adv_history, entropy_history = [], [], []
                total_loss = 0

            self.controller_step += 1

            prev_valid_idx = valid_idx
            valid_idx = ((valid_idx + self.max_length) %
                         (self.valid_data.size(0) - 1))
            # NOTE(brendan): Whenever we wrap around to the beginning of the
            # validation data, we reset the hidden states.
            if prev_valid_idx > valid_idx:
                hidden = self.shared.init_hidden(self.args.batch_size)

    def evaluate(self, source, dag, name, batch_size=1, max_num=None):
        """Evaluate on the validation set.

        NOTE(brendan): We should not be using the test set to develop the
        algorithm (basic machine learning good practices).
        """
        self.shared.eval()
        self.controller.eval()

        data = source[:max_num * self.max_length]

        total_loss = 0
        hidden = self.shared.init_hidden(batch_size)

        pbar = range(0, data.size(0) - 1, self.max_length)
        for count, idx in enumerate(pbar):
            inputs, targets = self.get_batch(data, idx, volatile=True)
            output, hidden = self.shared(inputs,
                                         dag,
                                         prev_s=hidden,
                                         is_training=False)
            output_flat = output.view(-1, self.dataset.num_tokens)
            total_loss += len(inputs) * self.ce(output_flat, targets).data
            hidden = hidden.detach_()
            ppl = math.exp(
                utils.to_item(total_loss) / (count + 1) / self.max_length)

        val_loss = utils.to_item(total_loss) / len(data)
        ppl = math.exp(val_loss)

        self.tb.scalar_summary(f'eval/{name}_loss', val_loss, self.epoch)
        self.tb.scalar_summary(f'eval/{name}_ppl', ppl, self.epoch)
        logger.info(f'eval | loss: {val_loss:8.2f} | ppl: {ppl:8.2f}')

    def derive(self, sample_num=None, valid_idx=0):
        """TODO(brendan): We are always deriving based on the very first batch
        of validation data? This seems wrong...
        """
        hidden = self.shared.init_hidden(self.args.batch_size)

        if sample_num is None:
            sample_num = self.args.derive_num_sample

        dags, _, entropies = self.controller.sample(sample_num,
                                                    with_details=True)

        dags = [dags]  # only one sample for now
        max_R = 0
        best_dag = None
        for dag in dags:
            R, _ = self.get_reward(dag, entropies, hidden, valid_idx)
            if R.max() > max_R:
                max_R = R.max()
                best_dag = dag

        logger.info(f'derive | max_R: {max_R:8.6f}')
        fname = (f'{self.epoch:03d}-{self.controller_step:06d}-'
                 f'{max_R:6.4f}-best.png')
        path = os.path.join(self.args.model_dir, 'networks', fname)
        #utils.draw_network(best_dag, path)
        #self.tb.image_summary('derive/best', [path], self.epoch)

        return best_dag

    @property
    def shared_lr(self):
        degree = max(self.epoch - self.args.shared_decay_after + 1, 0)
        return self.args.shared_lr * (self.args.shared_decay**degree)

    @property
    def controller_lr(self):
        return self.args.controller_lr

    def get_batch(self, source, idx, length=None, volatile=False):
        # code from
        # https://github.com/pytorch/examples/blob/master/word_language_model/main.py
        length = min(length if length else self.max_length,
                     len(source) - 1 - idx)
        data = source[idx:idx + length].clone().detach()
        target = source[idx + 1:idx + 1 + length].view(-1).clone().detach()
        return data, target

    @property
    def shared_path(self):
        return f'{self.args.model_dir}/shared_epoch{self.epoch}_step{self.shared_step}.pth'

    @property
    def controller_path(self):
        return f'{self.args.model_dir}/controller_epoch{self.epoch}_step{self.controller_step}.pth'

    def get_saved_models_info(self):
        paths = glob.glob(os.path.join(self.args.model_dir, '*.pth'))
        paths.sort()

        def get_numbers(items, delimiter, idx, replace_word, must_contain=''):
            return list(
                set([
                    int(name.split(delimiter)[idx].replace(replace_word, ''))
                    for name in basenames if must_contain in name
                ]))

        basenames = [
            os.path.basename(path.rsplit('.', 1)[0]) for path in paths
        ]
        epochs = get_numbers(basenames, '_', 1, 'epoch')
        shared_steps = get_numbers(basenames, '_', 2, 'step', 'shared')
        controller_steps = get_numbers(basenames, '_', 2, 'step', 'controller')

        epochs.sort()
        shared_steps.sort()
        controller_steps.sort()

        return epochs, shared_steps, controller_steps

    def save_model(self):
        torch.save(self.shared.state_dict(), self.shared_path)
        logger.info(f'[*] SAVED: {self.shared_path}')

        torch.save(self.controller.state_dict(), self.controller_path)
        logger.info(f'[*] SAVED: {self.controller_path}')

        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        for epoch in epochs[:-self.args.max_save_num]:
            paths = glob.glob(
                os.path.join(self.args.model_dir, f'*_epoch{epoch}_*.pth'))

            for path in paths:
                utils.remove_file(path)

    def load_model(self):
        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        if len(epochs) == 0:
            logger.info(f'[!] No checkpoint found in {self.args.model_dir}...')
            return

        self.epoch = self.start_epoch = max(epochs)
        self.shared_step = max(shared_steps)
        self.controller_step = max(controller_steps)

        if self.args.num_gpu == 0:
            map_location = lambda storage, loc: storage
        else:
            map_location = None

        self.shared.load_state_dict(
            torch.load(self.shared_path, map_location=map_location))
        logger.info(f'[*] LOADED: {self.shared_path}')

        self.controller.load_state_dict(
            torch.load(self.controller_path, map_location=map_location))
        logger.info(f'[*] LOADED: {self.controller_path}')

    def _summarize_controller_train(self, total_loss, adv_history,
                                    entropy_history, reward_history,
                                    avg_reward_base, dags):
        """Logs the controller's progress for this training epoch."""
        cur_loss = total_loss / self.args.log_step

        avg_adv = np.mean(adv_history)
        avg_entropy = np.mean(entropy_history)
        avg_reward = np.mean(reward_history)

        if avg_reward_base is None:
            avg_reward_base = avg_reward

        logger.info(f'| epoch {self.epoch:3d} | lr {self.controller_lr:.5f} '
                    f'| R {avg_reward:.5f} | entropy {avg_entropy:.4f} '
                    f'| loss {cur_loss:.5f}')

        # Tensorboard
        if self.tb is not None:
            self.tb.scalar_summary('controller/loss', cur_loss,
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward', avg_reward,
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward-B_per_epoch',
                                   avg_reward - avg_reward_base,
                                   self.controller_step)
            self.tb.scalar_summary('controller/entropy', avg_entropy,
                                   self.controller_step)
            self.tb.scalar_summary('controller/adv', avg_adv,
                                   self.controller_step)

            paths = []
            for dag in dags:
                fname = (f'{self.epoch:03d}-{self.controller_step:06d}-'
                         f'{avg_reward:6.4f}.png')
                path = os.path.join(self.args.model_dir, 'networks', fname)
                # utils.draw_network(dag, path)
                paths.append(path)

            # self.tb.image_summary('controller/sample',
            #                       paths,
            #                       self.controller_step)

    def _summarize_shared_train(self, total_loss, raw_total_loss):
        """Logs a set of training steps."""
        cur_loss = utils.to_item(total_loss) / self.args.log_step
        # NOTE(brendan): The raw loss, without adding in the activation
        # regularization terms, should be used to compute ppl.
        cur_raw_loss = utils.to_item(raw_total_loss) / self.args.log_step
        ppl = math.exp(cur_raw_loss)

        logger.info(f'| epoch {self.epoch:3d} '
                    f'| lr {self.shared_lr:4.2f} '
                    f'| raw loss {cur_raw_loss:.2f} '
                    f'| loss {cur_loss:.2f} '
                    f'| ppl {ppl:8.2f}')

        # Tensorboard
        if self.tb is not None:
            self.tb.scalar_summary('shared/loss', cur_loss, self.shared_step)
            self.tb.scalar_summary('shared/perplexity', ppl, self.shared_step)
예제 #3
0
def main(args):

    # Step 1: init data folders
    '''if os.path.exists('save_state/'+args.regime+'/normalization_stats.pkl'):
        print('Loading normalization stats')
        x_mean, x_sd = misc.load_file('save_state/'+args.regime+'/normalization_stats.pkl')
    else:
        x_mean, x_sd = preprocess.save_normalization_stats(args.regime)
        print('x_mean: %.3f, x_sd: %.3f' % (x_mean, x_sd))'''

    val_loader=load_data(args, "val")

    tb=TensorBoard(args.model_dir)

    # Step 2: init neural networks
    print("network is:",args.net)
    if args.net == 'Reab3p16':
        model = Reab3p16(args)
    elif args.net=='RN_mlp':
        model =WildRelationNet()
    if args.gpunum > 1:
        model = nn.DataParallel(model, device_ids=range(args.gpunum))

    weights_path = args.path_weight+"/"+args.load_weight

    if os.path.exists(weights_path) and args.restore:
        pretrained_dict = torch.load(weights_path)
        model_dict = model.state_dict()
        pretrained_dict1 = {}
        for k, v in pretrained_dict.items():
            if k in model_dict:
                pretrained_dict1[k] = v
                #print(k)
        model_dict.update(pretrained_dict1)
        model.load_state_dict(model_dict)

        print('load weight')

    style_raven={65:0, 129:1, 257:2, 66:3, 132:4, 36:5, 258:6, 136:7, 264:8, 72:9, 130:10
         , 260:11, 40:12, 34:13, 49:14, 18:15, 20:16, 24:17}
    model.cuda()
    optimizer = optim.SGD(model.parameters(), lr=args.lr,momentum=args.mo, weight_decay=5e-4)
   
    if args.gpunum>1:
        optimizer = nn.DataParallel(optimizer, device_ids=range(args.gpunum))

    iter_count = 1
    epoch_count = 1
    #iter_epoch=int(len(train_files) / args.batch_size)
    print(time.strftime('%H:%M:%S', time.localtime(time.time())), 'training')
    style_raven_len = len(style_raven)
    
    if args.rl_style=="dqn":
        dqn = DQN()
    elif args.rl_style=="ddpg":
        ram = MemoryBuffer(1000)
        ddpg = Trainer(style_raven_len*4+2, style_raven_len, 1, ram)
    alpha_1=0.1

    if args.rl_style=="dqn":
        a = dqn.choose_action([0.5] * 3)  # TODO
    elif args.rl_style=="ddpg":
        action_ = ddpg.get_exploration_action(np.zeros([style_raven_len*4+2]).astype(np.float32),alpha_1)
    if args.type_loss:loss_fn=nn.BCELoss()
    best_acc=0.0
    while True:
        since=time.time()
        print(action_)
        for i in range(style_raven_len):
            tb.scalar_summary("action/a"+str(i), action_[i], epoch_count)

        data_files = preprocess.provide_data(args.regime, style_raven_len, action_,style_raven)

        train_files = [data_file for data_file in data_files if 'train' in data_file]
        print("train_num:", len(train_files))
        train_loader = torch.utils.data.DataLoader(Dataset(args,train_files), batch_size=args.batch_size, shuffle=True,
                                                   num_workers=args.numwork)
        model.train()
        iter_epoch = int(len(train_files) / args.batch_size)
        acc_part_train=np.zeros([style_raven_len,2]).astype(np.float32)

        mean_loss_train= np.zeros([style_raven_len, 2]).astype(np.float32)
        loss_train=0
        for x, y,style,me in train_loader:
            if x.shape[0]<10:
                print(x.shape[0])
                break
            x, y ,meta = Variable(x).cuda(), Variable(y).cuda(), Variable(me).cuda()
            if args.gpunum > 1:
                optimizer.module.zero_grad()
            else:
                optimizer.zero_grad()
            if args.type_loss:
                pred_train, pred_meta= model(x)
            else:
                pred_train = model(x)
            loss_ = F.nll_loss(pred_train, y,reduce=False)
            loss=loss_.mean() if not args.type_loss else loss_.mean()+10*loss_fn(pred_meta,meta)
            loss.backward()
            if args.gpunum > 1:
                optimizer.module.step()
            else:
                optimizer.step()
            iter_count += 1
            pred = pred_train.data.max(1)[1]
            correct = pred.eq(y.data).cpu()
            loss_train+=loss.item()
            for num, style_pers in enumerate(style):
                style_pers = style_pers[:-4].split("/")[-1].split("_")[3:]
                for style_per in style_pers:
                    style_per=int(style_per)
                    if correct[num] == 1:
                        acc_part_train[style_per, 0] += 1
                    acc_part_train[style_per, 1] += 1
                    #mean_pred_train[style_per,0] += pred_train[num,y[num].item()].data.cpu()
                    #mean_pred_train[style_per, 1] += 1
                    mean_loss_train[style_per,0] += loss_[num].item()
                    mean_loss_train[style_per, 1] += 1
            accuracy_total = correct.sum() * 100.0 / len(y)

            if iter_count %10 == 0:
                iter_c = iter_count % iter_epoch
                print(time.strftime('%H:%M:%S', time.localtime(time.time())),
                      ('train_epoch:%d,iter_count:%d/%d, loss:%.3f, acc:%.1f') % (
                      epoch_count, iter_c, iter_epoch, loss, accuracy_total))
                tb.scalar_summary("train_loss",loss,iter_count)
        loss_train=loss_train/len(train_files)
        #mean_pred_train=[x[0]/ x[1] for x in mean_pred_train]
        mean_loss_train=[x[0]/ x[1] for x in mean_loss_train]
        acc_part_train = [x[0] / x[1] if x[1]!=0 else 0  for x in acc_part_train]
        print(acc_part_train)
        if epoch_count %args.lr_step ==0:
            print("change lr")
            adjust_learning_rate(optimizer, epoch_count, args.lr_step,args.gpunum)
        time_elapsed = time.time() - since
        print('train epoch in {:.0f}h {:.0f}m {:.0f}s'.format(
            time_elapsed // 3600, time_elapsed // 60 % 60, time_elapsed % 60))
        #acc_p=np.array([x[0]/x[1] for x in acc_part])
        #print(acc_p)
        with torch.no_grad():
            model.eval()
            accuracy_all = []
            iter_test=0
            acc_part_val = np.zeros([style_raven_len, 2]).astype(np.float32)
            for x, y, style,me in val_loader:
                iter_test+=1
                x, y = Variable(x).cuda(), Variable(y).cuda()
                pred,_ = model(x)
                pred = pred.data.max(1)[1]
                correct = pred.eq(y.data).cpu().numpy()
                accuracy = correct.sum() * 100.0 / len(y)
                for num, style_pers in enumerate(style):
                    style_pers = style_pers[:-4].split("/")[-1].split("_")[3:]
                    for style_per in style_pers:
                        style_per = int(style_per)
                        if correct[num] == 1:
                            acc_part_val[style_per, 0] += 1
                        acc_part_val[style_per, 1] += 1
                accuracy_all.append(accuracy)

                # if iter_test % 10 == 0:
                #
                #     print(time.strftime('%H:%M:%S', time.localtime(time.time())),
                #           ('test_iter:%d, acc:%.1f') % (
                #               iter_test, accuracy))

        accuracy_all = sum(accuracy_all) / len(accuracy_all)
        acc_part_val = [x[0] / x[1] if x[1]!=0 else 0 for x in acc_part_val ]
        baseline_rl=70
        reward=np.mean(acc_part_val)*100-baseline_rl
        tb.scalar_summary("valreward", reward,epoch_count)
        action_list=[x for x in a]
        cur_state=np.array(acc_part_val+acc_part_train+action_list+mean_loss_train
                           +[loss_train]+[epoch_count]).astype(np.float32)
        #np.expand_dims(, axis=0)
        if args.rl_style == "dqn":
            a = dqn.choose_action(cur_state)  # TODO
        elif args.rl_style == "ddpg":
            a = ddpg.get_exploration_action(cur_state,alpha_1)

        if alpha_1<1:
            alpha_1+=0.005#0.1
        if epoch_count > 1:
            if args.rl_style == "dqn":dqn.store_transition(last_state, a, reward , cur_state)
            elif args.rl_style == "ddpg":ram.add(last_state, a, reward, cur_state)


        if epoch_count > 1:
            if args.rl_style == "dqn":dqn.learn()
            elif args.rl_style == "ddpg":loss_actor, loss_critic=ddpg.optimize()
            print('------------------------------------')
            print('learn q learning')
            print('------------------------------------')
            tb.scalar_summary("loss_actor", loss_actor, epoch_count)
            tb.scalar_summary("loss_critic", loss_critic, epoch_count)


        last_state=cur_state
        time_elapsed = time.time() - since
        print('test epoch in {:.0f}h {:.0f}m {:.0f}s'.format(
            time_elapsed // 3600, time_elapsed // 60 % 60, time_elapsed % 60))
        print('------------------------------------')
        print(('epoch:%d, acc:%.1f') % (epoch_count, accuracy_all))
        print('------------------------------------')
        if accuracy_all>best_acc:
            best_acc=max(best_acc,accuracy_all)
            #ddpg.save_models(args.model_dir + '/', epoch_count)
            save_state(model.state_dict(), args.model_dir + "/epochbest")
        epoch_count += 1
        if epoch_count%20==0:
            print("save weights")
            ddpg.save_models(args.model_dir+'/',epoch_count )
            save_state(model.state_dict(), args.model_dir+"/epoch"+str(epoch_count))
class Trainer(object):
    def __init__(self, args, dataset):
        self.args = args
        self.cuda = args.cuda
        self.dataset = dataset

        self.train_data = batchify(dataset.train, args.batch_size, self.cuda)
        self.valid_data = batchify(dataset.valid, args.batch_size, self.cuda)
        self.test_data = batchify(dataset.test, args.test_batch_size,
                                  self.cuda)

        self.max_length = self.args.shared_rnn_max_length

        if args.use_tensorboard:
            self.tb = TensorBoard(args.model_dir)
        else:
            self.tb = None
        self.build_model()

        if self.args.load_path:
            self.load_model()

    def build_model(self):
        self.start_epoch = self.epoch = 0
        self.shared_step, self.controller_step = 0, 0

        if self.args.network_type == 'rnn':
            self.shared = RNN(self.args, self.dataset)
        elif self.args.network_type == 'cnn':
            self.shared = CNN(self.args, self.dataset)
        else:
            raise NotImplemented(
                f"Network type `{self.args.network_type}` is not defined")
        self.controller = Controller(self.args)

        if self.args.num_gpu == 1:
            self.shared.cuda()
            self.controller.cuda()
        elif self.args.num_gpu > 1:
            raise NotImplemented("`num_gpu > 1` is in progress")

        self.ce = nn.CrossEntropyLoss()

    def train(self):
        shared_optimizer = get_optimizer(self.args.shared_optim)
        controller_optimizer = get_optimizer(self.args.controller_optim)

        self.shared_optim = shared_optimizer(
            self.shared.parameters(),
            lr=self.shared_lr,
            weight_decay=self.args.shared_l2_reg)

        self.controller_optim = controller_optimizer(
            self.controller.parameters(), lr=self.args.controller_lr)

        hidden = self.shared.init_hidden(self.args.batch_size)

        for self.epoch in range(self.start_epoch, self.args.max_epoch):
            # 1. Training the shared parameters ω of the child models
            hidden = self.train_shared(hidden)

            # 2. Training the controller parameters θ
            self.train_controller()

            if self.epoch % self.args.save_epoch == 0:
                if self.epoch > 0:
                    best_dag = self.derive()
                    loss, ppl = self.test(self.test_data, best_dag,
                                          "test_best")
                self.save_model()

            if self.epoch >= self.args.shared_decay_after:
                update_lr(self.shared_optim, self.shared_lr)

    def get_loss(self, inputs, targets, hidden, dags, with_hidden=False):
        if type(dags) != list:
            dags = [dags]

        loss = 0
        for dag in dags:
            # previous hidden is useless
            output, hidden = self.shared(inputs, hidden, dag)
            output_flat = output.view(-1, self.dataset.num_tokens)
            sample_loss = self.ce(output_flat,
                                  targets) / self.args.shared_num_sample
            loss += sample_loss

        if with_hidden:
            assert len(
                dags) == 1, "there are multiple `hidden` for multple `dags`"
            return loss, hidden
        else:
            return loss

    def train_shared(self, hidden):
        total_loss = 0

        model = self.shared
        model.train()

        step, train_idx = 0, 0
        pbar = tqdm(total=self.train_data.size(0), desc="train_shared")

        while train_idx < self.train_data.size(0) - 1 - 1:
            if step > self.args.shared_max_step:
                break

            dags = self.controller.sample(self.args.shared_num_sample)
            inputs, targets = self.get_batch(self.train_data, train_idx,
                                             self.max_length)

            loss = self.get_loss(inputs, targets, hidden, dags)

            # update
            self.shared_optim.zero_grad()
            loss.backward()

            t.nn.utils.clip_grad_norm(model.parameters(),
                                      self.args.shared_grad_clip)
            self.shared_optim.step()

            total_loss += loss.data
            pbar.set_description(f"train_shared| loss: {loss.data[0]:5.3f}")

            if step % self.args.log_step == 0 and step > 0:
                cur_loss = total_loss[0] / self.args.log_step
                ppl = math.exp(cur_loss)

                logger.info(
                    f'| epoch {self.epoch:3d} | lr {self.shared_lr:4.2f} '
                    f'| loss {cur_loss:.2f} | ppl {ppl:8.2f}')

                # Tensorboard
                if self.tb is not None:
                    self.tb.scalar_summary("shared/loss", cur_loss,
                                           self.shared_step)
                    self.tb.scalar_summary("shared/perplexity", ppl,
                                           self.shared_step)

                total_loss = 0

            step += 1
            self.shared_step += 1

            train_idx += self.max_length
            pbar.update(self.max_length)

    def get_reward(self, dag, valid_idx=None):
        if valid_idx:
            valid_idx = 0

        inputs, targets = self.get_batch(self.valid_data, valid_idx,
                                         self.max_length)
        valid_loss = self.get_loss(inputs, targets, None, dag)

        valid_ppl = math.exp(valid_loss.data[0])
        R = self.args.reward_c / valid_ppl

        return R

    def train_controller(self):
        total_loss = 0

        model = self.controller
        model.train()

        pbar = trange(self.args.controller_max_step, desc="train_controller")

        baseline = None
        reward_history, adv_history, entropy_history = [], [], []

        valid_idx = 0

        for step in pbar:
            # sample models
            dags, log_probs, entropies = self.controller.sample(
                with_details=True)

            # calculate reward
            R = self.get_reward(dags, valid_idx)

            reward_history.append(R)
            entropy_history.extend(entropies)

            # moving average baseline
            if baseline is None:
                baseline = R
            else:
                decay = self.args.ema_baseline_decay
                baseline = decay * baseline + (1 - decay) * R

            adv = R - baseline
            adv_history.append(adv)
            pbar.set_description(
                f"train_controller| R: {R:8.6f} | R-b: {adv:8.6f}")

            rewards = [0] * (2 * (self.args.num_blocks - 1)) + [adv]
            # discount
            if self.args.discount == 1:
                rewards = [adv] * len(log_probs)
            elif self.args.discount > 0:
                rewards = discount(rewards, self.args.discount)
            #rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps)

            # policy loss
            loss = 0
            for log_prob, reward, entropy in zip(log_probs, rewards,
                                                 entropies):
                loss = loss - log_prob * reward - self.args.entropy_coeff * entropy

            # update
            self.controller_optim.zero_grad()
            loss.backward()
            self.controller_optim.step()

            total_loss += loss.data

            if step % self.args.log_step == 0 and step > 0:
                cur_loss = total_loss[0][0] / self.args.log_step

                avg_reward = np.mean(reward_history)
                avg_entropy = np.mean(entropy_history)
                avg_adv = np.mean(adv_history)

                logger.info(
                    f'| epoch {self.epoch:3d} | lr {self.controller_lr:.5f} '
                    f'| R {avg_reward:.5f} | entropy {avg_entropy:.4f} '
                    f'| loss {cur_loss:.5f}')

                # Tensorboard
                if self.tb is not None:
                    self.tb.scalar_summary("controller/loss", cur_loss,
                                           self.controller_step)
                    self.tb.scalar_summary("controller/reward", avg_reward,
                                           self.controller_step)
                    self.tb.scalar_summary("controller/entropy", avg_entropy,
                                           self.controller_step)
                    self.tb.scalar_summary("controller/adv", avg_adv,
                                           self.controller_step)

                    paths = []
                    for dag in dags:
                        fname = f"{self.epoch:03d}-{self.controller_step:06d}-{avg_reward:6.4f}.png"
                        path = os.path.join(self.args.model_dir, "networks",
                                            fname)
                        draw_network(dag, path)
                        paths.append(path)

                    self.tb.image_summary("controller/sample", paths,
                                          self.controller_step)

                reward_history, adv_history, entropy_history = [], [], []

            self.controller_step += 1

            valid_idx = (valid_idx +
                         self.max_length) % (self.valid_data.size(0) - 1)

    def test(self, source, dag, name, batch_size=1):
        self.shared.eval()
        self.controller.eval()

        total_loss = 0
        hidden = self.shared.init_hidden(batch_size)

        pbar = trange(0, source.size(0) - 1, self.max_length, desc="test")
        for count, idx in enumerate(pbar):
            data, targets = self.get_batch(source, idx, evaluation=True)
            output, hidden = self.shared(data, hidden, dag)
            output_flat = output.view(-1, self.dataset.num_tokens)
            total_loss += len(data) * self.ce(output_flat, targets).data
            hidden = detach(hidden)

            ppl = math.exp(total_loss[0] / (count + 1) / self.max_length)
            pbar.set_description(f"test| ppl: {ppl:8.2f}")

        test_loss = total_loss[0] / len(source)
        ppl = math.exp(test_loss)

        self.tb.scalar_summary(f"test/{name}_loss", test_loss, self.epoch)
        self.tb.scalar_summary(f"test/{name}_ppl", ppl, self.epoch)

        return test_loss, ppl

    def derive(self, valid_idx=0, sample_num=None):
        if sample_num is None:
            sample_num = self.args.derive_num_sample

        dags = self.controller.sample(sample_num)

        max_R, best_dag = 0, None
        pbar = tqdm(dags, desc="derive")
        for dag in pbar:
            R = self.get_reward(dag, valid_idx)
            if R > max_R:
                max_R = R
                best_dag = dag
            pbar.set_description(f"derive| max_R: {max_R:8.6f}")

        fname = f"{self.epoch:03d}-{self.controller_step:06d}-{max_R:6.4f}-best.png"
        path = os.path.join(self.args.model_dir, "networks", fname)
        draw_network(best_dag, path)
        self.tb.image_summary("derive/best", [path], self.epoch)

        return best_dag

    @property
    def shared_lr(self):
        degree = max(self.epoch - self.args.shared_decay_after + 1, 0)
        return self.args.shared_lr * (self.args.shared_decay**degree)

    @property
    def controller_lr(self):
        return self.args.controller_lr

    def get_batch(self, source, idx, length=None, evaluation=False):
        # code from https://github.com/pytorch/examples/blob/master/word_language_model/main.py
        length = min(length if length else self.max_length,
                     len(source) - 1 - idx)
        data = Variable(source[idx:idx + length], volatile=evaluation)
        target = Variable(source[idx + 1:idx + 1 + length].view(-1))
        return data, target

    @property
    def shared_path(self):
        return f'{self.args.model_dir}/shared_epoch{self.epoch}_step{self.shared_step}.pth'

    @property
    def controller_path(self):
        return f'{self.args.model_dir}/controller_epoch{self.epoch}_step{self.controller_step}.pth'

    def get_saved_models_info(self):
        paths = glob(os.path.join(self.args.model_dir, '*.pth'))
        paths.sort()

        def get_numbers(items, delimiter, idx, replace_word, must_contain=''):
            return list(
                set([
                    int(name.split(delimiter)[idx].replace(replace_word, ''))
                    for name in basenames if must_contain in name
                ]))

        basenames = [
            os.path.basename(path.rsplit('.', 1)[0]) for path in paths
        ]
        epochs = get_numbers(basenames, '_', 1, 'epoch')
        shared_steps = get_numbers(basenames, '_', 2, 'step', 'shared')
        controller_steps = get_numbers(basenames, '_', 2, 'step', 'controller')

        epochs.sort()
        shared_steps.sort()
        controller_steps.sort()

        return epochs, shared_steps, controller_steps

    def save_model(self):
        t.save(self.shared.state_dict(), self.shared_path)
        logger.info(f"[*] SAVED: {self.shared_path}")

        t.save(self.controller.state_dict(), self.controller_path)
        logger.info(f"[*] SAVED: {self.controller_path}")

        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        for epoch in epochs[:-self.args.max_save_num]:
            paths = glob(
                os.path.join(self.args.model_dir, f'*_epoch{epoch}_*.pth'))

            for path in paths:
                remove_file(path)

    def load_model(self):
        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        if len(epochs) == 0:
            logger.info(f"[!] No checkpoint found in {self.args.model_dir}...")
            return

        self.start_epoch = max(epochs)
        self.shared_step = max(shared_steps)
        self.controller_step = max(controller_steps)

        if self.args.num_gpu == 0:
            map_location = lambda storage, loc: storage
        else:
            map_location = None

        self.shared.load_state_dict(
            t.load(self.shared_path, map_location=map_location))
        logger.info(f"[*] LOADED: {self.shared_path}")

        self.controller.load_state_dict(
            t.load(self.controller_path, map_location=map_location))
        logger.info(f"[*] LOADED: {self.controller_path}")
예제 #5
0
def main(args):

    # Step 1: init data folders
    '''if os.path.exists('save_state/'+args.regime+'/normalization_stats.pkl'):           ##to load raw data and preprocess it 
        print('Loading normalization stats')
        x_mean, x_sd = misc.load_file('save_state/'+args.regime+'/normalization_stats.pkl')
    else:
        x_mean, x_sd = preprocess.save_normalization_stats(args.regime)
        print('x_mean: %.3f, x_sd: %.3f' % (x_mean, x_sd))'''

    val_loader=load_data(args, "val")              ##loading already preprocessed validation/testing data 

    tb=TensorBoard(args.model_dir)                ##The model_dir arguments represents the directory to save model parameters, graph and etc. This can also be used to 
                                                  ##load checkpoints from the directory into a estimator to continue training a previously saved model.

    # Step 2: init neural networks
    print("network is:",args.net)
    if args.net == 'Reab3p16':                ##if want to use model Reab3p16
        model = Reab3p16(args)
    elif args.net=='RN_mlp':                  ##if want to use model WildRelationNet
        model =WildRelationNet()
    if args.gpunum > 1:                        
        model = nn.DataParallel(model, device_ids=range(args.gpunum)) ##The nn package defines a set of Modules, which you can think of as a neural network layer that has produces output from 
                                                                       ##input and may have some trainable weights.
                                                                    ##when more than one gpu, want to save model weights using DataParrallel module prefix
    weights_path = args.path_weight+"/"+args.load_weight               ##saved weigths of model 

    if os.path.exists(weights_path) and args.restore:             ##pretrained weights
        pretrained_dict = torch.load(weights_path)                 ##pretrained_dict is the state dictionary of the pre-trained model available
        model_dict = model.state_dict()                           ## https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.htmlA state_dict is an integral entity 
        pretrained_dict1 = {}                                      ##..if you are interested in saving or loading models from PyTorch
        for k, v in pretrained_dict.items():                      ##filter out unnecessary keys k
            if k in model_dict:                                   ##only when keys match(like conv2D..and so forth)
                pretrained_dict1[k] = v
                #print(k)                   
        model_dict.update(pretrained_dict1)                        ##overwrite entries in the existing state dict 
        model.load_state_dict(model_dict)                          ##load the new state dict, new weights

        print('load weight')

    style_raven={65:0, 129:1, 257:2, 66:3, 132:4, 36:5, 258:6, 136:7, 264:8, 72:9, 130:10    ##dictionary(key:value pair of      
         , 260:11, 40:12, 34:13, 49:14, 18:15, 20:16, 24:17}

##After setting weights using optimizer for training.

##The standard way in PyTorch to train a model in multiple GPUs is to use nn.DataParallel which copies the model to the GPUs 
##and during training splits the batch among them and combines the individual outputs.
##model.cuda() by default will send your model to the "current device"

#If you need to move a model to GPU via .cuda(), please do so before constructing optimizers for it. Parameters of a model 
#after .cuda() will be different objects with those before the call.

##A very popular technique that is used along with SGD is called Momentum. Instead of using only the gradient of the current 
##step to guide the search, momentum also accumulates the gradient of the past steps to determine the direction to go
    model.cuda()
    optimizer = optim.SGD(model.parameters(), lr=args.lr,momentum=args.mo, weight_decay=5e-4) ##Adam has convergence problems that often SGD + momentum can converge better 
                                                                               ##with longer training time. We often see a lot of papers in 2018 and 2019 were still using SGD
    if args.gpunum>1:
        optimizer = nn.DataParallel(optimizer, device_ids=range(args.gpunum))
                                  ##setting iter-count and epoch to 1 before starting training
    iter_count = 1               ## number of batches of data the algorithm has seen (or simply the number of passes the algorithm has done on the dataset)
    epoch_count = 1              ##number of times a learning algorithm sees the complete dataset 
    #iter_epoch=int(len(train_files) / args.batch_size)
    print(time.strftime('%H:%M:%S', time.localtime(time.time())), 'training')
    style_raven_len = len(style_raven)  ##length of  style raven dict
    
    if args.rl_style=="dqn":     ##calling reinforcemt model for training
        dqn = DQN()                ##if want to use dqn model
    elif args.rl_style=="ddpg":    ##if want to use ddpg model (aiming to use this)
        ram = MemoryBuffer(1000)   
        ddpg = Trainer(style_raven_len*4+2, style_raven_len, 1, ram)        ##creating an instance of Trainer class defined  in rl folder (ddpg.py) why style_raven_len*4+2? 
    alpha_1=0.1

    if args.rl_style=="dqn":
        a = dqn.choose_action([0.5] * 3)  # TODO
    elif args.rl_style=="ddpg":
        action_ = ddpg.get_exploration_action(np.zeros([style_raven_len*4+2]).astype(np.float32),alpha_1) ##calling exploration which returns action? 
    if args.type_loss:loss_fn=nn.BCELoss()                      ##Creates a criterion that measures the Binary Cross Entropy between the target and the output.
    best_acc=0.0                                                ##setting accuracy to 0.0
    while True:                                                ##loop(train)  until
        since=time.time()
        print(action_)                                            
        for i in range(style_raven_len):                
            tb.scalar_summary("action/a"+str(i), action_[i], epoch_count) ##saving summary such as poch counts and actions

        data_files = preprocess.provide_data(args.regime, style_raven_len, action_,style_raven) 

        train_files = [data_file for data_file in data_files if 'train' in data_file]               #creating a list of training files
        print("train_num:", len(train_files))
    
        ##torch.utils.data.DataLoader` supports both map-style and iterable-style datasets with single- or multi-process loading,
        ##customizing loading order and optional automatic batching (collation) and memory pinning
        ##shuffle true because we want independent B training batches from Dataset
        train_loader = torch.utils.data.DataLoader(Dataset(args,train_files), batch_size=args.batch_size, shuffle=True,  
                                                   num_workers=args.numwork)
        model.train()                      ##start training model
        iter_epoch = int(len(train_files) / args.batch_size)         ##setting iteration count for total dataset
        acc_part_train=np.zeros([style_raven_len,2]).astype(np.float32)       ##defining variable for saving part accuracy while training

        mean_loss_train= np.zeros([style_raven_len, 2]).astype(np.float32)     ##defining variable for saving mean loss while training
        loss_train=0
        for x, y,style,me in train_loader:                              
            if x.shape[0]<10:                             ##x.shape[0] will give the number of rows in an array  (10 by 1024 2D array)                 
                print(x.shape[0])
                break                                                            
            x, y ,meta = Variable(x).cuda(), Variable(y).cuda(), Variable(me).cuda()  ##Components are accessible as variable.x,  variable.y,  variable.z
            if args.gpunum > 1:                                                        
                optimizer.module.zero_grad()             ##to set the gradient of the parameters in the model to 0, module beacause DataParallel
            else:
                optimizer.zero_grad()                    ## same as above set the gradient of the parameters to zero
            if args.type_loss:
                pred_train, pred_meta= model(x)              ##applying model to x where x is from training data
            else:
                pred_train = model(x)                        ##x is images y is actual label/category
            loss_ = F.nll_loss(pred_train, y,reduce=False)     ##calculating loss occurred while training
            loss=loss_.mean() if not args.type_loss else loss_.mean()+10*loss_fn(pred_meta,meta)##If your loss is not a scalar value, then you should certainly use either 
            loss.backward()             ##loss.mean() or loss.sum() to convert it to a scalar before calling the backward. Otherwise, it will cause an error
        
        #When you call loss.backward(), all it does is compute gradient of loss w.r.t all the parameters in loss that have 
        ##requires_grad = True and store them in parameter.grad attribute for every parameter.
        ##optimizer.step() updates all the parameters based on parameter.grad
            if args.gpunum > 1:
                optimizer.module.step()      ##module for DataParallel
            else:
                optimizer.step()
            iter_count += 1                ##update iter-count by 1 evrytime
            pred = pred_train.data.max(1)[1]  
            correct = pred.eq(y.data).cpu()       ##compare actual and predicted category
            loss_train+=loss.item()               ##The average of the batch losses will give you an estimate of the “epoch loss” during training.
            for num, style_pers in enumerate(style):
                style_pers = style_pers[:-4].split("/")[-1].split("_")[3:]
                for style_per in style_pers:
                    style_per=int(style_per)
                    if correct[num] == 1:
                        acc_part_train[style_per, 0] += 1
                    acc_part_train[style_per, 1] += 1
                    #mean_pred_train[style_per,0] += pred_train[num,y[num].item()].data.cpu()
                    #mean_pred_train[style_per, 1] += 1
                    mean_loss_train[style_per,0] += loss_[num].item()
                    mean_loss_train[style_per, 1] += 1
            accuracy_total = correct.sum() * 100.0 / len(y)       ####calc accuracy 

            if iter_count %10 == 0:                        ##do this for 10 iterations
                iter_c = iter_count % iter_epoch
                print(time.strftime('%H:%M:%S', time.localtime(time.time())),
                      ('train_epoch:%d,iter_count:%d/%d, loss:%.3f, acc:%.1f') % (
                      epoch_count, iter_c, iter_epoch, loss, accuracy_total))
                tb.scalar_summary("train_loss",loss,iter_count)               ##saving train loss to summary
        loss_train=loss_train/len(train_files)                             ##The average of the batch losses will give you an estimate of the “epoch loss” during training.
        #mean_pred_train=[x[0]/ x[1] for x in mean_pred_train]
        mean_loss_train=[x[0]/ x[1] for x in mean_loss_train]
        acc_part_train = [x[0] / x[1] if x[1]!=0 else 0  for x in acc_part_train]
        print(acc_part_train)
        if epoch_count %args.lr_step ==0:                 ##adjusting learning rate after  30 epochs
            print("change lr")
            adjust_learning_rate(optimizer, epoch_count, args.lr_step,args.gpunum)
        time_elapsed = time.time() - since
        print('train epoch in {:.0f}h {:.0f}m {:.0f}s'.format(
            time_elapsed // 3600, time_elapsed // 60 % 60, time_elapsed % 60))
        #acc_p=np.array([x[0]/x[1] for x in acc_part])
        #print(acc_p)
        with torch.no_grad():
            model.eval()             ##evaluating model 
            accuracy_all = []
            iter_test=0
            acc_part_val = np.zeros([style_raven_len, 2]).astype(np.float32)
            for x, y, style,me in val_loader:             ##using validation data
                iter_test+=1
                x, y = Variable(x).cuda(), Variable(y).cuda()
                pred,_ = model(x)
                pred = pred.data.max(1)[1]
                correct = pred.eq(y.data).cpu().numpy()
                accuracy = correct.sum() * 100.0 / len(y)   ##accuracy is calc basd on how many labels match
                for num, style_pers in enumerate(style):
                    style_pers = style_pers[:-4].split("/")[-1].split("_")[3:]
                    for style_per in style_pers:
                        style_per = int(style_per)
                        if correct[num] == 1:
                            acc_part_val[style_per, 0] += 1
                        acc_part_val[style_per, 1] += 1
                accuracy_all.append(accuracy)                    ##append to accuracy list

                # if iter_test % 10 == 0:
                #
                #     print(time.strftime('%H:%M:%S', time.localtime(time.time())),
                #           ('test_iter:%d, acc:%.1f') % (
                #               iter_test, accuracy))

        accuracy_all = sum(accuracy_all) / len(accuracy_all)              ##total accuracy is calculated 
        acc_part_val = [x[0] / x[1] if x[1]!=0 else 0 for x in acc_part_val ]
        baseline_rl=70                                          ##baseline for accuracy
        reward=np.mean(acc_part_val)*100-baseline_rl          ##calculating reward using val accuracy
        tb.scalar_summary("valreward", reward,epoch_count)        ##saving summary
        action_list=[x for x in a]
        cur_state=np.array(acc_part_val+acc_part_train+action_list+mean_loss_train ##saving all calc in currnt state
                           +[loss_train]+[epoch_count]).astype(np.float32)
        #np.expand_dims(, axis=0)
        if args.rl_style == "dqn":
            a = dqn.choose_action(cur_state)  # TODO
        elif args.rl_style == "ddpg":                              ##passing current state to rl model's get_exploration_action
            a = ddpg.get_exploration_action(cur_state,alpha_1)

        if alpha_1<1:
            alpha_1+=0.005#0.1
        if epoch_count > 1:                                ##saving  last state and current state ,reward in memory  for epoch >1
            if args.rl_style == "dqn":dqn.store_transition(last_state, a, reward , cur_state)
            elif args.rl_style == "ddpg":ram.add(last_state, a, reward, cur_state)    


        if epoch_count > 1:
            if args.rl_style == "dqn":dqn.learn()
            elif args.rl_style == "ddpg":loss_actor, loss_critic=ddpg.optimize()      ##using rl ddpg model's optimize function to for teaching
            print('------------------------------------')
            print('learn q learning')
            print('------------------------------------')
            tb.scalar_summary("loss_actor", loss_actor, epoch_count)
            tb.scalar_summary("loss_critic", loss_critic, epoch_count)


        last_state=cur_state
        time_elapsed = time.time() - since
        print('test epoch in {:.0f}h {:.0f}m {:.0f}s'.format(
            time_elapsed // 3600, time_elapsed // 60 % 60, time_elapsed % 60))
        print('------------------------------------')
        print(('epoch:%d, acc:%.1f') % (epoch_count, accuracy_all))
        print('------------------------------------')
        if accuracy_all>best_acc:                                        ##save the best accuracy obtained from val data as best accuracy for next epoch
            best_acc=max(best_acc,accuracy_all)
            #ddpg.save_models(args.model_dir + '/', epoch_count)
            save_state(model.state_dict(), args.model_dir + "/epochbest")             ##saving the current state
        epoch_count += 1                                                         ##increasing  epoch count by 1
        if epoch_count%20==0:              ##Do this for 20 epochs for complete dataset 
            print("save weights")
            ddpg.save_models(args.model_dir+'/',epoch_count )                  ##saving the model
            save_state(model.state_dict(), args.model_dir+"/epoch"+str(epoch_count))
예제 #6
0
class Trainer(object):
    """A class to wrap training code."""
    def __init__(self, args, dataset):
        """Constructor for training algorithm.

        Args:
            args: From command line, picked up by `argparse`.
            dataset: Currently only `data.text.Corpus` is supported.

        Initializes:
            - Data: train, val and test.
            - Model: shared and controller.
            - Inference: optimizers for shared and controller parameters.
            - Criticism: cross-entropy loss for training the shared model.
        """
        self.args = args
        self.controller_step = 0
        self.cuda = args.cuda
        self.device = gpu = torch.device("cuda:0")
        self.dataset = dataset
        self.epoch = 0
        self.shared_step = 0
        self.start_epoch = 0
        self.compute_fisher = False

        logger.info('regularizing:')
        for regularizer in [('activation regularization',
                             self.args.activation_regularization),
                            ('temporal activation regularization',
                             self.args.temporal_activation_regularization),
                            ('norm stabilizer regularization',
                             self.args.norm_stabilizer_regularization)]:
            if regularizer[1]:
                logger.info(f'{regularizer[0]}')

        self.image_dataset = isinstance(dataset, Image)
        if self.image_dataset:
            self._train_data = dataset.train
            self._valid_data = dataset.valid
            self._test_data = dataset.test
            self._eval_data = dataset.valid
            self.train_data = wrap_iterator_with_name(self._train_data,
                                                      'train')
            self.valid_data = wrap_iterator_with_name(self._valid_data,
                                                      'valid')
            self.test_data = wrap_iterator_with_name(self._test_data, 'test')
            self.eval_data = wrap_iterator_with_name(self._eval_data, 'eval')

            self.max_length = 0

        else:
            self.train_data = utils.batchify(dataset.train, args.batch_size,
                                             self.cuda)
            self.valid_data = utils.batchify(dataset.valid, args.batch_size,
                                             self.cuda)
            self.eval_data = utils.batchify(dataset.valid,
                                            args.test_batch_size, self.cuda)
            self.test_data = utils.batchify(dataset.test, args.test_batch_size,
                                            self.cuda)

            self.max_length = self.args.shared_rnn_max_length

        self.train_data_size = self.train_data.size(
            0) if not self.image_dataset else len(self.train_data)
        self.valid_data_size = self.valid_data.size(
            0) if not self.image_dataset else len(self.valid_data)
        self.test_data_size = self.test_data.size(
            0) if not self.image_dataset else len(self.test_data)

        # Visualization
        if args.use_tensorboard:
            self.tb = TensorBoard(args.model_dir)
        else:
            self.tb = None
        self.draw_network = utils.draw_network

        self.build_model()

        if self.args.load_path:
            self.load_model()

        shared_optimizer = _get_optimizer(self.args.shared_optim)
        controller_optimizer = _get_optimizer(self.args.controller_optim)

        # As fisher information, and it should be seen by this model, to get the loss.

        self.shared_optim = shared_optimizer(
            self.shared.parameters(),
            lr=self.shared_lr,
            weight_decay=self.args.shared_l2_reg)

        self.controller_optim = controller_optimizer(
            self.controller.parameters(), lr=self.args.controller_lr)

        self.ce = nn.CrossEntropyLoss()
        self.top_k_acc = top_k_accuracy

    def build_model(self):
        """Creates and initializes the shared and controller models."""
        if self.args.network_type == 'rnn':
            self.shared = models.RNN(self.args, self.dataset)
            self.controller = models.Controller(self.args)
        elif self.args.network_type == 'micro_cnn':
            self.shared = models.CNN(self.args, self.dataset)
            self.controller = models.CNNMicroController(self.args)
        else:
            raise NotImplementedError(f'Network type '
                                      f'`{self.args.network_type}` is not '
                                      f'defined')

        if self.args.num_gpu == 1:
            if torch.__version__ == '0.3.1':
                self.shared.cuda()
                self.controller.cuda()
            else:
                self.shared.to(self.device)
                self.controller.to(self.device)

        elif self.args.num_gpu > 1:
            raise NotImplementedError('`num_gpu > 1` is in progress')

    def train(self):
        """Cycles through alternately training the shared parameters and the
        controller, as described in Section 2.2, Training ENAS and Deriving
        Architectures, of the paper.

        From the paper (for Penn Treebank):

        - In the first phase, shared parameters omega are trained for 400
          steps, each on a minibatch of 64 examples.

        - In the second phase, the controller's parameters are trained for 2000
          steps.
        """
        if self.args.shared_initial_step > 0:
            self.train_shared(self.args.shared_initial_step)
            self.train_controller()

        for self.epoch in range(self.start_epoch, self.args.max_epoch):

            if self.epoch >= self.args.start_using_fisher:
                self.compute_fisher = True

            if self.args.set_fisher_zero_per_iter > 0 \
                    and self.epoch % self.args.set_fisher_zero_per_iter == 0:
                self.shared.set_fisher_zero()

            # 1. Training the shared parameters omega of the child models
            self.train_shared()

            # 2. Training the controller parameters theta
            if self.args.train_controller:
                if self.epoch < self.args.stop_training_controller:
                    self.train_controller()

            if self.epoch % self.args.save_epoch == 0:
                with _get_no_grad_ctx_mgr():
                    best_dag = self.derive()
                    self.evaluate(self.eval_data,
                                  best_dag,
                                  'val_best',
                                  max_num=self.args.batch_size * 100)
                self.save_model()

            if self.epoch >= self.args.shared_decay_after:
                utils.update_lr(self.shared_optim, self.shared_lr)

    def get_loss(self, inputs, targets, dags, **kwargs):
        """Computes the loss for the same batch for M models.

        This amounts to an estimate of the loss, which is turned into an
        estimate for the gradients of the shared model.

        We store, compute the new WPL.
        :param **kwargs: passed into self.shared(, such as hidden)
        """
        if not isinstance(dags, list):
            dags = [dags]

        loss = 0
        for dag in dags:
            output, hidden, extra_out = self.shared(inputs, dag, **kwargs)
            output_flat = output.view(-1, self.dataset.num_classes)
            sample_loss = (self.ce(output_flat, targets) /
                           self.args.shared_num_sample)

            # Get WPL part
            if self.compute_fisher:
                wpl = self.shared.compute_weight_plastic_loss_with_update_fisher(
                    dag)
                wpl = 0.5 * wpl
                loss += sample_loss + wpl
                rest_loss = wpl
            else:
                loss += sample_loss
                rest_loss = Variable(torch.zeros(1))
                # logger.info(f'Loss {loss.data[0]} = '
                #             f'sample_loss {sample_loss.data[0]}')

        #assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`'
        return loss, sample_loss, rest_loss, hidden, extra_out

    def train_shared(self, max_step=None):
        """Train the language model for 400 steps of minibatches of 64
        examples.

        Args:
            max_step: Used to run extra training steps as a warm-up.

        BPTT is truncated at 35 timesteps.

        For each weight update, gradients are estimated by sampling M models
        from the fixed controller policy, and averaging their gradients
        computed on a batch of training data.
        """
        valid_ppls = []
        valid_ppls_after = []

        model = self.shared
        model.train()
        self.controller.eval()

        hidden = self.shared.init_training(self.args.batch_size)
        v_hidden = self.shared.init_training(self.args.batch_size)

        if max_step is None:
            max_step = self.args.shared_max_step
        else:
            max_step = min(self.args.shared_max_step, max_step)

        abs_max_grad = 0
        abs_max_hidden_norm = 0
        step = 0
        raw_total_loss = 0
        total_loss = 0
        total_sample_loss = 0
        total_rest_loss = 0
        train_idx = 0
        valid_idx = 0

        def _run_shared_one_batch(inputs, targets, hidden, dags,
                                  raw_total_loss):
            # global abs_max_grad
            # global abs_max_hidden_norm
            # global raw_total_loss
            loss, sample_loss, rest_loss, hidden, extra_out = self.get_loss(
                inputs, targets, dags, hidden=hidden)

            # Detach the hidden
            # Because they are input from previous state.
            hidden = utils.detach(hidden)
            raw_total_loss += sample_loss.data / self.args.num_batch_per_iter
            penalty_loss = _apply_penalties(extra_out, self.args)
            loss += penalty_loss
            rest_loss += penalty_loss
            return loss, sample_loss, rest_loss, hidden, extra_out, raw_total_loss

        def _clip_gradient(abs_max_grad, abs_max_hidden_norm):

            h1tohT = extra_out['hiddens']
            new_abs_max_hidden_norm = utils.to_item(
                h1tohT.norm(dim=-1).data.max())
            if new_abs_max_hidden_norm > abs_max_hidden_norm:
                abs_max_hidden_norm = new_abs_max_hidden_norm
                logger.info(f'max hidden {abs_max_hidden_norm}')
            abs_max_grad = _check_abs_max_grad(abs_max_grad, model)
            torch.nn.utils.clip_grad_norm(model.parameters(),
                                          self.args.shared_grad_clip)
            return abs_max_grad, abs_max_hidden_norm

        def _evaluate_valid(dag):
            hidden_eval = self.shared.init_training(self.args.batch_size)
            inputs_eval, targets_eval = self.get_batch(self.valid_data,
                                                       0,
                                                       self.max_length,
                                                       volatile=True)
            _, valid_loss_eval, _, _, _ = self.get_loss(inputs_eval,
                                                        targets_eval,
                                                        dag,
                                                        hidden=hidden_eval)
            valid_loss_eval = utils.to_item(valid_loss_eval.data)
            valid_ppl_eval = math.exp(valid_loss_eval)
            # return valid_ppl_eval

        dags_eval = []
        while train_idx < self.train_data_size - 1 - 1:
            if step > max_step:
                break
            dags = self.controller.sample(self.args.shared_num_sample)
            dags_eval.append(dags[0])
            for b in range(0, self.args.num_batch_per_iter):
                # For each model, do the update for 30 batches.
                inputs, targets = self.get_batch(self.train_data, train_idx,
                                                 self.max_length)

                loss, sample_loss, rest_loss, hidden, extra_out, raw_total_loss = \
                    _run_shared_one_batch(
                        inputs, targets, hidden, dags, raw_total_loss)

                # update with complete logic
                # First, normally we compute one loss and do update accordingly.
                # if in the last batch, we compute the fisher information
                # based on two kinds of loss, complete or ce-loss only.
                self.shared_optim.zero_grad()

                # If it is the last training batch. Update the Fisher information
                if self.compute_fisher and (not self.args.shared_valid_fisher):
                    if b == self.args.num_batch_per_iter - 1:
                        sample_loss.backward()
                        if self.args.shared_ce_fisher:
                            self.shared.update_fisher(dags[0])
                            rest_loss.backward()
                        else:
                            rest_loss.backward()
                            self.shared.update_fisher(dags[0])
                    else:
                        loss.backward()
                else:
                    loss.backward()

                abs_max_grad, abs_max_hidden_norm = _clip_gradient(
                    abs_max_grad, abs_max_hidden_norm)

                self.shared_optim.step()

                total_loss += loss.data / self.args.num_batch_per_iter
                total_sample_loss += sample_loss.data / self.args.num_batch_per_iter
                total_rest_loss += rest_loss.data / self.args.num_batch_per_iter

                train_idx = ((train_idx + self.max_length) %
                             (self.train_data_size - 1))

            if self.epoch > self.args.start_evaluate_diff:
                valid_ppl_eval = _evaluate_valid(dags[0])
                valid_ppls.append(valid_ppl_eval)

            logger.info(
                f'Step {step}'
                f'Loss {utils.to_item(total_loss) / (step + 1):.5f} = '
                f'sample_loss {utils.to_item(total_sample_loss) / (step + 1):.5f} + '
                f'wpl {utils.to_item(total_rest_loss) / (step + 1):.5f}')

            if ((step % self.args.log_step) == 0) and (step > 0):
                self._summarize_shared_train(total_loss, raw_total_loss)
                raw_total_loss = 0
                total_loss = 0
                total_sample_loss = 0
                total_rest_loss = 0

            if self.compute_fisher:
                # Update with the validation dataset for fisher information after each step,
                # with update the optimal weights.
                v_inputs, v_targets = self.get_batch(self.valid_data,
                                                     valid_idx,
                                                     self.max_length)
                v_loss, v_sample_loss, _, v_hidden, v_extra_out, _ = _run_shared_one_batch(
                    v_inputs, v_targets, v_hidden, dags, 0)
                self.shared_optim.zero_grad()
                if self.args.shared_ce_fisher:
                    v_sample_loss.backward()
                else:
                    v_loss.backward()
                self.shared.update_fisher(dags[0], self.epoch)
                self.shared.update_optimal_weights()
                valid_idx = ((valid_idx + self.max_length) %
                             (self.valid_data_size - 1))

            step += 1
            self.shared_step += 1

        if self.epoch > self.args.start_evaluate_diff:
            for arch in dags_eval:
                valid_ppl_eval = _evaluate_valid(arch)
                valid_ppls_after.append(valid_ppl_eval)
                logger.info(f'valid_ppl {valid_ppl_eval}')
            diff = np.array(valid_ppls_after) - np.array(valid_ppls)
            logger.info(f'Mean_diff {np.mean(diff)}')
            logger.info(f'Max_diff {np.amax(diff)}')
            self.tb.scalar_summary(f'Mean difference', np.mean(diff),
                                   self.epoch)
            self.tb.scalar_summary(f'Max difference', np.amax(diff),
                                   self.epoch)
            self.tb.scalar_summary(f'Mean valid_ppl after training',
                                   np.mean(np.array(valid_ppls_after)),
                                   self.epoch)
            self.tb.scalar_summary(f'Mean valid_ppl before training',
                                   np.mean(np.array(valid_ppls)), self.epoch)
            self.tb.scalar_summary(f'std_diff', np.std(np.array(diff)),
                                   self.epoch)

    def get_reward(self, dags, entropies, hidden, valid_idx=None):
        """
        Computes the reward of a single sampled model or multiple on a minibatch of
        validation data.

        """
        if not isinstance(entropies, np.ndarray):
            entropies = entropies.data.cpu().numpy()

        if valid_idx is None:
            valid_idx = 0

        inputs, targets = self.get_batch(self.valid_data,
                                         valid_idx,
                                         self.max_length,
                                         volatile=True)
        _, valid_loss, _, hidden, _ = self.get_loss(inputs,
                                                    targets,
                                                    dags,
                                                    hidden=hidden)
        valid_loss = utils.to_item(valid_loss.data)

        valid_ppl = math.exp(valid_loss)

        if self.args.ppl_square:
            R = self.args.reward_c / valid_ppl**2
        else:
            R = self.args.reward_c / valid_ppl

        if self.args.entropy_mode == 'reward':
            rewards = R + self.args.entropy_coeff * entropies
        elif self.args.entropy_mode == 'regularizer':
            rewards = R * np.ones_like(entropies)
        else:
            raise NotImplementedError(
                f'Unkown entropy mode: {self.args.entropy_mode}')

        return rewards, hidden

    def train_controller(self):
        """Fixes the shared parameters and updates the controller parameters.

        The controller is updated with a score function gradient estimator
        (i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl
        is computed on a minibatch of validation data.

        A moving average baseline is used.

        The controller is trained for 2000 steps per epoch (i.e.,
        first (Train Shared) phase -> second (Train Controller) phase).
        """
        model = self.controller
        model.train()

        avg_reward_base = None
        baseline = None
        adv_history = []
        entropy_history = []
        reward_history = []
        hidden = self.shared.init_training(self.args.batch_size)
        total_loss = 0
        valid_idx = 0

        for step in range(self.args.controller_max_step):
            # print("************ train controller ****************")
            # sample models
            dags, log_probs, entropies = self.controller.sample(
                batch_size=self.args.policy_batch_size, with_details=True)

            # calculate reward
            np_entropies = entropies.data.cpu().numpy()
            with _get_no_grad_ctx_mgr():
                rewards, hidden = self.get_reward(dags, np_entropies, hidden,
                                                  valid_idx)

            # discount
            if 1 > self.args.discount > 0:
                rewards = discount(rewards, self.args.discount)

            reward_history.extend(rewards)
            entropy_history.extend(np_entropies)

            # moving average baseline
            if baseline is None:
                baseline = rewards
            else:
                decay = self.args.ema_baseline_decay
                baseline = decay * baseline + (1 - decay) * rewards

            adv = rewards - baseline
            adv_history.extend(adv)

            # policy loss
            loss = -log_probs * utils.get_variable(
                adv, self.cuda, requires_grad=False)
            if self.args.entropy_mode == 'regularizer':
                loss -= self.args.entropy_coeff * entropies

            loss = loss.sum()  # or loss.mean()

            # update
            self.controller_optim.zero_grad()
            loss.backward()

            if self.args.controller_grad_clip > 0:
                torch.nn.utils.clip_grad_norm(model.parameters(),
                                              self.args.controller_grad_clip)
            self.controller_optim.step()

            total_loss += utils.to_item(loss.data)

            if ((step % self.args.log_step) == 0) and (step > 0):
                self._summarize_controller_train(total_loss, adv_history,
                                                 entropy_history,
                                                 reward_history,
                                                 avg_reward_base, dags)

                reward_history, adv_history, entropy_history = [], [], []
                total_loss = 0

            self.controller_step += 1

            prev_valid_idx = valid_idx
            valid_idx = ((valid_idx + self.max_length) %
                         (self.valid_data_size - 1))
            if prev_valid_idx > valid_idx:
                hidden = self.shared.init_training(self.args.batch_size)

    def evaluate(self, source, dag, name, batch_size=1, max_num=None):
        """Evaluate on the validation set.
        """
        self.shared.eval()
        self.controller.eval()

        if self.image_dataset:
            data = source
        else:
            data = source[:max_num * self.max_length]

        total_loss = 0
        hidden = self.shared.init_training(batch_size)

        pbar = range(0, self.valid_data_size - 1, self.max_length)
        for count, idx in enumerate(pbar):
            inputs, targets = self.get_batch(data, idx, volatile=True)
            output, hidden, _ = self.shared(inputs,
                                            dag,
                                            hidden=hidden,
                                            is_train=False)
            output_flat = output.view(-1, self.dataset.num_classes)
            total_loss += len(inputs) * self.ce(output_flat, targets).data
            hidden = utils.detach(hidden)
            ppl = math.exp(
                utils.to_item(total_loss) / (count + 1) / self.max_length)

        val_loss = utils.to_item(total_loss) / len(data)
        ppl = math.exp(val_loss)

        self.tb.scalar_summary(f'eval/{name}_loss', val_loss, self.epoch)
        self.tb.scalar_summary(f'eval/{name}_ppl', ppl, self.epoch)
        logger.info(f'eval | loss: {val_loss:8.2f} | ppl: {ppl:8.2f}')

    def derive(self, sample_num=None, valid_idx=0):

        if sample_num is None:
            sample_num = self.args.derive_num_sample

        dags, _, entropies = self.controller.sample(sample_num,
                                                    with_details=True)

        max_R = 0
        best_dag = None
        for dag in dags:
            if self.image_dataset:
                R, _ = self.get_reward([dag], entropies, valid_idx)
            else:
                hidden = self.shared.init_training(self.args.batch_size)
                R, _ = self.get_reward(dag, entropies, hidden, valid_idx)
            if R.max() > max_R:
                max_R = R.max()
                best_dag = dag

        logger.info(f'derive | max_R: {max_R:8.6f}')
        fname = (f'{self.epoch:03d}-{self.controller_step:06d}-'
                 f'{max_R:6.4f}-best.png')
        path = os.path.join(self.args.model_dir, 'networks', fname)
        success = self.draw_network(best_dag, path)
        if success:
            self.tb.image_summary('derive/best', [path], self.epoch)

        return best_dag

    def reset_dataloader_by_name(self, name):
        """ Works for only reset _DataLoaderIter by DataLoader with name """
        try:
            new_iter = wrap_iterator_with_name(
                iter(getattr(self, f'_{name}_data')), name)
            setattr(self, f'{name}_data', new_iter)
        except Exception as e:
            print(e)
        return new_iter

    @property
    def shared_lr(self):
        degree = max(self.epoch - self.args.shared_decay_after + 1, 0)
        return self.args.shared_lr * (self.args.shared_decay**degree)

    @property
    def controller_lr(self):
        return self.args.controller_lr

    def get_batch(self, source, idx, length=None, volatile=False):
        # code from
        # https://github.com/pytorch/examples/blob/master/word_language_model/main.py

        if not self.image_dataset:
            length = min(length if length else self.max_length,
                         len(source) - 1 - idx)
            data = Variable(source[idx:idx + length], volatile=volatile)
            target = Variable(source[idx + 1:idx + 1 + length].view(-1),
                              volatile=volatile)
        else:
            # Try the dataloader logic.
            # type is _DataLoaderIter

            try:
                data, target = next(source)
            except StopIteration as e:
                print(f'{e}')
                name = source.name
                source = self.reset_dataloader_by_name(name)
                data, target = next(source)

            # data.to(self.device)
            return data.to(self.device), target.to(self.device)
        return data, target

    @property
    def shared_path(self):
        return f'{self.args.model_dir}/shared_epoch{self.epoch}_step{self.shared_step}.pth'

    @property
    def controller_path(self):
        return f'{self.args.model_dir}/controller_epoch{self.epoch}_step{self.controller_step}.pth'

    def get_saved_models_info(self):
        paths = glob.glob(os.path.join(self.args.model_dir, '*.pth'))
        paths.sort()

        def get_numbers(items, delimiter, idx, replace_word, must_contain=''):
            return list(
                set([
                    int(name.split(delimiter)[idx].replace(replace_word, ''))
                    for name in basenames if must_contain in name
                ]))

        basenames = [
            os.path.basename(path.rsplit('.', 1)[0]) for path in paths
        ]
        epochs = get_numbers(basenames, '_', 1, 'epoch')
        shared_steps = get_numbers(basenames, '_', 2, 'step', 'shared')
        controller_steps = get_numbers(basenames, '_', 2, 'step', 'controller')

        epochs.sort()
        shared_steps.sort()
        controller_steps.sort()

        return epochs, shared_steps, controller_steps

    def save_model(self):
        torch.save(self.shared.state_dict(), self.shared_path)
        logger.info(f'[*] SAVED: {self.shared_path}')

        torch.save(self.controller.state_dict(), self.controller_path)
        logger.info(f'[*] SAVED: {self.controller_path}')

        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        for epoch in epochs[:-self.args.max_save_num]:
            paths = glob.glob(
                os.path.join(self.args.model_dir, f'*_epoch{epoch}_*.pth'))

            for path in paths:
                utils.remove_file(path)

    def load_model(self):
        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        if len(epochs) == 0:
            logger.info(f'[!] No checkpoint found in {self.args.model_dir}...')
            return

        self.epoch = self.start_epoch = max(epochs)
        self.shared_step = max(shared_steps)
        self.controller_step = max(controller_steps)

        if self.args.num_gpu == 0:
            map_location = lambda storage, loc: storage
        else:
            map_location = None

        self.shared.load_state_dict(
            torch.load(self.shared_path, map_location=map_location))
        logger.info(f'[*] LOADED: {self.shared_path}')

        self.controller.load_state_dict(
            torch.load(self.controller_path, map_location=map_location))
        logger.info(f'[*] LOADED: {self.controller_path}')

    def _summarize_controller_train(self, total_loss, adv_history,
                                    entropy_history, reward_history,
                                    avg_reward_base, dags):
        """Logs the controller's progress for this training epoch."""
        cur_loss = total_loss / self.args.log_step

        avg_adv = np.mean(adv_history)
        avg_entropy = np.mean(entropy_history)
        avg_reward = np.mean(reward_history)

        if avg_reward_base is None:
            avg_reward_base = avg_reward

        logger.info(f'| epoch {self.epoch:3d} | lr {self.controller_lr:.5f} '
                    f'| R {avg_reward:.5f} | entropy {avg_entropy:.4f} '
                    f'| loss {cur_loss:.5f}')

        # Tensorboard
        if self.tb is not None:
            self.tb.scalar_summary('controller/loss', cur_loss,
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward', avg_reward,
                                   self.controller_step)
            self.tb.scalar_summary('controller/std/reward',
                                   np.std(reward_history),
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward-B_per_epoch',
                                   avg_reward - avg_reward_base,
                                   self.controller_step)
            self.tb.scalar_summary('controller/entropy', avg_entropy,
                                   self.controller_step)
            self.tb.scalar_summary('controller/adv', avg_adv,
                                   self.controller_step)

            paths = []
            for dag in dags:
                fname = (f'{self.epoch:03d}-{self.controller_step:06d}-'
                         f'{avg_reward:6.4f}.png')
                path = os.path.join(self.args.model_dir, 'networks', fname)
                self.draw_network(dag, path)
                paths.append(path)

            self.tb.image_summary('controller/sample', paths,
                                  self.controller_step)

    def _summarize_shared_train(self, total_loss, raw_total_loss):
        """Logs a set of training steps."""
        cur_loss = utils.to_item(total_loss) / self.args.log_step
        cur_raw_loss = utils.to_item(raw_total_loss) / self.args.log_step
        try:
            ppl = math.exp(cur_raw_loss)
        except RuntimeError as e:
            print(f"Got error {e}")

        logger.info(f'| epoch {self.epoch:3d} '
                    f'| lr {self.shared_lr:4.2f} '
                    f'| raw loss {cur_raw_loss:.2f} '
                    f'| loss {cur_loss:.2f} '
                    f'| ppl {ppl:8.2f}')

        # Tensorboard
        if self.tb is not None:
            self.tb.scalar_summary('shared/loss', cur_loss, self.shared_step)
            self.tb.scalar_summary('shared/perplexity', ppl, self.shared_step)
예제 #7
0
class Trainer(object):
    """A class to wrap training code."""
    def __init__(self, args, dataset):
        """Constructor for training algorithm.

        Args:
            args: From command line, picked up by `argparse`.
            dataset: Currently only `data.text.Corpus` is supported.

        Initializes:
            - Data: train, val and test.
            - Model: shared and controller.
            - Inference: optimizers for shared and controller parameters.
            - Criticism: cross-entropy loss for training the shared model.
        """
        self.args = args
        self.controller_step = 0
        self.cuda = args.cuda
        self.dataset = dataset
        self.epoch = 0
        self.shared_step = 0
        self.start_epoch = 0

        logger.info('regularizing:')
        for regularizer in [('activation regularization',
                             self.args.activation_regularization),
                            ('temporal activation regularization',
                             self.args.temporal_activation_regularization),
                            ('norm stabilizer regularization',
                             self.args.norm_stabilizer_regularization)]:
            if regularizer[1]:
                logger.info('{0}'.format(regularizer[0]))

        self.train_data = utils.batchify(dataset.train, args.batch_size,
                                         self.cuda)
        # NOTE(brendan): The validation set data is batchified twice
        # separately: once for computing rewards during the Train Controller
        # phase (valid_data, batch size == 64), and once for evaluating ppl
        # over the entire validation set (eval_data, batch size == 1)
        self.valid_data = utils.batchify(dataset.valid, args.batch_size,
                                         self.cuda)
        self.eval_data = utils.batchify(dataset.valid, args.test_batch_size,
                                        self.cuda)
        self.test_data = utils.batchify(dataset.test, args.test_batch_size,
                                        self.cuda)

        self.max_length = self.args.shared_rnn_max_length  # default=35

        if args.use_tensorboard:
            self.tb = TensorBoard(args.model_dir)
        else:
            self.tb = None
        self.build_model()  # 创建一个模型存入self.shared中,这里可以是RNN或CNN,再创建一个Controler

        if self.args.load_path:
            self.load_model()

        shared_optimizer = _get_optimizer(self.args.shared_optim)
        controller_optimizer = _get_optimizer(self.args.controller_optim)

        self.shared_optim = shared_optimizer(
            self.shared.parameters(),
            lr=self.shared_lr,
            weight_decay=self.args.shared_l2_reg)

        self.controller_optim = controller_optimizer(
            self.controller.parameters(), lr=self.args.controller_lr)

        self.ce = nn.CrossEntropyLoss()

    def build_model(self):
        """Creates and initializes the shared and controller models."""
        if self.args.network_type == 'rnn':
            self.shared = models.RNN(self.args, self.dataset)
        elif self.args.network_type == 'cnn':
            self.shared = models.CNN(self.args, self.dataset)
        else:
            raise NotImplementedError(
                'Network type `{0}` is not defined'.format(
                    self.args.network_type))
        self.controller = models.Controller(
            self.args
        )  # 构建了一个orward:Embedding(130,100)->lstm(100,100)->decoder的列表,对应25个decoder

        if self.args.num_gpu == 1:
            self.shared.cuda()
            self.controller.cuda()
        elif self.args.num_gpu > 1:
            raise NotImplementedError('`num_gpu > 1` is in progress')

    def train(self, single=False):
        """Cycles through alternately training the shared parameters and the
        controller, as described in Section 2.2, Training ENAS and Deriving
        Architectures, of the paper.

        From the paper (for Penn Treebank):

        - In the first phase, shared parameters omega are trained for 400
          steps, each on a minibatch of 64 examples.

        - In the second phase, the controller's parameters are trained for 2000
          steps.
          
        Args:
            single (bool): If True it won't train the controller and use the
                           same dag instead of derive().
        """
        dag = utils.load_dag(self.args) if single else None  # 初始训练dag=None

        if self.args.shared_initial_step > 0:  # self.args.shared_initial_step default=0
            self.train_shared(self.args.shared_initial_step)
            self.train_controller()

        for self.epoch in range(
                self.start_epoch,
                self.args.max_epoch):  # start_epoch=0,max_epoch=150
            # 1. Training the shared parameters omega of the child models
            # 训练RNN,先用Controller随机生成一个dag,然后用这个dag构建一个RNNcell,然后用这个RNNcell去做下一个词预测,得到loss
            self.train_shared(dag=dag)

            # 2. Training the controller parameters theta
            if not single:
                self.train_controller()

            if self.epoch % self.args.save_epoch == 0:
                with _get_no_grad_ctx_mgr():
                    best_dag = dag if dag else self.derive()
                    self.evaluate(self.eval_data,
                                  best_dag,
                                  'val_best',
                                  max_num=self.args.batch_size * 100)
                self.save_model()
            #应该是逐渐降低学习率
            if self.epoch >= self.args.shared_decay_after:
                utils.update_lr(self.shared_optim, self.shared_lr)

    def get_loss(self, inputs, targets, hidden, dags):
        """
        :param inputs:输入数据,[35,64]
        :param targets: 目标数据(相当于标签)[35,64] 输入的词后移一个词
        :param hidden: 隐藏层参数
        :param dags: RNN 的cell结构
        :return: decoded(35,64,10000),hidden(64,1000),extra_out{dropped_output(35,64,1000),h1tohT(35,64,1000),raw_output(35,64,1000)
        """
        """Computes the loss for the same batch for M models.

        This amounts to an estimate of the loss, which is turned into an
        estimate for the gradients of the shared model.
        """
        if not isinstance(dags, list):
            dags = [dags]

        loss = 0
        for dag in dags:
            # decoded(35,64,10000),hidden(64,1000),extra_out{dropped_output(35,64,1000),h1tohT(35,64,1000),raw_output(35,64,1000)
            output, hidden, extra_out = self.shared(
                inputs, dag, hidden=hidden)  # RNN.forward
            output_flat = output.view(-1,
                                      self.dataset.num_tokens)  # (2240,10000)
            # self.ce=nn.CrossEntropyLoss()  target(2240)  shared_num_sample=1
            sample_loss = (self.ce(output_flat, targets) /
                           self.args.shared_num_sample)
            loss += sample_loss

        assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`'
        return loss, hidden, extra_out

    def train_shared(self, max_step=None, dag=None):
        """Train the language model for 400 steps of minibatches of 64
        examples.

        Args:
            max_step: Used to run extra training steps as a warm-up.
            dag: If not None, is used instead of calling sample().

        BPTT is truncated at 35 timesteps.  #基于时间的反向传播算法BPTT(Back Propagation Trough Time)

        For each weight update, gradients are estimated by sampling M models
        from the fixed controller policy, and averaging their gradients
        computed on a batch of training data.
        """
        model = self.shared  # model.RNN
        model.train(
        )  # set RNN.training属性为true 即当前训练的是RNN而不训练Controller https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.train
        self.controller.eval(
        )  # Sets the module in evaluation mode. This is equivalent with self.train(False).
        # 功能:初始化variable,即全零的Tensor
        hidden = self.shared.init_hidden(self.args.batch_size)

        if max_step is None:
            max_step = self.args.shared_max_step  # shared_max_step=150
        else:
            max_step = min(self.args.shared_max_step, max_step)

        abs_max_grad = 0
        abs_max_hidden_norm = 0
        step = 0
        raw_total_loss = 0  # 用于统计结果的,和计算过程无关
        total_loss = 0
        train_idx = 0
        # TODO(brendan): Why - 1 - 1?为什么-1-1?
        # TODO(为什么-1-1)这里的train_idx是批次的编号,一共14524个batch(每个batch有64个词)为了训练输入数据不可能取最后一个batch
        # TODO(为什么-1-1)因为如果是最后一个batch就没有target了,因此最后一个batch是倒数第二个,而倒数第二个的下标是 size-2
        # self.train_data.size(0)   14524
        while train_idx < self.train_data.size(0) - 1 - 1:
            if step > max_step:
                break
            # Controller负责sample一个dag出来,是一个list,里面有一个defaultdict,存储了dag的连接信息
            # 这一步只是提取Controller的值,并没有训练,初始的时候也是随机得出来的一个dag
            dags = dag if dag else self.controller.sample(
                batch_size=self.args.shared_num_sample
            )  # shared_num_sample:default=1
            # 提取一个max_length长度的数据集(35,64),35个批次,每个批次64个词,组成一个训练批次
            # input是训练数据,target是每个输入的词后面的词,用于训练RNN的
            inputs, targets = self.get_batch(self.train_data, train_idx,
                                             self.max_length)  # max_length=35
            # get_loss完成了由dag生成的RNNcell的前向计算
            loss, hidden, extra_out = self.get_loss(inputs, targets, hidden,
                                                    dags)
            # Detaches the Tensor from the graph that created it, making it a leaf. Views cannot be detached in-place.
            hidden.detach_()
            raw_total_loss += loss.data
            # 根据命令行参数加一下正则惩罚项
            loss += _apply_penalties(extra_out, self.args)

            # update
            self.shared_optim.zero_grad()
            loss.backward()  # 反向更新

            h1tohT = extra_out['hiddens']
            # 和日志有关,和计算无关
            new_abs_max_hidden_norm = utils.to_item(
                h1tohT.norm(dim=-1).data.max())
            if new_abs_max_hidden_norm > abs_max_hidden_norm:
                abs_max_hidden_norm = new_abs_max_hidden_norm
                logger.info('max hidden {0}'.format(abs_max_hidden_norm))
            # 函数的功能是获取Tensor图中的最大梯度,来检测是否出现梯度爆炸,但好像后面没有使用
            abs_max_grad = _check_abs_max_grad(abs_max_grad, model)
            # Clips gradient norm of an iterable of parameters.
            # The norm is computed over all gradients together, as if they were concatenated into a single vector.
            # Gradients are modified in-place.
            torch.nn.utils.clip_grad_norm(
                model.parameters(),
                self.args.shared_grad_clip)  # shared_grad_clip=0.25
            self.shared_optim.step()  # Performs a single optimization step.

            total_loss += loss.data
            # 和log有关
            if ((step % self.args.log_step) == 0) and (step > 0):
                self._summarize_shared_train(total_loss, raw_total_loss)
                raw_total_loss = 0
                total_loss = 0

            step += 1
            self.shared_step += 1
            train_idx += self.max_length  # max_length:35,下一个batch

    def get_reward(self, dag, entropies, hidden, valid_idx=0):
        """Computes the perplexity of a single sampled model on a minibatch of
        validation data.
        计算模型的PPL:每个词的条件预测概率(即已知前n个词预测第n+1个词的概率)的累积的倒数开N(全体词的数量)次方
        """
        if not isinstance(entropies, np.ndarray):
            entropies = entropies.data.cpu().numpy()

        inputs, targets = self.get_batch(self.valid_data,
                                         valid_idx,
                                         self.max_length,
                                         volatile=True)
        valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden,
                                              dag)  #RNN.forward
        valid_loss = utils.to_item(valid_loss.data)

        valid_ppl = math.exp(valid_loss)  #计算PPL

        # TODO: we don't know reward_c
        if self.args.ppl_square:  #default:false
            # TODO: but we do know reward_c=80 in the previous paper
            R = self.args.reward_c / valid_ppl**2
        else:
            R = self.args.reward_c / valid_ppl  #这个值的作用在NAS(Zoph and Le, 2017) page 8 states that c is a constant

        if self.args.entropy_mode == 'reward':  #entroy_mode:default:reward
            rewards = R + self.args.entropy_coeff * entropies  # entropy_coeff:default=1e-4
        elif self.args.entropy_mode == 'regularizer':
            rewards = R * np.ones_like(entropies)
        else:
            raise NotImplementedError('Unkown entropy mode: {0}'.format(
                self.args.entropy_mode))

        return rewards, hidden

    def train_controller(self):
        """Fixes the shared parameters and updates the controller parameters.

        The controller is updated with a score function gradient estimator
        (i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl
        is computed on a minibatch of validation data.

        A moving average baseline is used.

        The controller is trained for 2000 steps per epoch (i.e.,
        first (Train Shared) phase -> second (Train Controller) phase).
        """
        model = self.controller
        model.train()  # 设置Controller的train属性为true,当前训练Controller
        # 这里为什么不调用hared.eval()? 这是因为会导致Controller的loss一直为零。
        # self.shared.eval(),上面的解释应该是Brendon这个人测试之后的结论

        avg_reward_base = None
        baseline = None
        # 这几个是用于统计信息的
        adv_history = []
        entropy_history = []
        reward_history = []

        hidden = self.shared.init_hidden(self.args.batch_size)
        total_loss = 0
        valid_idx = 0
        for step in range(self.args.controller_max_step):  #controller_max_step
            # sample models
            #dags:list([1])(defaultdict([25])),log_probs:Tensor.size([23]),entropies:Tensor.size([23])交叉熵:-ylogy
            dags, log_probs, entropies = self.controller.sample(
                with_details=True)

            # calculate reward
            np_entropies = entropies.data.cpu().numpy()
            # NOTE(brendan): No gradients should be backpropagated to the
            # shared model during controller training, obviously.
            """
            with 语句实质是上下文管理。
            1、上下文管理协议。包含方法__enter__() 和 __exit__(),支持该协议对象要实现这两个方法。
            2、上下文管理器,定义执行with语句时要建立的运行时上下文,负责执行with语句块上下文中的进入与退出操作。
            3、进入上下文的时候执行__enter__方法,如果设置as var语句,var变量接受__enter__()方法返回值。
            4、如果运行时发生了异常,就退出上下文管理器。调用管理器__exit__方法。
            """
            # 创建了一个torch.no_grad()的上下文,执行get_reward的时候是不需要计算梯度的,执行完get_reward在恢复计算梯度模式
            with _get_no_grad_ctx_mgr():
                rewards, hidden = self.get_reward(dags, np_entropies, hidden,
                                                  valid_idx)

            # discount  默认未启用
            if 1 > self.args.discount > 0:  #discout:default=1
                rewards = discount(rewards, self.args.discount)

            reward_history.extend(rewards)
            entropy_history.extend(np_entropies)

            # moving average baseline
            if baseline is None:
                baseline = rewards
            else:
                decay = self.args.ema_baseline_decay  #****ema_baseline_decay:default=0.95  very important
                baseline = decay * baseline + (1 - decay) * rewards

            adv = rewards - baseline
            adv_history.extend(adv)

            # policy loss
            loss = -log_probs * utils.get_variable(
                adv, self.cuda, requires_grad=False)
            if self.args.entropy_mode == 'regularizer':  #entropy_mode:default='reward'
                loss -= self.args.entropy_coeff * entropies

            loss = loss.sum()  # or loss.mean()

            # update
            self.controller_optim.zero_grad()
            loss.backward()

            if self.args.controller_grad_clip > 0:
                torch.nn.utils.clip_grad_norm(model.parameters(),
                                              self.args.controller_grad_clip)
            self.controller_optim.step()

            total_loss += utils.to_item(loss.data)

            if ((step % self.args.log_step) == 0) and (step > 0):
                self._summarize_controller_train(total_loss, adv_history,
                                                 entropy_history,
                                                 reward_history,
                                                 avg_reward_base, dags)

                reward_history, adv_history, entropy_history = [], [], []
                total_loss = 0

            self.controller_step += 1

            prev_valid_idx = valid_idx
            valid_idx = ((valid_idx + self.max_length) %
                         (self.valid_data.size(0) - 1))
            # NOTE(brendan): Whenever we wrap around to the beginning of the
            # validation data, we reset the hidden states.
            if prev_valid_idx > valid_idx:
                hidden = self.shared.init_hidden(self.args.batch_size)

    def evaluate(self, source, dag, name, batch_size=1, max_num=None):
        """Evaluate on the validation set.

        NOTE(brendan): We should not be using the test set to develop the
        algorithm (basic machine learning good practices).
        """
        self.shared.eval()
        self.controller.eval()

        data = source[:max_num * self.max_length]

        total_loss = 0
        hidden = self.shared.init_hidden(batch_size)

        pbar = range(0, data.size(0) - 1, self.max_length)
        for count, idx in enumerate(pbar):
            inputs, targets = self.get_batch(data, idx, volatile=True)
            output, hidden, _ = self.shared(inputs,
                                            dag,
                                            hidden=hidden,
                                            is_train=False)
            output_flat = output.view(-1, self.dataset.num_tokens)
            total_loss += len(inputs) * self.ce(output_flat, targets).data
            hidden.detach_()
            ppl = math.exp(
                utils.to_item(total_loss) / (count + 1) / self.max_length)

        val_loss = utils.to_item(total_loss) / len(data)
        ppl = math.exp(val_loss)

        self.tb.scalar_summary('eval/{0}_loss'.format(name), val_loss,
                               self.epoch)
        self.tb.scalar_summary('eval/{0}_ppl'.format(name), ppl, self.epoch)
        logger.info('eval | loss: {0:8.2f} | ppl: {1:8.2f}'.format(
            val_loss, ppl))

    def derive(self, sample_num=None, valid_idx=0):
        """TODO(brendan): We are always deriving based on the very first batch
        of validation data? This seems wrong...
        """
        hidden = self.shared.init_hidden(self.args.batch_size)

        if sample_num is None:
            sample_num = self.args.derive_num_sample

        dags, _, entropies = self.controller.sample(sample_num,
                                                    with_details=True)

        max_R = 0
        best_dag = None
        for dag in dags:
            R, _ = self.get_reward(dag, entropies, hidden, valid_idx)
            if R.max() > max_R:
                max_R = R.max()
                best_dag = dag

        logger.info('derive | max_R: {0:8.6f}'.format(max_R))
        fname = ('{0:03d}-{1:06d}-{2:6.4f}-best.png'.format(
            self.epoch, self.controller_step, max_R))
        path = os.path.join(self.args.model_dir, 'networks', fname)
        #utils.draw_network(best_dag, path)
        #self.tb.image_summary('derive/best', [path], self.epoch)

        return best_dag

    @property
    def shared_lr(self):
        degree = max(self.epoch - self.args.shared_decay_after + 1, 0)
        return self.args.shared_lr * (self.args.shared_decay**degree)

    @property  #将类方法转换为类属性,可以用 . 直接获取属性值或者对属性进行赋值
    def controller_lr(self):
        return self.args.controller_lr

    def get_batch(self, source, idx, length=None, volatile=False):
        """
        这个函数的作用是从数据集中取得length长度的数据组成一个Variable(这个操作在pytorch中已经过时了,可以直接使用Tensor来生成计算,而不用
        再使用Variable来封装Tensor来计算
        这里的batch指的是取词窗口组成的batch,length是最多取多少个batch_size的词
        :param source:数据集train_data
        :param idx: 当前数据样本索引值
        :param length:max_length=35?
        :param volatile(易变的):Volatile is recommended for purely inference mode, when you’re sure you won’t be even calling .backward()
        设定volatie选项为true的话则只是取值模式,而不会进行反向计算
        :return:
        """
        # code from
        # https://github.com/pytorch/examples/blob/master/word_language_model/main.py
        length = min(length if length else self.max_length,
                     len(source) - 1 - idx)
        #UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.
        data = Variable(source[idx:idx + length],
                        volatile=volatile)  # shape(35,64) 取35个批次,每个批次64个词
        target = Variable(source[idx + 1:idx + 1 + length].view(-1),
                          volatile=volatile)  # view(35,64)->(2240)
        # 这里target=data+1的意思是从data中推断下一个词
        return data, target

    @property
    def shared_path(self):
        return '{0}/shared_epoch{1:d}_step{2:d}.pth'.format(
            self.args.model_dir, self.epoch, self.shared_step)

    @property
    def controller_path(self):
        return '{}/controller_epoch{}_step{}.pth'.format(
            self.args.model_dir, self.epoch, self.controller_step)

    def get_saved_models_info(self):
        paths = glob.glob(os.path.join(self.args.model_dir, '*.pth'))
        paths.sort()

        def get_numbers(items, delimiter, idx, replace_word, must_contain=''):
            return list(
                set([
                    int(name.split(delimiter)[idx].replace(replace_word, ''))
                    for name in basenames if must_contain in name
                ]))

        basenames = [
            os.path.basename(path.rsplit('.', 1)[0]) for path in paths
        ]
        epochs = get_numbers(basenames, '_', 1, 'epoch')
        shared_steps = get_numbers(basenames, '_', 2, 'step', 'shared')
        controller_steps = get_numbers(basenames, '_', 2, 'step', 'controller')

        epochs.sort()
        shared_steps.sort()
        controller_steps.sort()

        return epochs, shared_steps, controller_steps

    def save_model(self):
        torch.save(self.shared.state_dict(), self.shared_path)
        logger.info('[*] SAVED: {0}'.format(self.shared_path))

        torch.save(self.controller.state_dict(), self.controller_path)
        logger.info('[*] SAVED: {0}'.format(self.controller_path))

        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        for epoch in epochs[:-self.args.max_save_num]:
            paths = glob.glob(
                os.path.join(self.args.model_dir,
                             '*_epoch{0}_*.pth'.format(epoch)))

            for path in paths:
                utils.remove_file(path)

    def load_model(self):
        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        if len(epochs) == 0:
            logger.info('[!] No checkpoint found in {0}...'.format(
                self.args.model_dir))
            return

        self.epoch = self.start_epoch = max(epochs)
        self.shared_step = max(shared_steps)
        self.controller_step = max(controller_steps)

        if self.args.num_gpu == 0:
            map_location = lambda storage, loc: storage
        else:
            map_location = None

        self.shared.load_state_dict(
            torch.load(self.shared_path, map_location=map_location))
        logger.info('[*] LOADED: {0}'.format(self.shared_path))

        self.controller.load_state_dict(
            torch.load(self.controller_path, map_location=map_location))
        logger.info('[*] LOADED: {0}'.format(self.controller_path))

    def _summarize_controller_train(self, total_loss, adv_history,
                                    entropy_history, reward_history,
                                    avg_reward_base, dags):
        """Logs the controller's progress for this training epoch."""
        cur_loss = total_loss / self.args.log_step

        avg_adv = np.mean(adv_history)
        avg_entropy = np.mean(entropy_history)
        avg_reward = np.mean(reward_history)

        if avg_reward_base is None:
            avg_reward_base = avg_reward

        logger.info(
            '| epoch {0:3d} | lr {1:.5f} | R {2:.5f} | entropy {3:.4f} | loss {:.5f}'
            .format(self.epoch, self.controller_lr, avg_reward, avg_entropy,
                    cur_loss))

        # Tensorboard
        if self.tb is not None:
            self.tb.scalar_summary('controller/loss', cur_loss,
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward', avg_reward,
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward-B_per_epoch',
                                   avg_reward - avg_reward_base,
                                   self.controller_step)
            self.tb.scalar_summary('controller/entropy', avg_entropy,
                                   self.controller_step)
            self.tb.scalar_summary('controller/adv', avg_adv,
                                   self.controller_step)

            paths = []
            for dag in dags:
                fname = ('{0:03d}-{1:06d}-{2:6.4f}.png'.format(
                    self.epoch, self.controller_step, avg_reward))
                path = os.path.join(self.args.model_dir, 'networks', fname)
                utils.draw_network(dag, path)
                paths.append(path)

            self.tb.image_summary('controller/sample', paths,
                                  self.controller_step)

    def _summarize_shared_train(self, total_loss, raw_total_loss):
        """Logs a set of training steps."""
        cur_loss = utils.to_item(total_loss) / self.args.log_step
        # NOTE(brendan): The raw loss, without adding in the activation
        # regularization terms, should be used to compute ppl.
        cur_raw_loss = utils.to_item(raw_total_loss) / self.args.log_step
        ppl = math.exp(cur_raw_loss)

        logger.info(
            '| epoch {0:3d} | lr {1:4.2f} | raw loss {2:.2f} | loss {3:.2f} | ppl {4:8.2f}'
            .format(self.epoch, self.shared_lr, cur_raw_loss, cur_loss, ppl))

        # Tensorboard
        if self.tb is not None:
            self.tb.scalar_summary('shared/loss', cur_loss, self.shared_step)
            self.tb.scalar_summary('shared/perplexity', ppl, self.shared_step)
예제 #8
0
def main(args):

    # Step 1: init data folders
    '''if os.path.exists('save_state/'+args.regime+'/normalization_stats.pkl'):
        print('Loading normalization stats')
        x_mean, x_sd = misc.load_file('save_state/'+args.regime+'/normalization_stats.pkl')
    else:
        x_mean, x_sd = preprocess.save_normalization_stats(args.regime)
        print('x_mean: %.3f, x_sd: %.3f' % (x_mean, x_sd))'''
    data_dir = args.datapath
    data_files = []
    for x in os.listdir(data_dir):
        for y in os.listdir(data_dir + x):
            data_files.append(data_dir + x + "/" + y)
    test_files = [
        data_file for data_file in data_files
        if 'val' in data_file and 'npz' in data_file
    ]

    train_files = [
        data_file for data_file in data_files
        if 'train' in data_file and 'npz' in data_file
    ]

    print("train_num:", len(train_files), "test_num:", len(test_files))

    train_loader = torch.utils.data.DataLoader(Dataset(args, train_files),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.numwork)  #
    test_loader = torch.utils.data.DataLoader(Dataset(args, test_files),
                                              batch_size=args.batch_size,
                                              num_workers=args.numwork)

    tb = TensorBoard(args.model_dir)

    # Step 2: init neural networks
    print("network is:", args.net)
    if args.net == 'Reab3p16':
        model = Reab3p16(args)

    if args.gpunum > 1:
        model = nn.DataParallel(model, device_ids=range(args.gpunum))

    weights_path = args.path_weight

    if os.path.exists(weights_path):
        pretrained_dict = torch.load(weights_path)
        model_dict = model.state_dict()
        pretrained_dict1 = {}
        for k, v in pretrained_dict.items():
            if k in model_dict:
                pretrained_dict1[k] = v
                #print(k)
        model_dict.update(pretrained_dict1)
        model.load_state_dict(model_dict)

        print('load weight: ' + weights_path)

    model.cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.mo,
                          weight_decay=5e-4)
    #optimizer = optim.Adam(model.parameters(), lr=args.lr)
    if args.gpunum > 1:
        optimizer = nn.DataParallel(optimizer, device_ids=range(args.gpunum))

    iter_count = 1
    epoch_count = 1
    #iter_epoch=int(len(train_files) / args.batch_size)
    print(time.strftime('%H:%M:%S', time.localtime(time.time())), 'training')

    while True:
        since = time.time()

        with torch.no_grad():
            model.eval()
            accuracy_all = []

            for x, y, style, me in test_loader:
                x, y = Variable(x).cuda(), Variable(y).cuda()
                pred = model(x)
                pred = pred.data.max(1)[1]
                correct = pred.eq(y.data).cpu().numpy()
                accuracy = correct.sum() * 100.0 / len(y)

                accuracy_all.append(accuracy)

        accuracy_all = sum(accuracy_all) / len(accuracy_all)

        reward = accuracy_all * 100
        tb.scalar_summary("test_acc", reward, epoch_count)

        # np.expand_dims(, axis=0)

        time_elapsed = time.time() - since
        print('test epoch in {:.0f}h {:.0f}m {:.0f}s'.format(
            time_elapsed // 3600, time_elapsed // 60 % 60, time_elapsed % 60))
        print('------------------------------------')
        print(('epoch:%d, acc:%.1f') % (epoch_count, accuracy_all))
        print('------------------------------------')

        model.train()
        iter_epoch = int(len(train_files) / args.batch_size)
        for x, y, style, me in train_loader:
            if x.shape[0] < 10:
                print(x.shape[0])
                break
            x, y = Variable(x).cuda(), Variable(y).cuda()
            if args.gpunum > 1:
                optimizer.module.zero_grad()
            else:
                optimizer.zero_grad()
            pred = model(x)
            loss = F.nll_loss(pred, y, reduce=False)
            #train_loss=loss
            loss = loss.mean()
            loss.backward()
            if args.gpunum > 1:
                optimizer.module.step()
            else:
                optimizer.step()
            iter_count += 1
            pred = pred.data.max(1)[1]
            correct = pred.eq(y.data).cpu()
            accuracy_total = correct.sum() * 100.0 / len(y)
            if iter_count % 100 == 0:
                iter_c = iter_count % iter_epoch
                print(
                    time.strftime('%H:%M:%S', time.localtime(time.time())),
                    ('train_epoch:%d,iter_count:%d/%d, loss:%.3f, acc:%.1f') %
                    (epoch_count, iter_c, iter_epoch, loss, accuracy_total))

                tb.scalar_summary("train_loss", loss, iter_count)

        #print(acc_part_train)
        if epoch_count % args.lr_step == 0:
            print("change lr")
            adjust_learning_rate(optimizer, epoch_count, args.lr_step,
                                 args.gpunum)
        time_elapsed = time.time() - since
        print('train epoch in {:.0f}h {:.0f}m {:.0f}s'.format(
            time_elapsed // 3600, time_elapsed // 60 % 60, time_elapsed % 60))
        #acc_p=np.array([x[0]/x[1] for x in acc_part])
        #print(acc_p)

        epoch_count += 1
        if epoch_count % 1 == 0:
            print("save!!!!!!!!!!!!!!!!")
            save_state(model.state_dict(),
                       args.model_dir + "/epoch" + str(epoch_count))
예제 #9
0
파일: trainer.py 프로젝트: williamsz/AutoML
class Trainer(object):
    """A class to wrap training code."""
    def __init__(self,
                 dataset,
                 n_tranformers,
                 n_scalers,
                 n_constructers,
                 n_selecters,
                 n_models,
                 lstm_size,
                 temperature,
                 tanh_constant,
                 save_dir,
                 func_names,
                 controller_max_step=100,
                 controller_grad_clip=0,
                 optimizer='sgd',
                 controller_lr=0.001,
                 entropy_weight=0.001,
                 ema_baseline_decay=0.95,
                 use_tensorboard=True,
                 model_dir=None,
                 log_step=10):

        self.dataset = dataset
        self.controller_max_step = controller_max_step
        self.controller_grad_clip = controller_grad_clip
        self.n_tranformers = n_tranformers
        self.n_scalers = n_scalers
        self.n_constructers = n_constructers
        self.n_selecters = n_selecters
        self.n_models = n_models
        self.lstm_size = lstm_size
        self.temperature = temperature
        self.tanh_constant = tanh_constant
        self.save_dir = save_dir
        self.optimizer = optimizer
        self.controller_lr = controller_lr
        self.entropy_weight = entropy_weight
        self.ema_baseline_decay = ema_baseline_decay
        self.func_names = func_names
        self.use_tensorboard = use_tensorboard
        self.log_step = log_step
        self.model_dir = model_dir

        if self.use_tensorboard:
            self.tb = TensorBoard(self.model_dir)
        else:
            self.tb = None

        self.controller_step = 0

    def get_reward(self, actions):
        reward = models.fit(actions, self.dataset)
        return reward

    def random_actions(self):
        num_tokens = [
            self.n_tranformers, self.n_scalers, self.n_constructers,
            self.n_selecters, self.n_models
        ]
        skip_index = [np.random.randint(i, size=1) for i in range(1, 5)]
        func_index = [np.random.randint(i, size=1) for i in num_tokens]
        actions = []
        for x in range(4):
            actions.append(skip_index[x][0])
            actions.append(func_index[x][0])
        actions.append(func_index[-1][0])
        return actions

    def train_controller(self):

        avg_reward_base = None
        baseline = None
        adv_history = []
        entropy_history = []
        reward_history = []

        controller = models.Controller(self.n_tranformers, self.n_scalers,
                                       self.n_constructers, self.n_selecters,
                                       self.n_models, self.func_names,
                                       self.lstm_size, self.temperature,
                                       self.tanh_constant, self.save_dir)

        controller_optimizer = _get_optimizer(self.optimizer)
        controller_optim = controller_optimizer(controller.parameters(),
                                                lr=self.controller_lr)

        controller.train()
        total_loss = 0

        results_dag = []
        results_acc = []
        random_history = []
        acc_history = []

        for step in range(self.controller_max_step):
            # sample models
            dags, actions, sample_entropy, sample_log_probs = controller()
            sample_entropy = torch.sum(sample_entropy)
            sample_log_probs = torch.sum(sample_log_probs)
            # print(sample_log_probs)
            print(actions)

            random_actions = self.random_actions()
            with torch.no_grad():
                acc = self.get_reward(actions)
                random_acc = self.get_reward(torch.LongTensor(random_actions))

            random_history.append(random_acc)
            results_acc.append(acc)
            results_dag.append(dags)
            acc_history.append(acc)

            rewards = torch.tensor(acc)

            if self.entropy_weight is not None:
                rewards += self.entropy_weight * sample_entropy

            reward_history.append(rewards)
            entropy_history.append(sample_entropy)

            # moving average baseline
            if baseline is None:
                baseline = rewards
            else:
                decay = self.ema_baseline_decay
                baseline = decay * baseline + (1 - decay) * rewards

            adv = rewards - baseline
            adv_history.append(adv)

            # policy loss
            loss = sample_log_probs * adv

            # update
            controller_optim.zero_grad()
            loss.backward()

            if self.controller_grad_clip > 0:
                torch.nn.utils.clip_grad_norm(controller.parameters(),
                                              self.controller_grad_clip)
            controller_optim.step()

            total_loss += loss.item()

            if ((step % self.log_step) == 0) and (step > 0):
                self._summarize_controller_train(total_loss, adv_history,
                                                 entropy_history,
                                                 reward_history, acc_history,
                                                 random_history,
                                                 avg_reward_base, dags)

                reward_history, adv_history, entropy_history,acc_history,random_history = [], [], [],[],[]
                total_loss = 0
            self.controller_step += 1

        max_acc = np.max(results_acc)
        max_dag = results_dag[np.argmax(results_acc)]
        path = os.path.join(self.model_dir, 'networks', 'best.png')
        utils.draw_network(max_dag[0], path)
        # np.sort(results_acc)[-10:]
        return np.sort(list(set(results_acc)))[-10:]

    def _summarize_controller_train(self, total_loss, adv_history,
                                    entropy_history, reward_history,
                                    acc_history, random_history,
                                    avg_reward_base, dags):
        """Logs the controller's progress for this training epoch."""
        cur_loss = total_loss / self.log_step

        avg_adv = np.mean(adv_history)
        avg_entropy = np.mean(entropy_history)
        avg_reward = np.mean(reward_history)
        avg_acc = np.mean(acc_history)
        avg_random = np.mean(random_history)

        if avg_reward_base is None:
            avg_reward_base = avg_reward

        logger.info(f'| lr {self.controller_lr:.5f} '
                    f'| R {avg_reward:.5f} | entropy {avg_entropy:.4f} '
                    f'| loss {cur_loss:.5f}')

        # Tensorboard
        if self.tb is not None:
            self.tb.scalar_summary('controller/loss', cur_loss,
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward', avg_reward,
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward-B_per_epoch',
                                   avg_reward - avg_reward_base,
                                   self.controller_step)
            self.tb.scalar_summary('controller/entropy', avg_entropy,
                                   self.controller_step)
            self.tb.scalar_summary('controller/adv', avg_adv,
                                   self.controller_step)
            self.tb.scalar_summary('controller/acc', avg_acc,
                                   self.controller_step)
            self.tb.scalar_summary('controller/random', avg_random,
                                   self.controller_step)

            paths = []
예제 #10
0
class Trainer(object):
    """A class to wrap training code."""
    def __init__(self, args, dataset):
        """Constructor for training algorithm.

        Args:
            args: From command line, picked up by 'argparse'
            dataset: Currently only `data.text.Corpus` is supported.

        Initializes:
            - Data: train, val and test.
            - Model: shared and controller.
            - Inference: optimizers for shared and controller parameters.
            - Criticism: cross-entropy loss for training the shared model.
        """
        #TODO   加个检查准确率的
        self.args = args
        self.controller_step = 0
        self.cuda = args.cuda
        self.dataset = dataset
        self.epoch = 0
        self.shared_step = 0
        self.start_epoch = 0

        print('regularizing:')
        for regularizer in [('activation regularization',
                             self.args.activation_regularization),
                            ('temporal activation regularization',
                             self.args.temporal_activation_regularization),
                            ('norm stabilizer regularization',
                             self.args.norm_stabilizer_regularization)]:
            if regularizer[1]:
                print(f'{regularizer[0]}')

        # self.train_data = utils.batchify(dataset.train,
        #                                  args.batch_size,
        #                                  self.cuda)
        # NOTE(brendan): The validation set data is batchified twice
        # separately: once for computing rewards during the Train Controller
        # phase (valid_data, batch size == 64), and once for evaluating ppl
        # over the entire validation set (eval_data, batch size == 1)
        self.train_data = dataset.train
        self.valid_data = dataset.valid
        self.test_data = dataset.test
        # self.max_length = self.args.shared_rnn_max_length

        if args.use_tensorboard:
            self.tb = TensorBoard(args.model_dir)
        else:
            self.tb = None
        #TODO initialize controller and shared model
        self.build_model()
        # print("11111111")
        if self.args.load_path:
            print("=======load_path=======")
            self.load_model()

        shared_optimizer = _get_optimizer(self.args.shared_optim)
        controller_optimizer = _get_optimizer(self.args.controller_optim)
        print("=======make optimizer========")
        self.shared_optim = shared_optimizer(
            self.shared.parameters(),
            lr=self.shared_lr,
            weight_decay=self.args.shared_l2_reg)
        print("=======make optimizer========")
        self.controller_optim = controller_optimizer(
            self.controller.parameters(), lr=self.args.controller_lr)

        self.ce = nn.CrossEntropyLoss()
        print("finish init")

    def build_model(self):
        """Creates and initializes the shared and controller models."""
        if self.args.network_type == 'rnn':
            self.shared = models.RNN(self.args, self.dataset)
        elif self.args.network_type == 'cnn':
            print("----- begin to init cnn------")
            self.shared = models.CNN(self.args, self.dataset)
            # self.shared = self.shared.cuda()
        else:
            raise NotImplementedError(f'Network type '
                                      f'`{self.args.network_type}` is not '
                                      f'defined')
        print("---- begin to init controller-----")
        self.controller = models.Controller(self.args)
        #self.controller = self.controller.cuda()
        print("===begin to cuda")
        if True:
            print("cuda")
            self.shared.cuda()
            self.controller.cuda()
            print("finish cuda")
        elif self.args.num_gpu > 1:
            raise NotImplementedError('`num_gpu > 1` is in process')

    def train(self):
        """Cycles through alternately training the shared parameters and the
        controller, as described in Section2.4 Training ENAS and deriving
        Architectures, of the paraer.
        """
        if self.args.shared_initial_step > 0:
            self.train_shared(self.args.shared_initial_step)
            self.train_controller()

        for self.epoch in range(self.start_epoch, self.args.max_epoch):
            # 1. Training the shared parameters omega of the child models
            self.train_shared()

            # 2. Training the controller parameters theta
            #self.train_controller()
            if self.epoch == 0:
                with _get_no_grad_ctx_mgr():
                    best_dag = self.derive()
                    self.evaluate(iter(self.test_data),
                                  best_dag,
                                  'val_best',
                                  max_num=self.args.batch_size * 100)
                self.save_model()

            if self.epoch % self.args.save_epoch == 0:
                with _get_no_grad_ctx_mgr():
                    best_dag = self.derive()
                    self.evaluate(iter(self.test_data),
                                  best_dag,
                                  'val_best',
                                  max_num=self.args.batch_size * 100)
                self.save_model()

            if self.epoch >= self.args.shared_decay_after:
                utils.update_lr(self.shared_optim, self.shared_lr)

    def get_loss(self, inputs, targets, dags):
        """Computes the loss for the same batch for M models.

        This amounts to an estimate of the loss, which is turned into an
        estimate for the gradients of the shared model.
        """
        if not isinstance(dags, list):
            dags = [dags]

        loss = 0
        for dag in dags:
            inputs = Variable(inputs.cuda())
            targets = Variable(targets.cuda())
            # inputs = inputs.cuda()
            #targets = targets.cuda()
            #self.shared = self.shared.cuda()
            output = self.shared(inputs, dag)
            sample_loss = (self.ce(output, targets) /
                           self.args.shared_num_sample)
            loss += sample_loss

        assert len(
            dags) == 1, 'there are multiple `hidden` for multiple `dags`'
        return loss

    def train_shared(self, max_step=None):
        """Train the image classification model for 310 steps
        """
        #TODO check if it is right that create a new dag for every batch and may be
        #one epoch one bathc will improve efficient
        model = self.shared
        model.train()
        self.controller.eval()

        if max_step is None:
            max_step = self.args.shared_max_step
        else:
            max_step = min(self.args.shared_max_step, max_step)

        step = 0
        raw_total_loss = 0
        total_loss = 0
        # train_idx = 0
        train_iter = iter(self.train_data)
        #TODO understanding how it train
        while True:
            if step > max_step:
                break
            dags = self.controller.sample(self.args.shared_num_sample)
            #print(dags)
            #TODO use iterator to create batch but need to add StopIteration
            #may be have some method to improve
            try:
                inputs, targets = train_iter.next()
            except StopIteration:
                print("====>train_shared<====== finish one epoch")
                break
                train_iter = iter(self.train_data)
            #print(dags)
            loss = self.get_loss(inputs, targets, dags)
            raw_total_loss += loss.data
            #TODO understand penality
            # loss += _apply_penalties()
            self.shared_optim.zero_grad()
            loss.backward()

            self.shared_optim.step()

            total_loss += loss.data
            #if step % 20 == 0:
            #    print("loss, ", total_loss, step, total_loss /(step+1))

            if ((step % self.args.log_step) == 0) and (step > 0):
                self._summarize_shared_train(total_loss, raw_total_loss)
                raw_total_loss = 0
                total_loss = 0

            step += 1
            self.shared_step += 1
            # train_idx += self.max_length
    def get_reward(self, dag, entropies, data_iter):
        """Computes the perplexity of a single sampled model on a minibatch of
        validation data.
        """
        if not isinstance(entropies, np.ndarray):
            entropies = entropies.data.cpu().numpy()
        try:
            inputs, targets = data_iter.next()
        except StopIteration:
            data_iter = iter(self.valid_data)
            inputs, targets = data_iter.next()
        #TODO 怎么做volidate
        valid_loss = self.get_loss(inputs, targets, dag)
        # convert valid_loss to numpy ndarray
        valid_loss = utils.to_item(valid_loss.data)

        valid_ppl = math.exp(valid_loss)

        # TODO we don't knoe reward_c
        if self.args.ppl_square:
            #TODO: but we do know reward_c =80 in the previous paper need to read previous paper
            R = self.args.reward_c / valid_ppl**2
        else:
            R = self.args.reward_c / valid_ppl

        if self.args.entropy_mode == 'reward':
            rewards = R + self.args.entropy_coeff * entropies
        elif self.args.entropy_mode == 'regularizer':
            rewards = R * np.ones_like(entropies)
        else:
            raise NotImplementedError(
                f'Unknown entropy mode: {self.args.entropy_mode}')

        return rewards

    def train_controller(self):
        """Fixes the shared parameters and updates the controller parameters.

        The controller is updated with a score function gradient estimator
        (i.e., REINFORCE), with the reward being c/valid_ppl. where valid_ppl
        is computed on a minibatch of vlaidation data.

        A moving average baseline is used.

        The controller is trained for 2000 steps per epoch (i.e.,
        first (Train Shared) phase -. Second (Train Controller) phase).
        """
        model = self.controller
        model.train()

        avg_reward_base = None
        baseline = None
        adv_history = []
        entropy_history = []
        reward_history = []
        valid_iter = iter(self.valid_data)
        total_loss = 0
        for step in range(self.args.controller_max_step):

            dags, log_probs, entropies = self.controller.sample(
                with_details=True)
            #print(dags)
            np_entropies = entropies.data.cpu().numpy()

            with _get_no_grad_ctx_mgr():
                rewards = self.get_reward(dags, np_entropies, valid_iter)
            if 1 > self.args.discount > 0:
                rewards = discount(rewards, self.args.discount)

            reward_history.extend(rewards)
            entropy_history.extend(np_entropies)

            # moving average baseline
            if baseline is None:
                baseline = rewards
            else:
                decay = self.args.ema_baseline_decay
                baseline = decay * baseline + (1 - decay) * rewards

            adv = rewards - baseline
            adv_history.extend(adv)

            #policy loss
            loss = -log_probs * utils.get_variable(
                adv, self.cuda, requires_grad=False)
            if self.args.entropy_mode == 'regularizer':
                loss -= self.args.entropy_coeff * np_entropies

            loss = loss.sum()

            self.controller_optim.zero_grad()
            loss.backward()

            if self.args.controller_grad_clip > 0:
                torch.nn.utils.clip_grad_norm(model.parameters(),
                                              self.args.controller_grad_clip)
            self.controller_optim.step()

            total_loss += utils.to_item(loss.data)
            #if step%20 ==0:
            #    print("total loss", total_loss, step, total_loss / (step+1))
            if ((step % self.args.log_step) == 0) and (step > 0):
                self._summarize_controller_train(total_loss, adv_history,
                                                 entropy_history,
                                                 reward_history,
                                                 avg_reward_base, dags)
                reward_history, adv_history, entropy_history = [], [], []
                total_loss = 0
            self.controller_step += 1

            # prev_valid_idx = valid_idx
            # valid_idx = ((valid_idx + self.max_length) %
            #             (self.valid_data.size(0) - 1))

            # NOTE(brendan): Whenever we wrap around to the beginning of the
            # validation data, we reset the hidden states.

    def evaluate(self, test_iter, dag, name, batch_size=1, max_num=None):
        """Evaluate on the validation set.
        (lianqing)what is the data of source ?

        NOTE: use validation to check reward but test set is the same as valid set
        """
        self.shared.eval()
        self.controller.eval()
        acc = AverageMeter()
        # data = source[:max_num*self.max_length]
        total_loss = 0
        # pbar = range(0, data.size(0) - 1, self.max_length)
        count = 0
        while True:
            try:
                count += 1
                inputs, targets = next(test_iter)
            except StopIteration:
                print("========> finish evaluate on one epoch<======")
                break
                test_iter = iter(self.test_data)
                inputs, targets = next(test_iter)
                # inputs = Variable(inputs)
            #check if is train the controller will have what difference
            inputs = Variable(inputs.cuda())
            targets = Variable(targets.cuda())
            # inputs = inputs.cuda()
            #targets = targets.cuda()
            output = self.shared(inputs, dag, is_train=False)
            # check is self.loss wil work ?:
            total_loss += len(inputs) * self.ce(output, targets).data
            ppl = math.exp(utils.to_item(total_loss) / (count + 1))
            acc.update(utils.get_accuracy(targets, output))
        val_loss = utils.to_item(total_loss) / count
        ppl = math.exp(val_loss)
        #TODO it's fix for rnn need to fix for cnn
        #self.tb.scalar_summary(f'eval/{name}_loss', val_loss, self.epoch)
        #self.tb.scalar_summary(f'eval/{name}_ppl', ppl, self.epoch)
        print(
            f'eval | loss: {val_loss:8.2f} | ppl: {ppl:8.2f} | accuracy: {acc.avg:8.2f}'
        )

    def derive(self, sample_num=None, valid_iter=None):
        """
        pass sample_num is always to 1 test if batch_size > 1 will work ? for controller.sample
        """
        if sample_num is None:
            sample_num = self.args.derive_num_sample
        if valid_iter == None:
            valid_iter = iter(self.valid_data)
        dags, _, entropies = self.controller.sample(sample_num,
                                                    with_details=True)
        max_R = 0
        best_dag = None
        for dag in dags:
            R = self.get_reward(dag, entropies, valid_iter)
            if R.max() > max_R:
                max_R = R.max()
                best_dag = dag

        print(f'derive | max_R: {max_R:8.6f}')
        fname = (f'{self.epoch:03d}-{self.controller_step:06d}-'
                 f'{max_R:6.4}-best.png')
        path = os.path.join(self.args.model_dir, 'networks', fname)
        # utils.draw_network(best_dag, path)
        # self.tb.image_summary('derive/best', [path], self.epoch)

        return best_dag

    @property
    def shared_lr(self):
        degree = max(self.epoch - self.args.shared_decay_after + 1, 0)
        return self.args.shared_lr * (self.args.shared_decay**degree)

    @property
    def controller_lr(self):
        return self.args.controller_lr

    @property
    def shared_path(self):
        return f'{self.args.model_dir}/shared_epoch{self.epoch}_step{self.shared_step}.pth'

    @property
    def controller_path(self):
        return f'{self.args.model_dir}/controller_epoch{self.epoch}_step{self.controller_step}.pth'

    def get_saved_models_info(self):
        paths = glob.glob(os.path.join(self.args.model_dir, '*.pth'))
        paths.sort()

        def get_numbers(items, delimiter, idx, replace_word, must_contain=''):
            return list(
                set([
                    int(name.split(delimiter)[idx].replace(replace_word, ''))
                    for name in basenames if must_contain in name
                ]))

        basenames = [
            os.path.basename(path.rsplit('.', 1)[0]) for path in paths
        ]
        epochs = get_numbers(basenames, '_', 1, 'epoch')
        shared_steps = get_numbers(basenames, '_', 2, 'step', 'shared')
        controller_steps = get_numbers(basenames, '_', 2, 'step', 'controller')

        epochs.sort()
        shared_steps.sort()
        controller_steps.sort()

        return epochs, shared_steps, controller_steps

    def save_model(self):
        torch.save(self.shared.state_dict(), self.shared_path)
        print(f'[*] SAVED: {self.shared_path}')

        torch.save(self.controller.state_dict(), self.controller_path)
        print(f'[*] SAVED: {self.controller_path}')

        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        for epoch in epochs[:-self.args.max_save_num]:
            paths = glob.glob(
                os.path.join(self.args.model_dir, f'*_epoch{epoch}_*.pth'))

            for path in paths:
                utils.remove_file(path)

    def load_model(self):
        epochs, shared_steps, controller_steps = self.get_saved_models_info()

        if len(epochs) == 0:
            print(f'[!] No checkpoint found in {self.args.model_dir}...')
            return

        self.epoch = self.start_epoch = max(epochs)
        self.shared_step = max(shared_steps)
        self.controller_step = max(controller_steps)

        if self.args.num_gpu == 0:
            map_location = lambda storage, loc: storage
        else:
            map_location = None

        self.shared.load_state_dict(
            torch.load(self.shared_path, map_location=map_location))
        print(f'[*] LOADED: {self.shared_path}')

        self.controller.load_state_dict(
            torch.load(self.controller_path, map_location=map_location))
        print(f'[*] LOADED: {self.controller_path}')

    def _summarize_controller_train(self, total_loss, adv_history,
                                    entropy_history, reward_history,
                                    avg_reward_base, dags):
        """Logs the controller's progress for this training epoch."""
        cur_loss = total_loss / self.args.log_step

        avg_adv = np.mean(adv_history)
        avg_entropy = np.mean(entropy_history)
        avg_reward = np.mean(reward_history)

        if avg_reward_base is None:
            avg_reward_base = avg_reward

        print(f'| epoch {self.epoch:3d} | lr {self.controller_lr:.5f} '
              f'| R {avg_reward:.5f} | entropy {avg_entropy:.4f} '
              f'| loss {cur_loss:.5f}')

        # Tensorboard
        if self.tb is not None:
            self.tb.scalar_summary('controller/loss', cur_loss,
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward', avg_reward,
                                   self.controller_step)
            self.tb.scalar_summary('controller/reward-B_per_epoch',
                                   avg_reward - avg_reward_base,
                                   self.controller_step)
            self.tb.scalar_summary('controller/entropy', avg_entropy,
                                   self.controller_step)
            self.tb.scalar_summary('controller/adv', avg_adv,
                                   self.controller_step)

            paths = []
            for dag in dags:
                fname = (f'{self.epoch:03d}-{self.controller_step:06d}-'
                         f'{avg_reward:6.4f}.png')
                path = os.path.join(self.args.model_dir, 'networks', fname)
                # utils.draw_network(dag, path)
                paths.append(path)

            self.tb.image_summary('controller/sample', paths,
                                  self.controller_step)

    def _summarize_shared_train(self, total_loss, raw_total_loss):
        """Logs a set of training steps."""
        cur_loss = utils.to_item(total_loss) / self.args.log_step
        # NOTE(brendan): The raw loss, without adding in the activation
        # regularization terms, should be used to compute ppl.
        cur_raw_loss = utils.to_item(raw_total_loss) / self.args.log_step
        ppl = math.exp(cur_raw_loss)

        print(f'| epoch {self.epoch:3d} '
              f'| lr {self.shared_lr:4.2f} '
              f'| raw loss {cur_raw_loss:.2f} '
              f'| loss {cur_loss:.2f} '
              f'| ppl {ppl:8.2f}')

        # Tensorboard
        if self.tb is not None:
            self.tb.scalar_summary('shared/loss', cur_loss, self.shared_step)
            self.tb.scalar_summary('shared/perplexity', ppl, self.shared_step)