def validate_one_epoch(self, epoch): # return a list of meters meter_list = [] with torch.no_grad(): for arc_id in range(self.test_arc_per_epoch): meters = AverageMeterGroup() for x, y in self.test_loader: x, y = to_device(x, self.device), to_device(y, self.device) self.mutator.reset() logits = self.model(x) if isinstance(logits, tuple): logits, _ = logits metrics = self.metrics(logits, y) loss = self.loss(logits, y) metrics["loss"] = loss.item() meters.update(metrics) meter_dict = json.loads( json.dumps('{' + meters.summary() + '}')) #print("meter_dict: {}".format(meter_dict)) #print("type of meter_dict: {}".format(type(json.loads(meter_dict)))) meter_list.append(json.loads(meter_dict)) logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary %s", epoch + 1, self.num_epochs, arc_id + 1, self.test_arc_per_epoch, meters.summary()) return meter_list
def train_one_epoch(self, epoch): # Train sampler (mutator) self.model.eval() self.mutator.train() total_loss=0 meters = AverageMeterGroup() for mutator_step in range(1, self.mutator_steps + 1): self.mutator_optim.zero_grad() for step in range(1, self.mutator_steps_aggregate + 1): x, y = next(self.valid_loader) x, y = to_device(x, self.device), to_device(y, self.device) self.mutator.reset() with torch.no_grad(): logits = self.model(x) self._write_graph_status() jacobian = get_batch_jacobian(self.model, x) jacobian = jacobian.reshape(jacobian.size(0), -1) reward = eval_score(jacobian) total_loss += reward.item() if self.entropy_weight: reward += self.entropy_weight * self.mutator.sample_entropy.item() # https://arxiv.org/pdf/1707.06347.pdf self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) loss = self.mutator.sample_log_prob * (reward - self.baseline) if self.skip_weight: loss += self.skip_weight * self.mutator.sample_skip_penalty loss /= self.mutator_steps_aggregate loss.backward() cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate if self.log_frequency is not None and cur_step % self.log_frequency == 0: logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs, mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate, reward.int()) nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.) self.mutator_optim.step() mlflow.log_metric('Total reward', -total_loss/(self.mutator_steps*self.mutator_steps_aggregate), epoch) torch.save({ 'model':self.mutator.state_dict(), 'optimizer':self.mutator_optim.state_dict() }, 'mutator_run_stats.pyt') mlflow.log_artifact('mutator_run_stats.pyt')
def validate_one_epoch(self, epoch): with torch.no_grad(): for arc_id in range(self.test_arc_per_epoch): meters = AverageMeterGroup() for x, y in self.test_loader: x, y = to_device(x, self.device), to_device(y, self.device) self.mutator.reset() logits = self.model(x) if isinstance(logits, tuple): logits, _ = logits metrics = self.metrics(logits, y) loss = self.loss(logits, y) metrics["loss"] = loss.item() meters.update(metrics) logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary %s", epoch + 1, self.num_epochs, arc_id + 1, self.test_arc_per_epoch, meters.summary())
def train_one_epoch(self, epoch): # Sample model and train self.model.train() self.mutator.eval() meters = AverageMeterGroup() for step in range(1, self.child_steps + 1): x, y = next(self.train_loader) x, y = to_device(x, self.device), to_device(y, self.device) self.optimizer.zero_grad() with torch.no_grad(): self.mutator.reset() logits = self.model(x) if isinstance(logits, tuple): logits, aux_logits = logits aux_loss = self.loss(aux_logits, y) else: aux_loss = 0. metrics = self.metrics(logits, y) loss = self.loss(logits, y) loss = loss + self.aux_weight * aux_loss loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5.) self.optimizer.step() metrics["loss"] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: logger.info("Model Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs, step, self.child_steps, meters) # Train sampler (mutator) self.model.eval() self.mutator.train() meters = AverageMeterGroup() for mutator_step in range(1, self.mutator_steps + 1): self.mutator_optim.zero_grad() for step in range(1, self.mutator_steps_aggregate + 1): x, y = next(self.valid_loader) x, y = to_device(x, self.device), to_device(y, self.device) self.mutator.reset() with torch.no_grad(): logits = self.model(x) metrics = self.metrics(logits, y) reward = self.reward_function(logits, y) if self.entropy_weight: reward += self.entropy_weight * self.mutator.sample_entropy.item( ) self.baseline = self.baseline * self.baseline_decay + reward * ( 1 - self.baseline_decay) loss = self.mutator.sample_log_prob * (reward - self.baseline) if self.skip_weight: loss += self.skip_weight * self.mutator.sample_skip_penalty metrics["reward"] = reward metrics["loss"] = loss.item() metrics["ent"] = self.mutator.sample_entropy.item() metrics["log_prob"] = self.mutator.sample_log_prob.item() metrics["baseline"] = self.baseline metrics["skip"] = self.mutator.sample_skip_penalty loss /= self.mutator_steps_aggregate loss.backward() meters.update(metrics) cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate if self.log_frequency is not None and cur_step % self.log_frequency == 0: logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs, mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate, meters) nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.) self.mutator_optim.step()
def train_one_epoch(self, epoch): # Sample model and train self.model.train() self.mutator.eval() meters = AverageMeterGroup() # COMMENT: 先训练模型 for step in range(1, self.child_steps + 1): x, y = next(self.train_loader) x, y = to_device(x, self.device), to_device(y, self.device) self.optimizer.zero_grad() with torch.no_grad(): self.mutator.reset() self._write_graph_status() logits = self.model(x) if isinstance(logits, tuple): logits, aux_logits = logits aux_loss = self.loss(aux_logits, y) else: aux_loss = 0. metrics = self.metrics(logits, y) # 计算acc loss = self.loss(logits, y) # 计算loss loss = loss + self.aux_weight * aux_loss loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5.) self.optimizer.step() metrics["loss"] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: logger.info("Model Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs, step, self.child_steps, meters) # Train sampler (mutator) self.model.eval() self.mutator.train() # 然后训练变化器,突变器 meters = AverageMeterGroup() for mutator_step in range(1, self.mutator_steps + 1): self.mutator_optim.zero_grad() for step in range(1, self.mutator_steps_aggregate + 1): x, y = next(self.valid_loader) x, y = to_device(x, self.device), to_device(y, self.device) self.mutator.reset() with torch.no_grad(): logits = self.model(x) self._write_graph_status() # 得到acc metrics = self.metrics(logits, y) # 得到reward ''' def reward_accuracy(output, target, topk=(1,)): batch_size = target.size(0) _, predicted = torch.max(output.data, 1) return (predicted == target).sum().item() / batch_size ''' reward = self.reward_function(logits, y) # 当前这个batch正确的个数 if self.entropy_weight: # 交叉熵权重 reward += self.entropy_weight * self.mutator.sample_entropy.item() # 得到样本熵 self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) # 有点policy gradient的感觉了 loss = self.mutator.sample_log_prob * (reward - self.baseline) if self.skip_weight: loss += self.skip_weight * self.mutator.sample_skip_penalty metrics["reward"] = reward metrics["loss"] = loss.item() metrics["ent"] = self.mutator.sample_entropy.item() metrics["log_prob"] = self.mutator.sample_log_prob.item() metrics["baseline"] = self.baseline metrics["skip"] = self.mutator.sample_skip_penalty loss /= self.mutator_steps_aggregate loss.backward() meters.update(metrics) cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate if self.log_frequency is not None and cur_step % self.log_frequency == 0: logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs, mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate, meters) nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.) self.mutator_optim.step()