def init_hidden(self, batch_size): zeros = torch.zeros(batch_size, self.shared_hid) return utils.get_variable(zeros, self.use_cuda, requires_grad=False)
def train_controller(self): """Fixes the shared parameters and updates the controller parameters. The controller is updated with a score function gradient estimator (i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl is computed on a minibatch of validation data. A moving average baseline is used. The controller is trained for 2000 steps per epoch (i.e., first (Train Shared) phase -> second (Train Controller) phase). """ model = self.controller model.train() # Why can't we call shared.eval() here? Leads to loss # being uniformly zero for the controller. # self.shared.eval() avg_reward_base = None baseline = None adv_history = [] entropy_history = [] reward_history = [] hidden = self.shared.init_hidden(self.batch_size) total_loss = 0 valid_idx = 0 for step in range(20): # sample models dags, log_probs, entropies = self.controller.sample( with_details=True) # calculate reward np_entropies = entropies.data.cpu().numpy() # No gradients should be backpropagated to the # shared model during controller training, obviously. with _get_no_grad_ctx_mgr(): rewards, hidden = self.get_reward(dags, np_entropies, hidden, valid_idx) reward_history.extend(rewards) entropy_history.extend(np_entropies) # moving average baseline if baseline is None: baseline = rewards else: decay = 0.95 baseline = decay * baseline + (1 - decay) * rewards adv = rewards - baseline adv_history.extend(adv) # policy loss loss = -log_probs * utils.get_variable( adv, self.use_cuda, requires_grad=False) loss = loss.sum() # or loss.mean() # update self.controller_optim.zero_grad() loss.backward() self.controller_optim.step() total_loss += utils.to_item(loss.data) if ((step % 50) == 0) and (step > 0): reward_history, adv_history, entropy_history = [], [], [] total_loss = 0 self.controller_step += 1
def sample(self, batch_size=1, with_details=False, save_dir=None): """Samples a set of `args.num_blocks` many computational nodes from the controller, where each node is made up of an activation function, and each node except the last also includes a previous node. """ if batch_size < 1: raise Exception(f'Wrong batch_size: {batch_size} < 1') # [B, L, H] inputs = self.static_inputs[batch_size] hidden = self.static_init_hidden[batch_size] activations = [] entropies = [] log_probs = [] prev_nodes = [] # The RNN controller alternately outputs an activation, # followed by a previous node, for each block except the last one, # which only gets an activation function. The last node is the output # node, and its previous node is the average of all leaf nodes. for block_idx in range(2 * (self.num_blocks - 1) + 1): logits, hidden = self.forward(inputs, hidden, block_idx, is_embed=(block_idx == 0)) probs = F.softmax(logits, dim=-1) log_prob = F.log_softmax(logits, dim=-1) # .mean() for entropy? entropy = -(log_prob * probs).sum(1, keepdim=False) action = probs.multinomial(num_samples=1).data selected_log_prob = log_prob.gather( 1, utils.get_variable(action, requires_grad=False)) # why the [:, 0] here? Should it be .squeeze(), or # .view()? Same below with `action`. entropies.append(entropy) log_probs.append(selected_log_prob[:, 0]) # 0: function, 1: previous node mode = block_idx % 2 inputs = utils.get_variable(action[:, 0] + sum(self.num_tokens[:mode]), requires_grad=False) if mode == 0: activations.append(action[:, 0]) elif mode == 1: prev_nodes.append(action[:, 0]) prev_nodes = torch.stack(prev_nodes).transpose(0, 1) activations = torch.stack(activations).transpose(0, 1) dags = _construct_dags(prev_nodes, activations, self.func_names, self.num_blocks) if save_dir is not None: for idx, dag in enumerate(dags): utils.draw_network(dag, os.path.join(save_dir, f'graph{idx}.png')) if with_details: return dags, torch.cat(log_probs), torch.cat(entropies) return dags
def init_hidden(self, batch_size): zeros = torch.zeros(batch_size, self.controller_hid) return (utils.get_variable(zeros, self.use_cuda, requires_grad=False), utils.get_variable(zeros.clone(), self.use_cuda, requires_grad=False))
def _get_default_hidden(key): return utils.get_variable(torch.zeros(key, self.controller_hid), self.use_cuda, requires_grad=False)