def _summarize_shared_train(self, total_loss, raw_total_loss): """Logs a set of training steps.""" cur_loss = utils.to_item(total_loss) / self.args.log_step # NOTE(brendan): The raw loss, without adding in the activation # regularization terms, should be used to compute ppl. cur_raw_loss = utils.to_item(raw_total_loss) / self.args.log_step ppl = math.exp(cur_raw_loss)
def evaluate(self, source, dag, name, batch_size=1, max_num=None): """Evaluate on the validation set. NOTE(brendan): We should not be using the test set to develop the algorithm (basic machine learning good practices). """ self.shared.eval() self.controller.eval() data = source[:max_num * self.max_length] total_loss = 0 hidden = self.shared.init_hidden(batch_size) pbar = range(0, data.size(0) - 1, self.max_length) for count, idx in enumerate(pbar): inputs, targets = self.get_batch(data, idx, volatile=True) output, hidden, _ = self.shared(inputs, dag, hidden=hidden, is_train=False) output_flat = output.view(-1, self.dataset.num_tokens) total_loss += len(inputs) * self.ce(output_flat, targets).data hidden.detach_() ppl = math.exp( utils.to_item(total_loss) / (count + 1) / self.max_length) val_loss = utils.to_item(total_loss) / len(data) ppl = math.exp(val_loss) return val_loss, ppl
def get_reward(self, dag, entropies, hidden, valid_idx=0): """Computes the perplexity of a single sampled model on a minibatch of validation data. """ if not isinstance(entropies, np.ndarray): entropies = entropies.data.cpu().numpy() inputs, targets = self.get_batch(self.valid_data, valid_idx, self.max_length, volatile=True) valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag) valid_loss = utils.to_item(valid_loss.data) valid_ppl = math.exp(valid_loss) # TODO: we don't know reward_c if self.args.ppl_square: # TODO: but we do know reward_c=80 in the previous paper R = self.args.reward_c / valid_ppl**2 else: R = self.args.reward_c / valid_ppl if self.args.entropy_mode == 'reward': rewards = R + self.args.entropy_coeff * entropies elif self.args.entropy_mode == 'regularizer': rewards = R * np.ones_like(entropies) else: raise NotImplementedError( f'Unkown entropy mode: {self.args.entropy_mode}') return rewards, hidden
def _construct_dags(prev_nodes, activations, func_names, num_blocks): """Constructs a set of DAGs based on the actions, i.e., previous nodes and activation functions, sampled from the controller/policy pi. Args: prev_nodes: Previous node actions from the policy. activations: Activations sampled from the policy. func_names: Mapping from activation function names to functions. num_blocks: Number of blocks in the target RNN cell. Returns: A list of DAGs defined by the inputs. RNN cell DAGs are represented in the following way: 1. Each element (node) in a DAG is a list of `Node`s. 2. The `Node`s in the list dag[i] correspond to the subsequent nodes that take the output from node i as their own input. 3. dag[-1] is the node that takes input from x^{(t)} and h^{(t - 1)}. dag[-1] always feeds dag[0]. dag[-1] acts as if `w_xc`, `w_hc`, `w_xh` and `w_hh` are its weights. 4. dag[N - 1] is the node that produces the hidden state passed to the next timestep. dag[N - 1] is also always a leaf node, and therefore is always averaged with the other leaf nodes and fed to the output decoder. """ dags = [] for nodes, func_ids in zip(prev_nodes, activations): dag = collections.defaultdict(list) # add first node dag[-1] = [Node(0, func_names[func_ids[0]])] dag[-2] = [Node(0, func_names[func_ids[0]])] # add following nodes for jdx, (idx, func_id) in enumerate(zip(nodes, func_ids[1:])): dag[utils.to_item(idx)].append(Node(jdx + 1, func_names[func_id])) leaf_nodes = set(range(num_blocks)) - dag.keys() # merge with avg for idx in leaf_nodes: dag[idx] = [Node(num_blocks, 'avg')] # TODO(brendan): This is actually y^{(t)}. h^{(t)} is node N - 1 in # the graph, where N Is the number of nodes. I.e., h^{(t)} takes # only one other node as its input. # last h[t] node last_node = Node(num_blocks + 1, 'h[t]') dag[num_blocks] = [last_node] dags.append(dag) return dags
def train_controller(self): """Fixes the shared parameters and updates the controller parameters. The controller is updated with a score function gradient estimator (i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl is computed on a minibatch of validation data. A moving average baseline is used. The controller is trained for 2000 steps per epoch (i.e., first (Train Shared) phase -> second (Train Controller) phase). """ model = self.controller model.train() # TODO(brendan): Why can't we call shared.eval() here? Leads to loss # being uniformly zero for the controller. # self.shared.eval() avg_reward_base = None baseline = None adv_history = [] entropy_history = [] reward_history = [] hidden = self.shared.init_hidden(self.args.batch_size) total_loss = 0 valid_idx = 0 for step in range(self.args.controller_max_step): # sample models dags, log_probs, entropies = self.controller.sample( with_details=True) # calculate reward np_entropies = entropies.data.cpu().numpy() # NOTE(brendan): No gradients should be backpropagated to the # shared model during controller training, obviously. with _get_no_grad_ctx_mgr(): rewards, hidden = self.get_reward(dags, np_entropies, hidden, valid_idx) # discount if 1 > self.args.discount > 0: rewards = discount(rewards, self.args.discount) reward_history.extend(rewards) entropy_history.extend(np_entropies) # moving average baseline if baseline is None: baseline = rewards else: decay = self.args.ema_baseline_decay baseline = decay * baseline + (1 - decay) * rewards adv = rewards - baseline adv_history.extend(adv) # policy loss loss = -log_probs * utils.get_variable( adv, self.cuda, requires_grad=False) if self.args.entropy_mode == 'regularizer': loss -= self.args.entropy_coeff * entropies loss = loss.sum() # or loss.mean() # update self.controller_optim.zero_grad() loss.backward() if self.args.controller_grad_clip > 0: torch.nn.utils.clip_grad_norm(model.parameters(), self.args.controller_grad_clip) self.controller_optim.step() total_loss += utils.to_item(loss.data) if ((step % self.args.log_step) == 0) and (step > 0): self._summarize_controller_train(total_loss, adv_history, entropy_history, reward_history, avg_reward_base, dags) reward_history, adv_history, entropy_history = [], [], [] total_loss = 0 self.controller_step += 1 prev_valid_idx = valid_idx valid_idx = ((valid_idx + self.max_length) % (self.valid_data.size(0) - 1)) # NOTE(brendan): Whenever we wrap around to the beginning of the # validation data, we reset the hidden states. if prev_valid_idx > valid_idx: hidden = self.shared.init_hidden(self.args.batch_size)
def train_shared(self, max_step=None, dag=None): """Train the language model for 400 steps of minibatches of 64 examples. Args: max_step: Used to run extra training steps as a warm-up. dag: If not None, is used instead of calling sample(). BPTT is truncated at 35 timesteps. For each weight update, gradients are estimated by sampling M models from the fixed controller policy, and averaging their gradients computed on a batch of training data. """ model = self.shared model.train() self.controller.eval() hidden = self.shared.init_hidden(self.args.batch_size) if max_step is None: max_step = self.args.shared_max_step else: max_step = min(self.args.shared_max_step, max_step) abs_max_grad = 0 abs_max_hidden_norm = 0 step = 0 raw_total_loss = 0 total_loss = 0 train_idx = 0 # TODO(brendan): Why - 1 - 1? while train_idx < self.train_data.size(0) - 1 - 1: if step > max_step: break dags = dag if dag else self.controller.sample( self.args.shared_num_sample) inputs, targets = self.get_batch(self.train_data, train_idx, self.max_length) loss, hidden, extra_out = self.get_loss(inputs, targets, hidden, dags) hidden.detach_() raw_total_loss += loss.data loss += _apply_penalties(extra_out, self.args) # update self.shared_optim.zero_grad() loss.backward() h1tohT = extra_out['hiddens'] new_abs_max_hidden_norm = utils.to_item( h1tohT.norm(dim=-1).data.max()) if new_abs_max_hidden_norm > abs_max_hidden_norm: abs_max_hidden_norm = new_abs_max_hidden_norm # logger.info(f'max hidden {abs_max_hidden_norm}') abs_max_grad = _check_abs_max_grad(abs_max_grad, model) torch.nn.utils.clip_grad_norm(model.parameters(), self.args.shared_grad_clip) self.shared_optim.step() total_loss += loss.data if ((step % self.args.log_step) == 0) and (step > 0): self._summarize_shared_train(total_loss, raw_total_loss) raw_total_loss = 0 total_loss = 0 step += 1 self.shared_step += 1 train_idx += self.max_length