class Trainer(object): """A class to wrap training code.""" def __init__(self, args, dataset): """Constructor for training algorithm. Args: args: From command line, picked up by `argparse`. dataset: Currently only `data.text.Corpus` is supported. Initializes: - Data: train, val and test. - Model: shared and controller. - Inference: optimizers for shared and controller parameters. - Criticism: cross-entropy loss for training the shared model. """ self.args = args self.controller_step = 0 self.cuda = args.cuda self.dataset = dataset self.epoch = 0 self.shared_step = 0 self.start_epoch = 0 logger.info('regularizing:') for regularizer in [('activation regularization', self.args.activation_regularization), ('temporal activation regularization', self.args.temporal_activation_regularization), ('norm stabilizer regularization', self.args.norm_stabilizer_regularization)]: if regularizer[1]: logger.info(f'{regularizer[0]}') self.train_data = utils.batchify(dataset.train, args.batch_size, self.cuda) # NOTE(brendan): The validation set data is batchified twice # separately: once for computing rewards during the Train Controller # phase (valid_data, batch size == 64), and once for evaluating ppl # over the entire validation set (eval_data, batch size == 1) self.valid_data = utils.batchify(dataset.valid, args.batch_size, self.cuda) self.eval_data = utils.batchify(dataset.valid, args.test_batch_size, self.cuda) self.test_data = utils.batchify(dataset.test, args.test_batch_size, self.cuda) self.max_length = self.args.shared_rnn_max_length if args.use_tensorboard: self.tb = TensorBoard(args.model_dir) else: self.tb = None self.build_model() if self.args.load_path: self.load_model() shared_optimizer = _get_optimizer(self.args.shared_optim) controller_optimizer = _get_optimizer(self.args.controller_optim) self.shared_optim = shared_optimizer( self.shared.parameters(), lr=self.shared_lr, weight_decay=self.args.shared_l2_reg) self.controller_optim = controller_optimizer( self.controller.parameters(), lr=self.args.controller_lr) self.ce = nn.CrossEntropyLoss() def build_model(self): """Creates and initializes the shared and controller models.""" if self.args.network_type == 'rnn': self.shared = models.RNN(self.args, self.dataset) elif self.args.network_type == 'cnn': self.shared = models.CNN(self.args, self.dataset) else: raise NotImplementedError(f'Network type ' f'`{self.args.network_type}` is not ' f'defined') self.controller = models.Controller(self.args) if self.args.num_gpu == 1: self.shared.cuda() self.controller.cuda() elif self.args.num_gpu > 1: raise NotImplementedError('`num_gpu > 1` is in progress') def train(self): """Cycles through alternately training the shared parameters and the controller, as described in Section 2.2, Training ENAS and Deriving Architectures, of the paper. From the paper (for Penn Treebank): - In the first phase, shared parameters omega are trained for 400 steps, each on a minibatch of 64 examples. - In the second phase, the controller's parameters are trained for 2000 steps. """ if self.args.shared_initial_step > 0: self.train_shared(self.args.shared_initial_step) self.train_controller() for self.epoch in range(self.start_epoch, self.args.max_epoch): # 1. Training the shared parameters omega of the child models self.train_shared() # 2. Training the controller parameters theta self.train_controller() if self.epoch % self.args.save_epoch == 0: with _get_no_grad_ctx_mgr(): best_dag = self.derive() self.evaluate(self.eval_data, best_dag, 'val_best', max_num=self.args.batch_size * 100) self.save_model() if self.epoch >= self.args.shared_decay_after: utils.update_lr(self.shared_optim, self.shared_lr) def get_loss(self, inputs, targets, hidden, dags): """Computes the loss for the same batch for M models. This amounts to an estimate of the loss, which is turned into an estimate for the gradients of the shared model. """ if not isinstance(dags, list): dags = [dags] loss = 0 for dag in dags: output, hidden, extra_out = self.shared(inputs, dag, hidden=hidden) output_flat = output.view(-1, self.dataset.num_tokens) sample_loss = (self.ce(output_flat, targets) / self.args.shared_num_sample) loss += sample_loss assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`' return loss, hidden, extra_out def train_shared(self, max_step=None): """Train the language model for 400 steps of minibatches of 64 examples. Args: max_step: Used to run extra training steps as a warm-up. BPTT is truncated at 35 timesteps. For each weight update, gradients are estimated by sampling M models from the fixed controller policy, and averaging their gradients computed on a batch of training data. """ model = self.shared model.train() self.controller.eval() hidden = self.shared.init_hidden(self.args.batch_size) if max_step is None: max_step = self.args.shared_max_step else: max_step = min(self.args.shared_max_step, max_step) abs_max_grad = 0 abs_max_hidden_norm = 0 step = 0 raw_total_loss = 0 total_loss = 0 train_idx = 0 # TODO(brendan): Why - 1 - 1? while train_idx < self.train_data.size(0) - 1 - 1: if step > max_step: break dags = self.controller.sample(self.args.shared_num_sample) inputs, targets = self.get_batch(self.train_data, train_idx, self.max_length) loss, hidden, extra_out = self.get_loss(inputs, targets, hidden, dags) hidden.detach_() raw_total_loss += loss.data loss += _apply_penalties(extra_out, self.args) # update self.shared_optim.zero_grad() loss.backward() h1tohT = extra_out['hiddens'] new_abs_max_hidden_norm = utils.to_item( h1tohT.norm(dim=-1).data.max()) if new_abs_max_hidden_norm > abs_max_hidden_norm: abs_max_hidden_norm = new_abs_max_hidden_norm logger.info(f'max hidden {abs_max_hidden_norm}') abs_max_grad = _check_abs_max_grad(abs_max_grad, model) torch.nn.utils.clip_grad_norm(model.parameters(), self.args.shared_grad_clip) self.shared_optim.step() total_loss += loss.data if ((step % self.args.log_step) == 0) and (step > 0): self._summarize_shared_train(total_loss, raw_total_loss) raw_total_loss = 0 total_loss = 0 step += 1 self.shared_step += 1 train_idx += self.max_length def get_reward(self, dag, entropies, hidden, valid_idx=0): """Computes the perplexity of a single sampled model on a minibatch of validation data. """ if not isinstance(entropies, np.ndarray): entropies = entropies.data.cpu().numpy() inputs, targets = self.get_batch(self.valid_data, valid_idx, self.max_length, volatile=True) valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag) valid_loss = utils.to_item(valid_loss.data) valid_ppl = math.exp(valid_loss) # TODO: we don't know reward_c if self.args.ppl_square: # TODO: but we do know reward_c=80 in the previous paper R = self.args.reward_c / valid_ppl**2 else: R = self.args.reward_c / valid_ppl if self.args.entropy_mode == 'reward': rewards = R + self.args.entropy_coeff * entropies elif self.args.entropy_mode == 'regularizer': rewards = R * np.ones_like(entropies) else: raise NotImplementedError( f'Unkown entropy mode: {self.args.entropy_mode}') return rewards, hidden def train_controller(self): """Fixes the shared parameters and updates the controller parameters. The controller is updated with a score function gradient estimator (i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl is computed on a minibatch of validation data. A moving average baseline is used. The controller is trained for 2000 steps per epoch (i.e., first (Train Shared) phase -> second (Train Controller) phase). """ model = self.controller model.train() # TODO(brendan): Why can't we call shared.eval() here? Leads to loss # being uniformly zero for the controller. # self.shared.eval() avg_reward_base = None baseline = None adv_history = [] entropy_history = [] reward_history = [] hidden = self.shared.init_hidden(self.args.batch_size) total_loss = 0 valid_idx = 0 for step in range(self.args.controller_max_step): # sample models dags, log_probs, entropies = self.controller.sample( with_details=True) # calculate reward np_entropies = entropies.data.cpu().numpy() # NOTE(brendan): No gradients should be backpropagated to the # shared model during controller training, obviously. with _get_no_grad_ctx_mgr(): rewards, hidden = self.get_reward(dags, np_entropies, hidden, valid_idx) # discount if 1 > self.args.discount > 0: rewards = discount(rewards, self.args.discount) reward_history.extend(rewards) entropy_history.extend(np_entropies) # moving average baseline if baseline is None: baseline = rewards else: decay = self.args.ema_baseline_decay baseline = decay * baseline + (1 - decay) * rewards adv = rewards - baseline adv_history.extend(adv) # policy loss loss = -log_probs * utils.get_variable( adv, self.cuda, requires_grad=False) if self.args.entropy_mode == 'regularizer': loss -= self.args.entropy_coeff * entropies loss = loss.sum() # or loss.mean() # update self.controller_optim.zero_grad() loss.backward() if self.args.controller_grad_clip > 0: torch.nn.utils.clip_grad_norm(model.parameters(), self.args.controller_grad_clip) self.controller_optim.step() total_loss += utils.to_item(loss.data) if ((step % self.args.log_step) == 0) and (step > 0): self._summarize_controller_train(total_loss, adv_history, entropy_history, reward_history, avg_reward_base, dags) reward_history, adv_history, entropy_history = [], [], [] total_loss = 0 self.controller_step += 1 prev_valid_idx = valid_idx valid_idx = ((valid_idx + self.max_length) % (self.valid_data.size(0) - 1)) # NOTE(brendan): Whenever we wrap around to the beginning of the # validation data, we reset the hidden states. if prev_valid_idx > valid_idx: hidden = self.shared.init_hidden(self.args.batch_size) def evaluate(self, source, dag, name, batch_size=1, max_num=None): """Evaluate on the validation set. NOTE(brendan): We should not be using the test set to develop the algorithm (basic machine learning good practices). """ self.shared.eval() self.controller.eval() data = source[:max_num * self.max_length] total_loss = 0 hidden = self.shared.init_hidden(batch_size) pbar = range(0, data.size(0) - 1, self.max_length) for count, idx in enumerate(pbar): inputs, targets = self.get_batch(data, idx, volatile=True) output, hidden, _ = self.shared(inputs, dag, hidden=hidden, is_train=False) output_flat = output.view(-1, self.dataset.num_tokens) total_loss += len(inputs) * self.ce(output_flat, targets).data hidden.detach_() ppl = math.exp( utils.to_item(total_loss) / (count + 1) / self.max_length) val_loss = utils.to_item(total_loss) / len(data) ppl = math.exp(val_loss) self.tb.scalar_summary(f'eval/{name}_loss', val_loss, self.epoch) self.tb.scalar_summary(f'eval/{name}_ppl', ppl, self.epoch) logger.info(f'eval | loss: {val_loss:8.2f} | ppl: {ppl:8.2f}') def derive(self, sample_num=None, valid_idx=0): """TODO(brendan): We are always deriving based on the very first batch of validation data? This seems wrong... """ hidden = self.shared.init_hidden(self.args.batch_size) if sample_num is None: sample_num = self.args.derive_num_sample dags, _, entropies = self.controller.sample(sample_num, with_details=True) max_R = 0 best_dag = None for dag in dags: R, _ = self.get_reward(dag, entropies, hidden, valid_idx) if R.max() > max_R: max_R = R.max() best_dag = dag logger.info(f'derive | max_R: {max_R:8.6f}') fname = (f'{self.epoch:03d}-{self.controller_step:06d}-' f'{max_R:6.4f}-best.png') path = os.path.join(self.args.model_dir, 'networks', fname) utils.draw_network(best_dag, path) self.tb.image_summary('derive/best', [path], self.epoch) return best_dag @property def shared_lr(self): degree = max(self.epoch - self.args.shared_decay_after + 1, 0) return self.args.shared_lr * (self.args.shared_decay**degree) @property def controller_lr(self): return self.args.controller_lr def get_batch(self, source, idx, length=None, volatile=False): # code from # https://github.com/pytorch/examples/blob/master/word_language_model/main.py length = min(length if length else self.max_length, len(source) - 1 - idx) data = Variable(source[idx:idx + length], volatile=volatile) target = Variable(source[idx + 1:idx + 1 + length].view(-1), volatile=volatile) return data, target @property def shared_path(self): return f'{self.args.model_dir}/shared_epoch{self.epoch}_step{self.shared_step}.pth' @property def controller_path(self): return f'{self.args.model_dir}/controller_epoch{self.epoch}_step{self.controller_step}.pth' def get_saved_models_info(self): paths = glob.glob(os.path.join(self.args.model_dir, '*.pth')) paths.sort() def get_numbers(items, delimiter, idx, replace_word, must_contain=''): return list( set([ int(name.split(delimiter)[idx].replace(replace_word, '')) for name in basenames if must_contain in name ])) basenames = [ os.path.basename(path.rsplit('.', 1)[0]) for path in paths ] epochs = get_numbers(basenames, '_', 1, 'epoch') shared_steps = get_numbers(basenames, '_', 2, 'step', 'shared') controller_steps = get_numbers(basenames, '_', 2, 'step', 'controller') epochs.sort() shared_steps.sort() controller_steps.sort() return epochs, shared_steps, controller_steps def save_model(self): torch.save(self.shared.state_dict(), self.shared_path) logger.info(f'[*] SAVED: {self.shared_path}') torch.save(self.controller.state_dict(), self.controller_path) logger.info(f'[*] SAVED: {self.controller_path}') epochs, shared_steps, controller_steps = self.get_saved_models_info() for epoch in epochs[:-self.args.max_save_num]: paths = glob.glob( os.path.join(self.args.model_dir, f'*_epoch{epoch}_*.pth')) for path in paths: utils.remove_file(path) def load_model(self): epochs, shared_steps, controller_steps = self.get_saved_models_info() if len(epochs) == 0: logger.info(f'[!] No checkpoint found in {self.args.model_dir}...') return self.epoch = self.start_epoch = max(epochs) self.shared_step = max(shared_steps) self.controller_step = max(controller_steps) if self.args.num_gpu == 0: map_location = lambda storage, loc: storage else: map_location = None self.shared.load_state_dict( torch.load(self.shared_path, map_location=map_location)) logger.info(f'[*] LOADED: {self.shared_path}') self.controller.load_state_dict( torch.load(self.controller_path, map_location=map_location)) logger.info(f'[*] LOADED: {self.controller_path}') def _summarize_controller_train(self, total_loss, adv_history, entropy_history, reward_history, avg_reward_base, dags): """Logs the controller's progress for this training epoch.""" cur_loss = total_loss / self.args.log_step avg_adv = np.mean(adv_history) avg_entropy = np.mean(entropy_history) avg_reward = np.mean(reward_history) if avg_reward_base is None: avg_reward_base = avg_reward logger.info(f'| epoch {self.epoch:3d} | lr {self.controller_lr:.5f} ' f'| R {avg_reward:.5f} | entropy {avg_entropy:.4f} ' f'| loss {cur_loss:.5f}') # Tensorboard if self.tb is not None: self.tb.scalar_summary('controller/loss', cur_loss, self.controller_step) self.tb.scalar_summary('controller/reward', avg_reward, self.controller_step) self.tb.scalar_summary('controller/reward-B_per_epoch', avg_reward - avg_reward_base, self.controller_step) self.tb.scalar_summary('controller/entropy', avg_entropy, self.controller_step) self.tb.scalar_summary('controller/adv', avg_adv, self.controller_step) paths = [] for dag in dags: fname = (f'{self.epoch:03d}-{self.controller_step:06d}-' f'{avg_reward:6.4f}.png') path = os.path.join(self.args.model_dir, 'networks', fname) utils.draw_network(dag, path) paths.append(path) self.tb.image_summary('controller/sample', paths, self.controller_step) def _summarize_shared_train(self, total_loss, raw_total_loss): """Logs a set of training steps.""" cur_loss = utils.to_item(total_loss) / self.args.log_step # NOTE(brendan): The raw loss, without adding in the activation # regularization terms, should be used to compute ppl. cur_raw_loss = utils.to_item(raw_total_loss) / self.args.log_step ppl = math.exp(cur_raw_loss) logger.info(f'| epoch {self.epoch:3d} ' f'| lr {self.shared_lr:4.2f} ' f'| raw loss {cur_raw_loss:.2f} ' f'| loss {cur_loss:.2f} ' f'| ppl {ppl:8.2f}') # Tensorboard if self.tb is not None: self.tb.scalar_summary('shared/loss', cur_loss, self.shared_step) self.tb.scalar_summary('shared/perplexity', ppl, self.shared_step)
class Trainer(object): def __init__(self, args, dataset): self.args = args self.cuda = args.cuda self.dataset = dataset self.train_data = batchify(dataset.train, args.batch_size, self.cuda) self.valid_data = batchify(dataset.valid, args.batch_size, self.cuda) self.test_data = batchify(dataset.test, args.test_batch_size, self.cuda) self.max_length = self.args.shared_rnn_max_length if args.use_tensorboard: self.tb = TensorBoard(args.model_dir) else: self.tb = None self.build_model() if self.args.load_path: self.load_model() def build_model(self): self.start_epoch = self.epoch = 0 self.shared_step, self.controller_step = 0, 0 if self.args.network_type == 'rnn': self.shared = RNN(self.args, self.dataset) elif self.args.network_type == 'cnn': self.shared = CNN(self.args, self.dataset) else: raise NotImplemented( f"Network type `{self.args.network_type}` is not defined") self.controller = Controller(self.args) if self.args.num_gpu == 1: self.shared.cuda() self.controller.cuda() elif self.args.num_gpu > 1: raise NotImplemented("`num_gpu > 1` is in progress") self.ce = nn.CrossEntropyLoss() def train(self): shared_optimizer = get_optimizer(self.args.shared_optim) controller_optimizer = get_optimizer(self.args.controller_optim) self.shared_optim = shared_optimizer( self.shared.parameters(), lr=self.shared_lr, weight_decay=self.args.shared_l2_reg) self.controller_optim = controller_optimizer( self.controller.parameters(), lr=self.args.controller_lr) hidden = self.shared.init_hidden(self.args.batch_size) for self.epoch in range(self.start_epoch, self.args.max_epoch): # 1. Training the shared parameters ω of the child models hidden = self.train_shared(hidden) # 2. Training the controller parameters θ self.train_controller() if self.epoch % self.args.save_epoch == 0: if self.epoch > 0: best_dag = self.derive() loss, ppl = self.test(self.test_data, best_dag, "test_best") self.save_model() if self.epoch >= self.args.shared_decay_after: update_lr(self.shared_optim, self.shared_lr) def get_loss(self, inputs, targets, hidden, dags, with_hidden=False): if type(dags) != list: dags = [dags] loss = 0 for dag in dags: # previous hidden is useless output, hidden = self.shared(inputs, hidden, dag) output_flat = output.view(-1, self.dataset.num_tokens) sample_loss = self.ce(output_flat, targets) / self.args.shared_num_sample loss += sample_loss if with_hidden: assert len( dags) == 1, "there are multiple `hidden` for multple `dags`" return loss, hidden else: return loss def train_shared(self, hidden): total_loss = 0 model = self.shared model.train() step, train_idx = 0, 0 pbar = tqdm(total=self.train_data.size(0), desc="train_shared") while train_idx < self.train_data.size(0) - 1 - 1: if step > self.args.shared_max_step: break dags = self.controller.sample(self.args.shared_num_sample) inputs, targets = self.get_batch(self.train_data, train_idx, self.max_length) loss = self.get_loss(inputs, targets, hidden, dags) # update self.shared_optim.zero_grad() loss.backward() t.nn.utils.clip_grad_norm(model.parameters(), self.args.shared_grad_clip) self.shared_optim.step() total_loss += loss.data pbar.set_description(f"train_shared| loss: {loss.data[0]:5.3f}") if step % self.args.log_step == 0 and step > 0: cur_loss = total_loss[0] / self.args.log_step ppl = math.exp(cur_loss) logger.info( f'| epoch {self.epoch:3d} | lr {self.shared_lr:4.2f} ' f'| loss {cur_loss:.2f} | ppl {ppl:8.2f}') # Tensorboard if self.tb is not None: self.tb.scalar_summary("shared/loss", cur_loss, self.shared_step) self.tb.scalar_summary("shared/perplexity", ppl, self.shared_step) total_loss = 0 step += 1 self.shared_step += 1 train_idx += self.max_length pbar.update(self.max_length) def get_reward(self, dag, valid_idx=None): if valid_idx: valid_idx = 0 inputs, targets = self.get_batch(self.valid_data, valid_idx, self.max_length) valid_loss = self.get_loss(inputs, targets, None, dag) valid_ppl = math.exp(valid_loss.data[0]) R = self.args.reward_c / valid_ppl return R def train_controller(self): total_loss = 0 model = self.controller model.train() pbar = trange(self.args.controller_max_step, desc="train_controller") baseline = None reward_history, adv_history, entropy_history = [], [], [] valid_idx = 0 for step in pbar: # sample models dags, log_probs, entropies = self.controller.sample( with_details=True) # calculate reward R = self.get_reward(dags, valid_idx) reward_history.append(R) entropy_history.extend(entropies) # moving average baseline if baseline is None: baseline = R else: decay = self.args.ema_baseline_decay baseline = decay * baseline + (1 - decay) * R adv = R - baseline adv_history.append(adv) pbar.set_description( f"train_controller| R: {R:8.6f} | R-b: {adv:8.6f}") rewards = [0] * (2 * (self.args.num_blocks - 1)) + [adv] # discount if self.args.discount == 1: rewards = [adv] * len(log_probs) elif self.args.discount > 0: rewards = discount(rewards, self.args.discount) #rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps) # policy loss loss = 0 for log_prob, reward, entropy in zip(log_probs, rewards, entropies): loss = loss - log_prob * reward - self.args.entropy_coeff * entropy # update self.controller_optim.zero_grad() loss.backward() self.controller_optim.step() total_loss += loss.data if step % self.args.log_step == 0 and step > 0: cur_loss = total_loss[0][0] / self.args.log_step avg_reward = np.mean(reward_history) avg_entropy = np.mean(entropy_history) avg_adv = np.mean(adv_history) logger.info( f'| epoch {self.epoch:3d} | lr {self.controller_lr:.5f} ' f'| R {avg_reward:.5f} | entropy {avg_entropy:.4f} ' f'| loss {cur_loss:.5f}') # Tensorboard if self.tb is not None: self.tb.scalar_summary("controller/loss", cur_loss, self.controller_step) self.tb.scalar_summary("controller/reward", avg_reward, self.controller_step) self.tb.scalar_summary("controller/entropy", avg_entropy, self.controller_step) self.tb.scalar_summary("controller/adv", avg_adv, self.controller_step) paths = [] for dag in dags: fname = f"{self.epoch:03d}-{self.controller_step:06d}-{avg_reward:6.4f}.png" path = os.path.join(self.args.model_dir, "networks", fname) draw_network(dag, path) paths.append(path) self.tb.image_summary("controller/sample", paths, self.controller_step) reward_history, adv_history, entropy_history = [], [], [] self.controller_step += 1 valid_idx = (valid_idx + self.max_length) % (self.valid_data.size(0) - 1) def test(self, source, dag, name, batch_size=1): self.shared.eval() self.controller.eval() total_loss = 0 hidden = self.shared.init_hidden(batch_size) pbar = trange(0, source.size(0) - 1, self.max_length, desc="test") for count, idx in enumerate(pbar): data, targets = self.get_batch(source, idx, evaluation=True) output, hidden = self.shared(data, hidden, dag) output_flat = output.view(-1, self.dataset.num_tokens) total_loss += len(data) * self.ce(output_flat, targets).data hidden = detach(hidden) ppl = math.exp(total_loss[0] / (count + 1) / self.max_length) pbar.set_description(f"test| ppl: {ppl:8.2f}") test_loss = total_loss[0] / len(source) ppl = math.exp(test_loss) self.tb.scalar_summary(f"test/{name}_loss", test_loss, self.epoch) self.tb.scalar_summary(f"test/{name}_ppl", ppl, self.epoch) return test_loss, ppl def derive(self, valid_idx=0, sample_num=None): if sample_num is None: sample_num = self.args.derive_num_sample dags = self.controller.sample(sample_num) max_R, best_dag = 0, None pbar = tqdm(dags, desc="derive") for dag in pbar: R = self.get_reward(dag, valid_idx) if R > max_R: max_R = R best_dag = dag pbar.set_description(f"derive| max_R: {max_R:8.6f}") fname = f"{self.epoch:03d}-{self.controller_step:06d}-{max_R:6.4f}-best.png" path = os.path.join(self.args.model_dir, "networks", fname) draw_network(best_dag, path) self.tb.image_summary("derive/best", [path], self.epoch) return best_dag @property def shared_lr(self): degree = max(self.epoch - self.args.shared_decay_after + 1, 0) return self.args.shared_lr * (self.args.shared_decay**degree) @property def controller_lr(self): return self.args.controller_lr def get_batch(self, source, idx, length=None, evaluation=False): # code from https://github.com/pytorch/examples/blob/master/word_language_model/main.py length = min(length if length else self.max_length, len(source) - 1 - idx) data = Variable(source[idx:idx + length], volatile=evaluation) target = Variable(source[idx + 1:idx + 1 + length].view(-1)) return data, target @property def shared_path(self): return f'{self.args.model_dir}/shared_epoch{self.epoch}_step{self.shared_step}.pth' @property def controller_path(self): return f'{self.args.model_dir}/controller_epoch{self.epoch}_step{self.controller_step}.pth' def get_saved_models_info(self): paths = glob(os.path.join(self.args.model_dir, '*.pth')) paths.sort() def get_numbers(items, delimiter, idx, replace_word, must_contain=''): return list( set([ int(name.split(delimiter)[idx].replace(replace_word, '')) for name in basenames if must_contain in name ])) basenames = [ os.path.basename(path.rsplit('.', 1)[0]) for path in paths ] epochs = get_numbers(basenames, '_', 1, 'epoch') shared_steps = get_numbers(basenames, '_', 2, 'step', 'shared') controller_steps = get_numbers(basenames, '_', 2, 'step', 'controller') epochs.sort() shared_steps.sort() controller_steps.sort() return epochs, shared_steps, controller_steps def save_model(self): t.save(self.shared.state_dict(), self.shared_path) logger.info(f"[*] SAVED: {self.shared_path}") t.save(self.controller.state_dict(), self.controller_path) logger.info(f"[*] SAVED: {self.controller_path}") epochs, shared_steps, controller_steps = self.get_saved_models_info() for epoch in epochs[:-self.args.max_save_num]: paths = glob( os.path.join(self.args.model_dir, f'*_epoch{epoch}_*.pth')) for path in paths: remove_file(path) def load_model(self): epochs, shared_steps, controller_steps = self.get_saved_models_info() if len(epochs) == 0: logger.info(f"[!] No checkpoint found in {self.args.model_dir}...") return self.start_epoch = max(epochs) self.shared_step = max(shared_steps) self.controller_step = max(controller_steps) if self.args.num_gpu == 0: map_location = lambda storage, loc: storage else: map_location = None self.shared.load_state_dict( t.load(self.shared_path, map_location=map_location)) logger.info(f"[*] LOADED: {self.shared_path}") self.controller.load_state_dict( t.load(self.controller_path, map_location=map_location)) logger.info(f"[*] LOADED: {self.controller_path}")
class Trainer(object): """A class to wrap training code.""" def __init__(self, args, dataset): """Constructor for training algorithm. Args: args: From command line, picked up by `argparse`. dataset: Currently only `data.text.Corpus` is supported. Initializes: - Data: train, val and test. - Model: shared and controller. - Inference: optimizers for shared and controller parameters. - Criticism: cross-entropy loss for training the shared model. """ self.args = args self.controller_step = 0 self.cuda = args.cuda self.dataset = dataset self.epoch = 0 self.shared_step = 0 self.start_epoch = 0 logger.info('regularizing:') for regularizer in [('activation regularization', self.args.activation_regularization), ('temporal activation regularization', self.args.temporal_activation_regularization), ('norm stabilizer regularization', self.args.norm_stabilizer_regularization)]: if regularizer[1]: logger.info('{0}'.format(regularizer[0])) self.train_data = utils.batchify(dataset.train, args.batch_size, self.cuda) # NOTE(brendan): The validation set data is batchified twice # separately: once for computing rewards during the Train Controller # phase (valid_data, batch size == 64), and once for evaluating ppl # over the entire validation set (eval_data, batch size == 1) self.valid_data = utils.batchify(dataset.valid, args.batch_size, self.cuda) self.eval_data = utils.batchify(dataset.valid, args.test_batch_size, self.cuda) self.test_data = utils.batchify(dataset.test, args.test_batch_size, self.cuda) self.max_length = self.args.shared_rnn_max_length # default=35 if args.use_tensorboard: self.tb = TensorBoard(args.model_dir) else: self.tb = None self.build_model() # 创建一个模型存入self.shared中,这里可以是RNN或CNN,再创建一个Controler if self.args.load_path: self.load_model() shared_optimizer = _get_optimizer(self.args.shared_optim) controller_optimizer = _get_optimizer(self.args.controller_optim) self.shared_optim = shared_optimizer( self.shared.parameters(), lr=self.shared_lr, weight_decay=self.args.shared_l2_reg) self.controller_optim = controller_optimizer( self.controller.parameters(), lr=self.args.controller_lr) self.ce = nn.CrossEntropyLoss() def build_model(self): """Creates and initializes the shared and controller models.""" if self.args.network_type == 'rnn': self.shared = models.RNN(self.args, self.dataset) elif self.args.network_type == 'cnn': self.shared = models.CNN(self.args, self.dataset) else: raise NotImplementedError( 'Network type `{0}` is not defined'.format( self.args.network_type)) self.controller = models.Controller( self.args ) # 构建了一个orward:Embedding(130,100)->lstm(100,100)->decoder的列表,对应25个decoder if self.args.num_gpu == 1: self.shared.cuda() self.controller.cuda() elif self.args.num_gpu > 1: raise NotImplementedError('`num_gpu > 1` is in progress') def train(self, single=False): """Cycles through alternately training the shared parameters and the controller, as described in Section 2.2, Training ENAS and Deriving Architectures, of the paper. From the paper (for Penn Treebank): - In the first phase, shared parameters omega are trained for 400 steps, each on a minibatch of 64 examples. - In the second phase, the controller's parameters are trained for 2000 steps. Args: single (bool): If True it won't train the controller and use the same dag instead of derive(). """ dag = utils.load_dag(self.args) if single else None # 初始训练dag=None if self.args.shared_initial_step > 0: # self.args.shared_initial_step default=0 self.train_shared(self.args.shared_initial_step) self.train_controller() for self.epoch in range( self.start_epoch, self.args.max_epoch): # start_epoch=0,max_epoch=150 # 1. Training the shared parameters omega of the child models # 训练RNN,先用Controller随机生成一个dag,然后用这个dag构建一个RNNcell,然后用这个RNNcell去做下一个词预测,得到loss self.train_shared(dag=dag) # 2. Training the controller parameters theta if not single: self.train_controller() if self.epoch % self.args.save_epoch == 0: with _get_no_grad_ctx_mgr(): best_dag = dag if dag else self.derive() self.evaluate(self.eval_data, best_dag, 'val_best', max_num=self.args.batch_size * 100) self.save_model() #应该是逐渐降低学习率 if self.epoch >= self.args.shared_decay_after: utils.update_lr(self.shared_optim, self.shared_lr) def get_loss(self, inputs, targets, hidden, dags): """ :param inputs:输入数据,[35,64] :param targets: 目标数据(相当于标签)[35,64] 输入的词后移一个词 :param hidden: 隐藏层参数 :param dags: RNN 的cell结构 :return: decoded(35,64,10000),hidden(64,1000),extra_out{dropped_output(35,64,1000),h1tohT(35,64,1000),raw_output(35,64,1000) """ """Computes the loss for the same batch for M models. This amounts to an estimate of the loss, which is turned into an estimate for the gradients of the shared model. """ if not isinstance(dags, list): dags = [dags] loss = 0 for dag in dags: # decoded(35,64,10000),hidden(64,1000),extra_out{dropped_output(35,64,1000),h1tohT(35,64,1000),raw_output(35,64,1000) output, hidden, extra_out = self.shared( inputs, dag, hidden=hidden) # RNN.forward output_flat = output.view(-1, self.dataset.num_tokens) # (2240,10000) # self.ce=nn.CrossEntropyLoss() target(2240) shared_num_sample=1 sample_loss = (self.ce(output_flat, targets) / self.args.shared_num_sample) loss += sample_loss assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`' return loss, hidden, extra_out def train_shared(self, max_step=None, dag=None): """Train the language model for 400 steps of minibatches of 64 examples. Args: max_step: Used to run extra training steps as a warm-up. dag: If not None, is used instead of calling sample(). BPTT is truncated at 35 timesteps. #基于时间的反向传播算法BPTT(Back Propagation Trough Time) For each weight update, gradients are estimated by sampling M models from the fixed controller policy, and averaging their gradients computed on a batch of training data. """ model = self.shared # model.RNN model.train( ) # set RNN.training属性为true 即当前训练的是RNN而不训练Controller https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.train self.controller.eval( ) # Sets the module in evaluation mode. This is equivalent with self.train(False). # 功能:初始化variable,即全零的Tensor hidden = self.shared.init_hidden(self.args.batch_size) if max_step is None: max_step = self.args.shared_max_step # shared_max_step=150 else: max_step = min(self.args.shared_max_step, max_step) abs_max_grad = 0 abs_max_hidden_norm = 0 step = 0 raw_total_loss = 0 # 用于统计结果的,和计算过程无关 total_loss = 0 train_idx = 0 # TODO(brendan): Why - 1 - 1?为什么-1-1? # TODO(为什么-1-1)这里的train_idx是批次的编号,一共14524个batch(每个batch有64个词)为了训练输入数据不可能取最后一个batch # TODO(为什么-1-1)因为如果是最后一个batch就没有target了,因此最后一个batch是倒数第二个,而倒数第二个的下标是 size-2 # self.train_data.size(0) 14524 while train_idx < self.train_data.size(0) - 1 - 1: if step > max_step: break # Controller负责sample一个dag出来,是一个list,里面有一个defaultdict,存储了dag的连接信息 # 这一步只是提取Controller的值,并没有训练,初始的时候也是随机得出来的一个dag dags = dag if dag else self.controller.sample( batch_size=self.args.shared_num_sample ) # shared_num_sample:default=1 # 提取一个max_length长度的数据集(35,64),35个批次,每个批次64个词,组成一个训练批次 # input是训练数据,target是每个输入的词后面的词,用于训练RNN的 inputs, targets = self.get_batch(self.train_data, train_idx, self.max_length) # max_length=35 # get_loss完成了由dag生成的RNNcell的前向计算 loss, hidden, extra_out = self.get_loss(inputs, targets, hidden, dags) # Detaches the Tensor from the graph that created it, making it a leaf. Views cannot be detached in-place. hidden.detach_() raw_total_loss += loss.data # 根据命令行参数加一下正则惩罚项 loss += _apply_penalties(extra_out, self.args) # update self.shared_optim.zero_grad() loss.backward() # 反向更新 h1tohT = extra_out['hiddens'] # 和日志有关,和计算无关 new_abs_max_hidden_norm = utils.to_item( h1tohT.norm(dim=-1).data.max()) if new_abs_max_hidden_norm > abs_max_hidden_norm: abs_max_hidden_norm = new_abs_max_hidden_norm logger.info('max hidden {0}'.format(abs_max_hidden_norm)) # 函数的功能是获取Tensor图中的最大梯度,来检测是否出现梯度爆炸,但好像后面没有使用 abs_max_grad = _check_abs_max_grad(abs_max_grad, model) # Clips gradient norm of an iterable of parameters. # The norm is computed over all gradients together, as if they were concatenated into a single vector. # Gradients are modified in-place. torch.nn.utils.clip_grad_norm( model.parameters(), self.args.shared_grad_clip) # shared_grad_clip=0.25 self.shared_optim.step() # Performs a single optimization step. total_loss += loss.data # 和log有关 if ((step % self.args.log_step) == 0) and (step > 0): self._summarize_shared_train(total_loss, raw_total_loss) raw_total_loss = 0 total_loss = 0 step += 1 self.shared_step += 1 train_idx += self.max_length # max_length:35,下一个batch def get_reward(self, dag, entropies, hidden, valid_idx=0): """Computes the perplexity of a single sampled model on a minibatch of validation data. 计算模型的PPL:每个词的条件预测概率(即已知前n个词预测第n+1个词的概率)的累积的倒数开N(全体词的数量)次方 """ if not isinstance(entropies, np.ndarray): entropies = entropies.data.cpu().numpy() inputs, targets = self.get_batch(self.valid_data, valid_idx, self.max_length, volatile=True) valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag) #RNN.forward valid_loss = utils.to_item(valid_loss.data) valid_ppl = math.exp(valid_loss) #计算PPL # TODO: we don't know reward_c if self.args.ppl_square: #default:false # TODO: but we do know reward_c=80 in the previous paper R = self.args.reward_c / valid_ppl**2 else: R = self.args.reward_c / valid_ppl #这个值的作用在NAS(Zoph and Le, 2017) page 8 states that c is a constant if self.args.entropy_mode == 'reward': #entroy_mode:default:reward rewards = R + self.args.entropy_coeff * entropies # entropy_coeff:default=1e-4 elif self.args.entropy_mode == 'regularizer': rewards = R * np.ones_like(entropies) else: raise NotImplementedError('Unkown entropy mode: {0}'.format( self.args.entropy_mode)) return rewards, hidden def train_controller(self): """Fixes the shared parameters and updates the controller parameters. The controller is updated with a score function gradient estimator (i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl is computed on a minibatch of validation data. A moving average baseline is used. The controller is trained for 2000 steps per epoch (i.e., first (Train Shared) phase -> second (Train Controller) phase). """ model = self.controller model.train() # 设置Controller的train属性为true,当前训练Controller # 这里为什么不调用hared.eval()? 这是因为会导致Controller的loss一直为零。 # self.shared.eval(),上面的解释应该是Brendon这个人测试之后的结论 avg_reward_base = None baseline = None # 这几个是用于统计信息的 adv_history = [] entropy_history = [] reward_history = [] hidden = self.shared.init_hidden(self.args.batch_size) total_loss = 0 valid_idx = 0 for step in range(self.args.controller_max_step): #controller_max_step # sample models #dags:list([1])(defaultdict([25])),log_probs:Tensor.size([23]),entropies:Tensor.size([23])交叉熵:-ylogy dags, log_probs, entropies = self.controller.sample( with_details=True) # calculate reward np_entropies = entropies.data.cpu().numpy() # NOTE(brendan): No gradients should be backpropagated to the # shared model during controller training, obviously. """ with 语句实质是上下文管理。 1、上下文管理协议。包含方法__enter__() 和 __exit__(),支持该协议对象要实现这两个方法。 2、上下文管理器,定义执行with语句时要建立的运行时上下文,负责执行with语句块上下文中的进入与退出操作。 3、进入上下文的时候执行__enter__方法,如果设置as var语句,var变量接受__enter__()方法返回值。 4、如果运行时发生了异常,就退出上下文管理器。调用管理器__exit__方法。 """ # 创建了一个torch.no_grad()的上下文,执行get_reward的时候是不需要计算梯度的,执行完get_reward在恢复计算梯度模式 with _get_no_grad_ctx_mgr(): rewards, hidden = self.get_reward(dags, np_entropies, hidden, valid_idx) # discount 默认未启用 if 1 > self.args.discount > 0: #discout:default=1 rewards = discount(rewards, self.args.discount) reward_history.extend(rewards) entropy_history.extend(np_entropies) # moving average baseline if baseline is None: baseline = rewards else: decay = self.args.ema_baseline_decay #****ema_baseline_decay:default=0.95 very important baseline = decay * baseline + (1 - decay) * rewards adv = rewards - baseline adv_history.extend(adv) # policy loss loss = -log_probs * utils.get_variable( adv, self.cuda, requires_grad=False) if self.args.entropy_mode == 'regularizer': #entropy_mode:default='reward' loss -= self.args.entropy_coeff * entropies loss = loss.sum() # or loss.mean() # update self.controller_optim.zero_grad() loss.backward() if self.args.controller_grad_clip > 0: torch.nn.utils.clip_grad_norm(model.parameters(), self.args.controller_grad_clip) self.controller_optim.step() total_loss += utils.to_item(loss.data) if ((step % self.args.log_step) == 0) and (step > 0): self._summarize_controller_train(total_loss, adv_history, entropy_history, reward_history, avg_reward_base, dags) reward_history, adv_history, entropy_history = [], [], [] total_loss = 0 self.controller_step += 1 prev_valid_idx = valid_idx valid_idx = ((valid_idx + self.max_length) % (self.valid_data.size(0) - 1)) # NOTE(brendan): Whenever we wrap around to the beginning of the # validation data, we reset the hidden states. if prev_valid_idx > valid_idx: hidden = self.shared.init_hidden(self.args.batch_size) def evaluate(self, source, dag, name, batch_size=1, max_num=None): """Evaluate on the validation set. NOTE(brendan): We should not be using the test set to develop the algorithm (basic machine learning good practices). """ self.shared.eval() self.controller.eval() data = source[:max_num * self.max_length] total_loss = 0 hidden = self.shared.init_hidden(batch_size) pbar = range(0, data.size(0) - 1, self.max_length) for count, idx in enumerate(pbar): inputs, targets = self.get_batch(data, idx, volatile=True) output, hidden, _ = self.shared(inputs, dag, hidden=hidden, is_train=False) output_flat = output.view(-1, self.dataset.num_tokens) total_loss += len(inputs) * self.ce(output_flat, targets).data hidden.detach_() ppl = math.exp( utils.to_item(total_loss) / (count + 1) / self.max_length) val_loss = utils.to_item(total_loss) / len(data) ppl = math.exp(val_loss) self.tb.scalar_summary('eval/{0}_loss'.format(name), val_loss, self.epoch) self.tb.scalar_summary('eval/{0}_ppl'.format(name), ppl, self.epoch) logger.info('eval | loss: {0:8.2f} | ppl: {1:8.2f}'.format( val_loss, ppl)) def derive(self, sample_num=None, valid_idx=0): """TODO(brendan): We are always deriving based on the very first batch of validation data? This seems wrong... """ hidden = self.shared.init_hidden(self.args.batch_size) if sample_num is None: sample_num = self.args.derive_num_sample dags, _, entropies = self.controller.sample(sample_num, with_details=True) max_R = 0 best_dag = None for dag in dags: R, _ = self.get_reward(dag, entropies, hidden, valid_idx) if R.max() > max_R: max_R = R.max() best_dag = dag logger.info('derive | max_R: {0:8.6f}'.format(max_R)) fname = ('{0:03d}-{1:06d}-{2:6.4f}-best.png'.format( self.epoch, self.controller_step, max_R)) path = os.path.join(self.args.model_dir, 'networks', fname) #utils.draw_network(best_dag, path) #self.tb.image_summary('derive/best', [path], self.epoch) return best_dag @property def shared_lr(self): degree = max(self.epoch - self.args.shared_decay_after + 1, 0) return self.args.shared_lr * (self.args.shared_decay**degree) @property #将类方法转换为类属性,可以用 . 直接获取属性值或者对属性进行赋值 def controller_lr(self): return self.args.controller_lr def get_batch(self, source, idx, length=None, volatile=False): """ 这个函数的作用是从数据集中取得length长度的数据组成一个Variable(这个操作在pytorch中已经过时了,可以直接使用Tensor来生成计算,而不用 再使用Variable来封装Tensor来计算 这里的batch指的是取词窗口组成的batch,length是最多取多少个batch_size的词 :param source:数据集train_data :param idx: 当前数据样本索引值 :param length:max_length=35? :param volatile(易变的):Volatile is recommended for purely inference mode, when you’re sure you won’t be even calling .backward() 设定volatie选项为true的话则只是取值模式,而不会进行反向计算 :return: """ # code from # https://github.com/pytorch/examples/blob/master/word_language_model/main.py length = min(length if length else self.max_length, len(source) - 1 - idx) #UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead. data = Variable(source[idx:idx + length], volatile=volatile) # shape(35,64) 取35个批次,每个批次64个词 target = Variable(source[idx + 1:idx + 1 + length].view(-1), volatile=volatile) # view(35,64)->(2240) # 这里target=data+1的意思是从data中推断下一个词 return data, target @property def shared_path(self): return '{0}/shared_epoch{1:d}_step{2:d}.pth'.format( self.args.model_dir, self.epoch, self.shared_step) @property def controller_path(self): return '{}/controller_epoch{}_step{}.pth'.format( self.args.model_dir, self.epoch, self.controller_step) def get_saved_models_info(self): paths = glob.glob(os.path.join(self.args.model_dir, '*.pth')) paths.sort() def get_numbers(items, delimiter, idx, replace_word, must_contain=''): return list( set([ int(name.split(delimiter)[idx].replace(replace_word, '')) for name in basenames if must_contain in name ])) basenames = [ os.path.basename(path.rsplit('.', 1)[0]) for path in paths ] epochs = get_numbers(basenames, '_', 1, 'epoch') shared_steps = get_numbers(basenames, '_', 2, 'step', 'shared') controller_steps = get_numbers(basenames, '_', 2, 'step', 'controller') epochs.sort() shared_steps.sort() controller_steps.sort() return epochs, shared_steps, controller_steps def save_model(self): torch.save(self.shared.state_dict(), self.shared_path) logger.info('[*] SAVED: {0}'.format(self.shared_path)) torch.save(self.controller.state_dict(), self.controller_path) logger.info('[*] SAVED: {0}'.format(self.controller_path)) epochs, shared_steps, controller_steps = self.get_saved_models_info() for epoch in epochs[:-self.args.max_save_num]: paths = glob.glob( os.path.join(self.args.model_dir, '*_epoch{0}_*.pth'.format(epoch))) for path in paths: utils.remove_file(path) def load_model(self): epochs, shared_steps, controller_steps = self.get_saved_models_info() if len(epochs) == 0: logger.info('[!] No checkpoint found in {0}...'.format( self.args.model_dir)) return self.epoch = self.start_epoch = max(epochs) self.shared_step = max(shared_steps) self.controller_step = max(controller_steps) if self.args.num_gpu == 0: map_location = lambda storage, loc: storage else: map_location = None self.shared.load_state_dict( torch.load(self.shared_path, map_location=map_location)) logger.info('[*] LOADED: {0}'.format(self.shared_path)) self.controller.load_state_dict( torch.load(self.controller_path, map_location=map_location)) logger.info('[*] LOADED: {0}'.format(self.controller_path)) def _summarize_controller_train(self, total_loss, adv_history, entropy_history, reward_history, avg_reward_base, dags): """Logs the controller's progress for this training epoch.""" cur_loss = total_loss / self.args.log_step avg_adv = np.mean(adv_history) avg_entropy = np.mean(entropy_history) avg_reward = np.mean(reward_history) if avg_reward_base is None: avg_reward_base = avg_reward logger.info( '| epoch {0:3d} | lr {1:.5f} | R {2:.5f} | entropy {3:.4f} | loss {:.5f}' .format(self.epoch, self.controller_lr, avg_reward, avg_entropy, cur_loss)) # Tensorboard if self.tb is not None: self.tb.scalar_summary('controller/loss', cur_loss, self.controller_step) self.tb.scalar_summary('controller/reward', avg_reward, self.controller_step) self.tb.scalar_summary('controller/reward-B_per_epoch', avg_reward - avg_reward_base, self.controller_step) self.tb.scalar_summary('controller/entropy', avg_entropy, self.controller_step) self.tb.scalar_summary('controller/adv', avg_adv, self.controller_step) paths = [] for dag in dags: fname = ('{0:03d}-{1:06d}-{2:6.4f}.png'.format( self.epoch, self.controller_step, avg_reward)) path = os.path.join(self.args.model_dir, 'networks', fname) utils.draw_network(dag, path) paths.append(path) self.tb.image_summary('controller/sample', paths, self.controller_step) def _summarize_shared_train(self, total_loss, raw_total_loss): """Logs a set of training steps.""" cur_loss = utils.to_item(total_loss) / self.args.log_step # NOTE(brendan): The raw loss, without adding in the activation # regularization terms, should be used to compute ppl. cur_raw_loss = utils.to_item(raw_total_loss) / self.args.log_step ppl = math.exp(cur_raw_loss) logger.info( '| epoch {0:3d} | lr {1:4.2f} | raw loss {2:.2f} | loss {3:.2f} | ppl {4:8.2f}' .format(self.epoch, self.shared_lr, cur_raw_loss, cur_loss, ppl)) # Tensorboard if self.tb is not None: self.tb.scalar_summary('shared/loss', cur_loss, self.shared_step) self.tb.scalar_summary('shared/perplexity', ppl, self.shared_step)
class Trainer(object): """A class to wrap training code.""" def __init__(self, args, dataset): """Constructor for training algorithm. Args: args: From command line, picked up by `argparse`. dataset: Currently only `data.text.Corpus` is supported. Initializes: - Data: train, val and test. - Model: shared and controller. - Inference: optimizers for shared and controller parameters. - Criticism: cross-entropy loss for training the shared model. """ self.args = args self.controller_step = 0 self.cuda = args.cuda self.device = gpu = torch.device("cuda:0") self.dataset = dataset self.epoch = 0 self.shared_step = 0 self.start_epoch = 0 self.compute_fisher = False logger.info('regularizing:') for regularizer in [('activation regularization', self.args.activation_regularization), ('temporal activation regularization', self.args.temporal_activation_regularization), ('norm stabilizer regularization', self.args.norm_stabilizer_regularization)]: if regularizer[1]: logger.info(f'{regularizer[0]}') self.image_dataset = isinstance(dataset, Image) if self.image_dataset: self._train_data = dataset.train self._valid_data = dataset.valid self._test_data = dataset.test self._eval_data = dataset.valid self.train_data = wrap_iterator_with_name(self._train_data, 'train') self.valid_data = wrap_iterator_with_name(self._valid_data, 'valid') self.test_data = wrap_iterator_with_name(self._test_data, 'test') self.eval_data = wrap_iterator_with_name(self._eval_data, 'eval') self.max_length = 0 else: self.train_data = utils.batchify(dataset.train, args.batch_size, self.cuda) self.valid_data = utils.batchify(dataset.valid, args.batch_size, self.cuda) self.eval_data = utils.batchify(dataset.valid, args.test_batch_size, self.cuda) self.test_data = utils.batchify(dataset.test, args.test_batch_size, self.cuda) self.max_length = self.args.shared_rnn_max_length self.train_data_size = self.train_data.size( 0) if not self.image_dataset else len(self.train_data) self.valid_data_size = self.valid_data.size( 0) if not self.image_dataset else len(self.valid_data) self.test_data_size = self.test_data.size( 0) if not self.image_dataset else len(self.test_data) # Visualization if args.use_tensorboard: self.tb = TensorBoard(args.model_dir) else: self.tb = None self.draw_network = utils.draw_network self.build_model() if self.args.load_path: self.load_model() shared_optimizer = _get_optimizer(self.args.shared_optim) controller_optimizer = _get_optimizer(self.args.controller_optim) # As fisher information, and it should be seen by this model, to get the loss. self.shared_optim = shared_optimizer( self.shared.parameters(), lr=self.shared_lr, weight_decay=self.args.shared_l2_reg) self.controller_optim = controller_optimizer( self.controller.parameters(), lr=self.args.controller_lr) self.ce = nn.CrossEntropyLoss() self.top_k_acc = top_k_accuracy def build_model(self): """Creates and initializes the shared and controller models.""" if self.args.network_type == 'rnn': self.shared = models.RNN(self.args, self.dataset) self.controller = models.Controller(self.args) elif self.args.network_type == 'micro_cnn': self.shared = models.CNN(self.args, self.dataset) self.controller = models.CNNMicroController(self.args) else: raise NotImplementedError(f'Network type ' f'`{self.args.network_type}` is not ' f'defined') if self.args.num_gpu == 1: if torch.__version__ == '0.3.1': self.shared.cuda() self.controller.cuda() else: self.shared.to(self.device) self.controller.to(self.device) elif self.args.num_gpu > 1: raise NotImplementedError('`num_gpu > 1` is in progress') def train(self): """Cycles through alternately training the shared parameters and the controller, as described in Section 2.2, Training ENAS and Deriving Architectures, of the paper. From the paper (for Penn Treebank): - In the first phase, shared parameters omega are trained for 400 steps, each on a minibatch of 64 examples. - In the second phase, the controller's parameters are trained for 2000 steps. """ if self.args.shared_initial_step > 0: self.train_shared(self.args.shared_initial_step) self.train_controller() for self.epoch in range(self.start_epoch, self.args.max_epoch): if self.epoch >= self.args.start_using_fisher: self.compute_fisher = True if self.args.set_fisher_zero_per_iter > 0 \ and self.epoch % self.args.set_fisher_zero_per_iter == 0: self.shared.set_fisher_zero() # 1. Training the shared parameters omega of the child models self.train_shared() # 2. Training the controller parameters theta if self.args.train_controller: if self.epoch < self.args.stop_training_controller: self.train_controller() if self.epoch % self.args.save_epoch == 0: with _get_no_grad_ctx_mgr(): best_dag = self.derive() self.evaluate(self.eval_data, best_dag, 'val_best', max_num=self.args.batch_size * 100) self.save_model() if self.epoch >= self.args.shared_decay_after: utils.update_lr(self.shared_optim, self.shared_lr) def get_loss(self, inputs, targets, dags, **kwargs): """Computes the loss for the same batch for M models. This amounts to an estimate of the loss, which is turned into an estimate for the gradients of the shared model. We store, compute the new WPL. :param **kwargs: passed into self.shared(, such as hidden) """ if not isinstance(dags, list): dags = [dags] loss = 0 for dag in dags: output, hidden, extra_out = self.shared(inputs, dag, **kwargs) output_flat = output.view(-1, self.dataset.num_classes) sample_loss = (self.ce(output_flat, targets) / self.args.shared_num_sample) # Get WPL part if self.compute_fisher: wpl = self.shared.compute_weight_plastic_loss_with_update_fisher( dag) wpl = 0.5 * wpl loss += sample_loss + wpl rest_loss = wpl else: loss += sample_loss rest_loss = Variable(torch.zeros(1)) # logger.info(f'Loss {loss.data[0]} = ' # f'sample_loss {sample_loss.data[0]}') #assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`' return loss, sample_loss, rest_loss, hidden, extra_out def train_shared(self, max_step=None): """Train the language model for 400 steps of minibatches of 64 examples. Args: max_step: Used to run extra training steps as a warm-up. BPTT is truncated at 35 timesteps. For each weight update, gradients are estimated by sampling M models from the fixed controller policy, and averaging their gradients computed on a batch of training data. """ valid_ppls = [] valid_ppls_after = [] model = self.shared model.train() self.controller.eval() hidden = self.shared.init_training(self.args.batch_size) v_hidden = self.shared.init_training(self.args.batch_size) if max_step is None: max_step = self.args.shared_max_step else: max_step = min(self.args.shared_max_step, max_step) abs_max_grad = 0 abs_max_hidden_norm = 0 step = 0 raw_total_loss = 0 total_loss = 0 total_sample_loss = 0 total_rest_loss = 0 train_idx = 0 valid_idx = 0 def _run_shared_one_batch(inputs, targets, hidden, dags, raw_total_loss): # global abs_max_grad # global abs_max_hidden_norm # global raw_total_loss loss, sample_loss, rest_loss, hidden, extra_out = self.get_loss( inputs, targets, dags, hidden=hidden) # Detach the hidden # Because they are input from previous state. hidden = utils.detach(hidden) raw_total_loss += sample_loss.data / self.args.num_batch_per_iter penalty_loss = _apply_penalties(extra_out, self.args) loss += penalty_loss rest_loss += penalty_loss return loss, sample_loss, rest_loss, hidden, extra_out, raw_total_loss def _clip_gradient(abs_max_grad, abs_max_hidden_norm): h1tohT = extra_out['hiddens'] new_abs_max_hidden_norm = utils.to_item( h1tohT.norm(dim=-1).data.max()) if new_abs_max_hidden_norm > abs_max_hidden_norm: abs_max_hidden_norm = new_abs_max_hidden_norm logger.info(f'max hidden {abs_max_hidden_norm}') abs_max_grad = _check_abs_max_grad(abs_max_grad, model) torch.nn.utils.clip_grad_norm(model.parameters(), self.args.shared_grad_clip) return abs_max_grad, abs_max_hidden_norm def _evaluate_valid(dag): hidden_eval = self.shared.init_training(self.args.batch_size) inputs_eval, targets_eval = self.get_batch(self.valid_data, 0, self.max_length, volatile=True) _, valid_loss_eval, _, _, _ = self.get_loss(inputs_eval, targets_eval, dag, hidden=hidden_eval) valid_loss_eval = utils.to_item(valid_loss_eval.data) valid_ppl_eval = math.exp(valid_loss_eval) # return valid_ppl_eval dags_eval = [] while train_idx < self.train_data_size - 1 - 1: if step > max_step: break dags = self.controller.sample(self.args.shared_num_sample) dags_eval.append(dags[0]) for b in range(0, self.args.num_batch_per_iter): # For each model, do the update for 30 batches. inputs, targets = self.get_batch(self.train_data, train_idx, self.max_length) loss, sample_loss, rest_loss, hidden, extra_out, raw_total_loss = \ _run_shared_one_batch( inputs, targets, hidden, dags, raw_total_loss) # update with complete logic # First, normally we compute one loss and do update accordingly. # if in the last batch, we compute the fisher information # based on two kinds of loss, complete or ce-loss only. self.shared_optim.zero_grad() # If it is the last training batch. Update the Fisher information if self.compute_fisher and (not self.args.shared_valid_fisher): if b == self.args.num_batch_per_iter - 1: sample_loss.backward() if self.args.shared_ce_fisher: self.shared.update_fisher(dags[0]) rest_loss.backward() else: rest_loss.backward() self.shared.update_fisher(dags[0]) else: loss.backward() else: loss.backward() abs_max_grad, abs_max_hidden_norm = _clip_gradient( abs_max_grad, abs_max_hidden_norm) self.shared_optim.step() total_loss += loss.data / self.args.num_batch_per_iter total_sample_loss += sample_loss.data / self.args.num_batch_per_iter total_rest_loss += rest_loss.data / self.args.num_batch_per_iter train_idx = ((train_idx + self.max_length) % (self.train_data_size - 1)) if self.epoch > self.args.start_evaluate_diff: valid_ppl_eval = _evaluate_valid(dags[0]) valid_ppls.append(valid_ppl_eval) logger.info( f'Step {step}' f'Loss {utils.to_item(total_loss) / (step + 1):.5f} = ' f'sample_loss {utils.to_item(total_sample_loss) / (step + 1):.5f} + ' f'wpl {utils.to_item(total_rest_loss) / (step + 1):.5f}') if ((step % self.args.log_step) == 0) and (step > 0): self._summarize_shared_train(total_loss, raw_total_loss) raw_total_loss = 0 total_loss = 0 total_sample_loss = 0 total_rest_loss = 0 if self.compute_fisher: # Update with the validation dataset for fisher information after each step, # with update the optimal weights. v_inputs, v_targets = self.get_batch(self.valid_data, valid_idx, self.max_length) v_loss, v_sample_loss, _, v_hidden, v_extra_out, _ = _run_shared_one_batch( v_inputs, v_targets, v_hidden, dags, 0) self.shared_optim.zero_grad() if self.args.shared_ce_fisher: v_sample_loss.backward() else: v_loss.backward() self.shared.update_fisher(dags[0], self.epoch) self.shared.update_optimal_weights() valid_idx = ((valid_idx + self.max_length) % (self.valid_data_size - 1)) step += 1 self.shared_step += 1 if self.epoch > self.args.start_evaluate_diff: for arch in dags_eval: valid_ppl_eval = _evaluate_valid(arch) valid_ppls_after.append(valid_ppl_eval) logger.info(f'valid_ppl {valid_ppl_eval}') diff = np.array(valid_ppls_after) - np.array(valid_ppls) logger.info(f'Mean_diff {np.mean(diff)}') logger.info(f'Max_diff {np.amax(diff)}') self.tb.scalar_summary(f'Mean difference', np.mean(diff), self.epoch) self.tb.scalar_summary(f'Max difference', np.amax(diff), self.epoch) self.tb.scalar_summary(f'Mean valid_ppl after training', np.mean(np.array(valid_ppls_after)), self.epoch) self.tb.scalar_summary(f'Mean valid_ppl before training', np.mean(np.array(valid_ppls)), self.epoch) self.tb.scalar_summary(f'std_diff', np.std(np.array(diff)), self.epoch) def get_reward(self, dags, entropies, hidden, valid_idx=None): """ Computes the reward of a single sampled model or multiple on a minibatch of validation data. """ if not isinstance(entropies, np.ndarray): entropies = entropies.data.cpu().numpy() if valid_idx is None: valid_idx = 0 inputs, targets = self.get_batch(self.valid_data, valid_idx, self.max_length, volatile=True) _, valid_loss, _, hidden, _ = self.get_loss(inputs, targets, dags, hidden=hidden) valid_loss = utils.to_item(valid_loss.data) valid_ppl = math.exp(valid_loss) if self.args.ppl_square: R = self.args.reward_c / valid_ppl**2 else: R = self.args.reward_c / valid_ppl if self.args.entropy_mode == 'reward': rewards = R + self.args.entropy_coeff * entropies elif self.args.entropy_mode == 'regularizer': rewards = R * np.ones_like(entropies) else: raise NotImplementedError( f'Unkown entropy mode: {self.args.entropy_mode}') return rewards, hidden def train_controller(self): """Fixes the shared parameters and updates the controller parameters. The controller is updated with a score function gradient estimator (i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl is computed on a minibatch of validation data. A moving average baseline is used. The controller is trained for 2000 steps per epoch (i.e., first (Train Shared) phase -> second (Train Controller) phase). """ model = self.controller model.train() avg_reward_base = None baseline = None adv_history = [] entropy_history = [] reward_history = [] hidden = self.shared.init_training(self.args.batch_size) total_loss = 0 valid_idx = 0 for step in range(self.args.controller_max_step): # print("************ train controller ****************") # sample models dags, log_probs, entropies = self.controller.sample( batch_size=self.args.policy_batch_size, with_details=True) # calculate reward np_entropies = entropies.data.cpu().numpy() with _get_no_grad_ctx_mgr(): rewards, hidden = self.get_reward(dags, np_entropies, hidden, valid_idx) # discount if 1 > self.args.discount > 0: rewards = discount(rewards, self.args.discount) reward_history.extend(rewards) entropy_history.extend(np_entropies) # moving average baseline if baseline is None: baseline = rewards else: decay = self.args.ema_baseline_decay baseline = decay * baseline + (1 - decay) * rewards adv = rewards - baseline adv_history.extend(adv) # policy loss loss = -log_probs * utils.get_variable( adv, self.cuda, requires_grad=False) if self.args.entropy_mode == 'regularizer': loss -= self.args.entropy_coeff * entropies loss = loss.sum() # or loss.mean() # update self.controller_optim.zero_grad() loss.backward() if self.args.controller_grad_clip > 0: torch.nn.utils.clip_grad_norm(model.parameters(), self.args.controller_grad_clip) self.controller_optim.step() total_loss += utils.to_item(loss.data) if ((step % self.args.log_step) == 0) and (step > 0): self._summarize_controller_train(total_loss, adv_history, entropy_history, reward_history, avg_reward_base, dags) reward_history, adv_history, entropy_history = [], [], [] total_loss = 0 self.controller_step += 1 prev_valid_idx = valid_idx valid_idx = ((valid_idx + self.max_length) % (self.valid_data_size - 1)) if prev_valid_idx > valid_idx: hidden = self.shared.init_training(self.args.batch_size) def evaluate(self, source, dag, name, batch_size=1, max_num=None): """Evaluate on the validation set. """ self.shared.eval() self.controller.eval() if self.image_dataset: data = source else: data = source[:max_num * self.max_length] total_loss = 0 hidden = self.shared.init_training(batch_size) pbar = range(0, self.valid_data_size - 1, self.max_length) for count, idx in enumerate(pbar): inputs, targets = self.get_batch(data, idx, volatile=True) output, hidden, _ = self.shared(inputs, dag, hidden=hidden, is_train=False) output_flat = output.view(-1, self.dataset.num_classes) total_loss += len(inputs) * self.ce(output_flat, targets).data hidden = utils.detach(hidden) ppl = math.exp( utils.to_item(total_loss) / (count + 1) / self.max_length) val_loss = utils.to_item(total_loss) / len(data) ppl = math.exp(val_loss) self.tb.scalar_summary(f'eval/{name}_loss', val_loss, self.epoch) self.tb.scalar_summary(f'eval/{name}_ppl', ppl, self.epoch) logger.info(f'eval | loss: {val_loss:8.2f} | ppl: {ppl:8.2f}') def derive(self, sample_num=None, valid_idx=0): if sample_num is None: sample_num = self.args.derive_num_sample dags, _, entropies = self.controller.sample(sample_num, with_details=True) max_R = 0 best_dag = None for dag in dags: if self.image_dataset: R, _ = self.get_reward([dag], entropies, valid_idx) else: hidden = self.shared.init_training(self.args.batch_size) R, _ = self.get_reward(dag, entropies, hidden, valid_idx) if R.max() > max_R: max_R = R.max() best_dag = dag logger.info(f'derive | max_R: {max_R:8.6f}') fname = (f'{self.epoch:03d}-{self.controller_step:06d}-' f'{max_R:6.4f}-best.png') path = os.path.join(self.args.model_dir, 'networks', fname) success = self.draw_network(best_dag, path) if success: self.tb.image_summary('derive/best', [path], self.epoch) return best_dag def reset_dataloader_by_name(self, name): """ Works for only reset _DataLoaderIter by DataLoader with name """ try: new_iter = wrap_iterator_with_name( iter(getattr(self, f'_{name}_data')), name) setattr(self, f'{name}_data', new_iter) except Exception as e: print(e) return new_iter @property def shared_lr(self): degree = max(self.epoch - self.args.shared_decay_after + 1, 0) return self.args.shared_lr * (self.args.shared_decay**degree) @property def controller_lr(self): return self.args.controller_lr def get_batch(self, source, idx, length=None, volatile=False): # code from # https://github.com/pytorch/examples/blob/master/word_language_model/main.py if not self.image_dataset: length = min(length if length else self.max_length, len(source) - 1 - idx) data = Variable(source[idx:idx + length], volatile=volatile) target = Variable(source[idx + 1:idx + 1 + length].view(-1), volatile=volatile) else: # Try the dataloader logic. # type is _DataLoaderIter try: data, target = next(source) except StopIteration as e: print(f'{e}') name = source.name source = self.reset_dataloader_by_name(name) data, target = next(source) # data.to(self.device) return data.to(self.device), target.to(self.device) return data, target @property def shared_path(self): return f'{self.args.model_dir}/shared_epoch{self.epoch}_step{self.shared_step}.pth' @property def controller_path(self): return f'{self.args.model_dir}/controller_epoch{self.epoch}_step{self.controller_step}.pth' def get_saved_models_info(self): paths = glob.glob(os.path.join(self.args.model_dir, '*.pth')) paths.sort() def get_numbers(items, delimiter, idx, replace_word, must_contain=''): return list( set([ int(name.split(delimiter)[idx].replace(replace_word, '')) for name in basenames if must_contain in name ])) basenames = [ os.path.basename(path.rsplit('.', 1)[0]) for path in paths ] epochs = get_numbers(basenames, '_', 1, 'epoch') shared_steps = get_numbers(basenames, '_', 2, 'step', 'shared') controller_steps = get_numbers(basenames, '_', 2, 'step', 'controller') epochs.sort() shared_steps.sort() controller_steps.sort() return epochs, shared_steps, controller_steps def save_model(self): torch.save(self.shared.state_dict(), self.shared_path) logger.info(f'[*] SAVED: {self.shared_path}') torch.save(self.controller.state_dict(), self.controller_path) logger.info(f'[*] SAVED: {self.controller_path}') epochs, shared_steps, controller_steps = self.get_saved_models_info() for epoch in epochs[:-self.args.max_save_num]: paths = glob.glob( os.path.join(self.args.model_dir, f'*_epoch{epoch}_*.pth')) for path in paths: utils.remove_file(path) def load_model(self): epochs, shared_steps, controller_steps = self.get_saved_models_info() if len(epochs) == 0: logger.info(f'[!] No checkpoint found in {self.args.model_dir}...') return self.epoch = self.start_epoch = max(epochs) self.shared_step = max(shared_steps) self.controller_step = max(controller_steps) if self.args.num_gpu == 0: map_location = lambda storage, loc: storage else: map_location = None self.shared.load_state_dict( torch.load(self.shared_path, map_location=map_location)) logger.info(f'[*] LOADED: {self.shared_path}') self.controller.load_state_dict( torch.load(self.controller_path, map_location=map_location)) logger.info(f'[*] LOADED: {self.controller_path}') def _summarize_controller_train(self, total_loss, adv_history, entropy_history, reward_history, avg_reward_base, dags): """Logs the controller's progress for this training epoch.""" cur_loss = total_loss / self.args.log_step avg_adv = np.mean(adv_history) avg_entropy = np.mean(entropy_history) avg_reward = np.mean(reward_history) if avg_reward_base is None: avg_reward_base = avg_reward logger.info(f'| epoch {self.epoch:3d} | lr {self.controller_lr:.5f} ' f'| R {avg_reward:.5f} | entropy {avg_entropy:.4f} ' f'| loss {cur_loss:.5f}') # Tensorboard if self.tb is not None: self.tb.scalar_summary('controller/loss', cur_loss, self.controller_step) self.tb.scalar_summary('controller/reward', avg_reward, self.controller_step) self.tb.scalar_summary('controller/std/reward', np.std(reward_history), self.controller_step) self.tb.scalar_summary('controller/reward-B_per_epoch', avg_reward - avg_reward_base, self.controller_step) self.tb.scalar_summary('controller/entropy', avg_entropy, self.controller_step) self.tb.scalar_summary('controller/adv', avg_adv, self.controller_step) paths = [] for dag in dags: fname = (f'{self.epoch:03d}-{self.controller_step:06d}-' f'{avg_reward:6.4f}.png') path = os.path.join(self.args.model_dir, 'networks', fname) self.draw_network(dag, path) paths.append(path) self.tb.image_summary('controller/sample', paths, self.controller_step) def _summarize_shared_train(self, total_loss, raw_total_loss): """Logs a set of training steps.""" cur_loss = utils.to_item(total_loss) / self.args.log_step cur_raw_loss = utils.to_item(raw_total_loss) / self.args.log_step try: ppl = math.exp(cur_raw_loss) except RuntimeError as e: print(f"Got error {e}") logger.info(f'| epoch {self.epoch:3d} ' f'| lr {self.shared_lr:4.2f} ' f'| raw loss {cur_raw_loss:.2f} ' f'| loss {cur_loss:.2f} ' f'| ppl {ppl:8.2f}') # Tensorboard if self.tb is not None: self.tb.scalar_summary('shared/loss', cur_loss, self.shared_step) self.tb.scalar_summary('shared/perplexity', ppl, self.shared_step)
class Trainer(object): """A class to wrap training code.""" def __init__(self, args, dataset): """Constructor for training algorithm. Args: args: From command line, picked up by 'argparse' dataset: Currently only `data.text.Corpus` is supported. Initializes: - Data: train, val and test. - Model: shared and controller. - Inference: optimizers for shared and controller parameters. - Criticism: cross-entropy loss for training the shared model. """ #TODO 加个检查准确率的 self.args = args self.controller_step = 0 self.cuda = args.cuda self.dataset = dataset self.epoch = 0 self.shared_step = 0 self.start_epoch = 0 print('regularizing:') for regularizer in [('activation regularization', self.args.activation_regularization), ('temporal activation regularization', self.args.temporal_activation_regularization), ('norm stabilizer regularization', self.args.norm_stabilizer_regularization)]: if regularizer[1]: print(f'{regularizer[0]}') # self.train_data = utils.batchify(dataset.train, # args.batch_size, # self.cuda) # NOTE(brendan): The validation set data is batchified twice # separately: once for computing rewards during the Train Controller # phase (valid_data, batch size == 64), and once for evaluating ppl # over the entire validation set (eval_data, batch size == 1) self.train_data = dataset.train self.valid_data = dataset.valid self.test_data = dataset.test # self.max_length = self.args.shared_rnn_max_length if args.use_tensorboard: self.tb = TensorBoard(args.model_dir) else: self.tb = None #TODO initialize controller and shared model self.build_model() # print("11111111") if self.args.load_path: print("=======load_path=======") self.load_model() shared_optimizer = _get_optimizer(self.args.shared_optim) controller_optimizer = _get_optimizer(self.args.controller_optim) print("=======make optimizer========") self.shared_optim = shared_optimizer( self.shared.parameters(), lr=self.shared_lr, weight_decay=self.args.shared_l2_reg) print("=======make optimizer========") self.controller_optim = controller_optimizer( self.controller.parameters(), lr=self.args.controller_lr) self.ce = nn.CrossEntropyLoss() print("finish init") def build_model(self): """Creates and initializes the shared and controller models.""" if self.args.network_type == 'rnn': self.shared = models.RNN(self.args, self.dataset) elif self.args.network_type == 'cnn': print("----- begin to init cnn------") self.shared = models.CNN(self.args, self.dataset) # self.shared = self.shared.cuda() else: raise NotImplementedError(f'Network type ' f'`{self.args.network_type}` is not ' f'defined') print("---- begin to init controller-----") self.controller = models.Controller(self.args) #self.controller = self.controller.cuda() print("===begin to cuda") if True: print("cuda") self.shared.cuda() self.controller.cuda() print("finish cuda") elif self.args.num_gpu > 1: raise NotImplementedError('`num_gpu > 1` is in process') def train(self): """Cycles through alternately training the shared parameters and the controller, as described in Section2.4 Training ENAS and deriving Architectures, of the paraer. """ if self.args.shared_initial_step > 0: self.train_shared(self.args.shared_initial_step) self.train_controller() for self.epoch in range(self.start_epoch, self.args.max_epoch): # 1. Training the shared parameters omega of the child models self.train_shared() # 2. Training the controller parameters theta #self.train_controller() if self.epoch == 0: with _get_no_grad_ctx_mgr(): best_dag = self.derive() self.evaluate(iter(self.test_data), best_dag, 'val_best', max_num=self.args.batch_size * 100) self.save_model() if self.epoch % self.args.save_epoch == 0: with _get_no_grad_ctx_mgr(): best_dag = self.derive() self.evaluate(iter(self.test_data), best_dag, 'val_best', max_num=self.args.batch_size * 100) self.save_model() if self.epoch >= self.args.shared_decay_after: utils.update_lr(self.shared_optim, self.shared_lr) def get_loss(self, inputs, targets, dags): """Computes the loss for the same batch for M models. This amounts to an estimate of the loss, which is turned into an estimate for the gradients of the shared model. """ if not isinstance(dags, list): dags = [dags] loss = 0 for dag in dags: inputs = Variable(inputs.cuda()) targets = Variable(targets.cuda()) # inputs = inputs.cuda() #targets = targets.cuda() #self.shared = self.shared.cuda() output = self.shared(inputs, dag) sample_loss = (self.ce(output, targets) / self.args.shared_num_sample) loss += sample_loss assert len( dags) == 1, 'there are multiple `hidden` for multiple `dags`' return loss def train_shared(self, max_step=None): """Train the image classification model for 310 steps """ #TODO check if it is right that create a new dag for every batch and may be #one epoch one bathc will improve efficient model = self.shared model.train() self.controller.eval() if max_step is None: max_step = self.args.shared_max_step else: max_step = min(self.args.shared_max_step, max_step) step = 0 raw_total_loss = 0 total_loss = 0 # train_idx = 0 train_iter = iter(self.train_data) #TODO understanding how it train while True: if step > max_step: break dags = self.controller.sample(self.args.shared_num_sample) #print(dags) #TODO use iterator to create batch but need to add StopIteration #may be have some method to improve try: inputs, targets = train_iter.next() except StopIteration: print("====>train_shared<====== finish one epoch") break train_iter = iter(self.train_data) #print(dags) loss = self.get_loss(inputs, targets, dags) raw_total_loss += loss.data #TODO understand penality # loss += _apply_penalties() self.shared_optim.zero_grad() loss.backward() self.shared_optim.step() total_loss += loss.data #if step % 20 == 0: # print("loss, ", total_loss, step, total_loss /(step+1)) if ((step % self.args.log_step) == 0) and (step > 0): self._summarize_shared_train(total_loss, raw_total_loss) raw_total_loss = 0 total_loss = 0 step += 1 self.shared_step += 1 # train_idx += self.max_length def get_reward(self, dag, entropies, data_iter): """Computes the perplexity of a single sampled model on a minibatch of validation data. """ if not isinstance(entropies, np.ndarray): entropies = entropies.data.cpu().numpy() try: inputs, targets = data_iter.next() except StopIteration: data_iter = iter(self.valid_data) inputs, targets = data_iter.next() #TODO 怎么做volidate valid_loss = self.get_loss(inputs, targets, dag) # convert valid_loss to numpy ndarray valid_loss = utils.to_item(valid_loss.data) valid_ppl = math.exp(valid_loss) # TODO we don't knoe reward_c if self.args.ppl_square: #TODO: but we do know reward_c =80 in the previous paper need to read previous paper R = self.args.reward_c / valid_ppl**2 else: R = self.args.reward_c / valid_ppl if self.args.entropy_mode == 'reward': rewards = R + self.args.entropy_coeff * entropies elif self.args.entropy_mode == 'regularizer': rewards = R * np.ones_like(entropies) else: raise NotImplementedError( f'Unknown entropy mode: {self.args.entropy_mode}') return rewards def train_controller(self): """Fixes the shared parameters and updates the controller parameters. The controller is updated with a score function gradient estimator (i.e., REINFORCE), with the reward being c/valid_ppl. where valid_ppl is computed on a minibatch of vlaidation data. A moving average baseline is used. The controller is trained for 2000 steps per epoch (i.e., first (Train Shared) phase -. Second (Train Controller) phase). """ model = self.controller model.train() avg_reward_base = None baseline = None adv_history = [] entropy_history = [] reward_history = [] valid_iter = iter(self.valid_data) total_loss = 0 for step in range(self.args.controller_max_step): dags, log_probs, entropies = self.controller.sample( with_details=True) #print(dags) np_entropies = entropies.data.cpu().numpy() with _get_no_grad_ctx_mgr(): rewards = self.get_reward(dags, np_entropies, valid_iter) if 1 > self.args.discount > 0: rewards = discount(rewards, self.args.discount) reward_history.extend(rewards) entropy_history.extend(np_entropies) # moving average baseline if baseline is None: baseline = rewards else: decay = self.args.ema_baseline_decay baseline = decay * baseline + (1 - decay) * rewards adv = rewards - baseline adv_history.extend(adv) #policy loss loss = -log_probs * utils.get_variable( adv, self.cuda, requires_grad=False) if self.args.entropy_mode == 'regularizer': loss -= self.args.entropy_coeff * np_entropies loss = loss.sum() self.controller_optim.zero_grad() loss.backward() if self.args.controller_grad_clip > 0: torch.nn.utils.clip_grad_norm(model.parameters(), self.args.controller_grad_clip) self.controller_optim.step() total_loss += utils.to_item(loss.data) #if step%20 ==0: # print("total loss", total_loss, step, total_loss / (step+1)) if ((step % self.args.log_step) == 0) and (step > 0): self._summarize_controller_train(total_loss, adv_history, entropy_history, reward_history, avg_reward_base, dags) reward_history, adv_history, entropy_history = [], [], [] total_loss = 0 self.controller_step += 1 # prev_valid_idx = valid_idx # valid_idx = ((valid_idx + self.max_length) % # (self.valid_data.size(0) - 1)) # NOTE(brendan): Whenever we wrap around to the beginning of the # validation data, we reset the hidden states. def evaluate(self, test_iter, dag, name, batch_size=1, max_num=None): """Evaluate on the validation set. (lianqing)what is the data of source ? NOTE: use validation to check reward but test set is the same as valid set """ self.shared.eval() self.controller.eval() acc = AverageMeter() # data = source[:max_num*self.max_length] total_loss = 0 # pbar = range(0, data.size(0) - 1, self.max_length) count = 0 while True: try: count += 1 inputs, targets = next(test_iter) except StopIteration: print("========> finish evaluate on one epoch<======") break test_iter = iter(self.test_data) inputs, targets = next(test_iter) # inputs = Variable(inputs) #check if is train the controller will have what difference inputs = Variable(inputs.cuda()) targets = Variable(targets.cuda()) # inputs = inputs.cuda() #targets = targets.cuda() output = self.shared(inputs, dag, is_train=False) # check is self.loss wil work ?: total_loss += len(inputs) * self.ce(output, targets).data ppl = math.exp(utils.to_item(total_loss) / (count + 1)) acc.update(utils.get_accuracy(targets, output)) val_loss = utils.to_item(total_loss) / count ppl = math.exp(val_loss) #TODO it's fix for rnn need to fix for cnn #self.tb.scalar_summary(f'eval/{name}_loss', val_loss, self.epoch) #self.tb.scalar_summary(f'eval/{name}_ppl', ppl, self.epoch) print( f'eval | loss: {val_loss:8.2f} | ppl: {ppl:8.2f} | accuracy: {acc.avg:8.2f}' ) def derive(self, sample_num=None, valid_iter=None): """ pass sample_num is always to 1 test if batch_size > 1 will work ? for controller.sample """ if sample_num is None: sample_num = self.args.derive_num_sample if valid_iter == None: valid_iter = iter(self.valid_data) dags, _, entropies = self.controller.sample(sample_num, with_details=True) max_R = 0 best_dag = None for dag in dags: R = self.get_reward(dag, entropies, valid_iter) if R.max() > max_R: max_R = R.max() best_dag = dag print(f'derive | max_R: {max_R:8.6f}') fname = (f'{self.epoch:03d}-{self.controller_step:06d}-' f'{max_R:6.4}-best.png') path = os.path.join(self.args.model_dir, 'networks', fname) # utils.draw_network(best_dag, path) # self.tb.image_summary('derive/best', [path], self.epoch) return best_dag @property def shared_lr(self): degree = max(self.epoch - self.args.shared_decay_after + 1, 0) return self.args.shared_lr * (self.args.shared_decay**degree) @property def controller_lr(self): return self.args.controller_lr @property def shared_path(self): return f'{self.args.model_dir}/shared_epoch{self.epoch}_step{self.shared_step}.pth' @property def controller_path(self): return f'{self.args.model_dir}/controller_epoch{self.epoch}_step{self.controller_step}.pth' def get_saved_models_info(self): paths = glob.glob(os.path.join(self.args.model_dir, '*.pth')) paths.sort() def get_numbers(items, delimiter, idx, replace_word, must_contain=''): return list( set([ int(name.split(delimiter)[idx].replace(replace_word, '')) for name in basenames if must_contain in name ])) basenames = [ os.path.basename(path.rsplit('.', 1)[0]) for path in paths ] epochs = get_numbers(basenames, '_', 1, 'epoch') shared_steps = get_numbers(basenames, '_', 2, 'step', 'shared') controller_steps = get_numbers(basenames, '_', 2, 'step', 'controller') epochs.sort() shared_steps.sort() controller_steps.sort() return epochs, shared_steps, controller_steps def save_model(self): torch.save(self.shared.state_dict(), self.shared_path) print(f'[*] SAVED: {self.shared_path}') torch.save(self.controller.state_dict(), self.controller_path) print(f'[*] SAVED: {self.controller_path}') epochs, shared_steps, controller_steps = self.get_saved_models_info() for epoch in epochs[:-self.args.max_save_num]: paths = glob.glob( os.path.join(self.args.model_dir, f'*_epoch{epoch}_*.pth')) for path in paths: utils.remove_file(path) def load_model(self): epochs, shared_steps, controller_steps = self.get_saved_models_info() if len(epochs) == 0: print(f'[!] No checkpoint found in {self.args.model_dir}...') return self.epoch = self.start_epoch = max(epochs) self.shared_step = max(shared_steps) self.controller_step = max(controller_steps) if self.args.num_gpu == 0: map_location = lambda storage, loc: storage else: map_location = None self.shared.load_state_dict( torch.load(self.shared_path, map_location=map_location)) print(f'[*] LOADED: {self.shared_path}') self.controller.load_state_dict( torch.load(self.controller_path, map_location=map_location)) print(f'[*] LOADED: {self.controller_path}') def _summarize_controller_train(self, total_loss, adv_history, entropy_history, reward_history, avg_reward_base, dags): """Logs the controller's progress for this training epoch.""" cur_loss = total_loss / self.args.log_step avg_adv = np.mean(adv_history) avg_entropy = np.mean(entropy_history) avg_reward = np.mean(reward_history) if avg_reward_base is None: avg_reward_base = avg_reward print(f'| epoch {self.epoch:3d} | lr {self.controller_lr:.5f} ' f'| R {avg_reward:.5f} | entropy {avg_entropy:.4f} ' f'| loss {cur_loss:.5f}') # Tensorboard if self.tb is not None: self.tb.scalar_summary('controller/loss', cur_loss, self.controller_step) self.tb.scalar_summary('controller/reward', avg_reward, self.controller_step) self.tb.scalar_summary('controller/reward-B_per_epoch', avg_reward - avg_reward_base, self.controller_step) self.tb.scalar_summary('controller/entropy', avg_entropy, self.controller_step) self.tb.scalar_summary('controller/adv', avg_adv, self.controller_step) paths = [] for dag in dags: fname = (f'{self.epoch:03d}-{self.controller_step:06d}-' f'{avg_reward:6.4f}.png') path = os.path.join(self.args.model_dir, 'networks', fname) # utils.draw_network(dag, path) paths.append(path) self.tb.image_summary('controller/sample', paths, self.controller_step) def _summarize_shared_train(self, total_loss, raw_total_loss): """Logs a set of training steps.""" cur_loss = utils.to_item(total_loss) / self.args.log_step # NOTE(brendan): The raw loss, without adding in the activation # regularization terms, should be used to compute ppl. cur_raw_loss = utils.to_item(raw_total_loss) / self.args.log_step ppl = math.exp(cur_raw_loss) print(f'| epoch {self.epoch:3d} ' f'| lr {self.shared_lr:4.2f} ' f'| raw loss {cur_raw_loss:.2f} ' f'| loss {cur_loss:.2f} ' f'| ppl {ppl:8.2f}') # Tensorboard if self.tb is not None: self.tb.scalar_summary('shared/loss', cur_loss, self.shared_step) self.tb.scalar_summary('shared/perplexity', ppl, self.shared_step)