示例#1
0
    def validate_one_epoch(self, epoch):

        # return a list of meters
        meter_list = []

        with torch.no_grad():
            for arc_id in range(self.test_arc_per_epoch):
                meters = AverageMeterGroup()
                for x, y in self.test_loader:
                    x, y = to_device(x, self.device), to_device(y, self.device)
                    self.mutator.reset()
                    logits = self.model(x)
                    if isinstance(logits, tuple):
                        logits, _ = logits
                    metrics = self.metrics(logits, y)
                    loss = self.loss(logits, y)
                    metrics["loss"] = loss.item()
                    meters.update(metrics)

                meter_dict = json.loads(
                    json.dumps('{' + meters.summary() + '}'))
                #print("meter_dict: {}".format(meter_dict))
                #print("type of meter_dict: {}".format(type(json.loads(meter_dict))))
                meter_list.append(json.loads(meter_dict))

                logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary  %s",
                            epoch + 1, self.num_epochs, arc_id + 1,
                            self.test_arc_per_epoch, meters.summary())

        return meter_list
示例#2
0
    def train_one_epoch(self, epoch):
        # Train sampler (mutator)
        self.model.eval()
        self.mutator.train()
        total_loss=0
        meters = AverageMeterGroup()
        for mutator_step in range(1, self.mutator_steps + 1):
            self.mutator_optim.zero_grad()
            for step in range(1, self.mutator_steps_aggregate + 1):
                x, y = next(self.valid_loader)
                x, y = to_device(x, self.device), to_device(y, self.device)

                self.mutator.reset()
                with torch.no_grad():
                    logits = self.model(x)
                self._write_graph_status()
                jacobian = get_batch_jacobian(self.model, x)
                jacobian = jacobian.reshape(jacobian.size(0), -1)

                reward = eval_score(jacobian)
                total_loss += reward.item()
            
                if self.entropy_weight:
                    reward += self.entropy_weight * self.mutator.sample_entropy.item()
                # https://arxiv.org/pdf/1707.06347.pdf
                self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
                loss = self.mutator.sample_log_prob * (reward - self.baseline)
                if self.skip_weight:
                    loss += self.skip_weight * self.mutator.sample_skip_penalty

                loss /= self.mutator_steps_aggregate
                loss.backward()

                cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate
                if self.log_frequency is not None and cur_step % self.log_frequency == 0:
                    logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d]  %s", epoch + 1, self.num_epochs,
                                mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate,
                                reward.int())

            nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.)
            self.mutator_optim.step()
        mlflow.log_metric('Total reward', -total_loss/(self.mutator_steps*self.mutator_steps_aggregate), epoch)
        torch.save({
        'model':self.mutator.state_dict(),
        'optimizer':self.mutator_optim.state_dict()
        }, 'mutator_run_stats.pyt')
        mlflow.log_artifact('mutator_run_stats.pyt')
示例#3
0
    def validate_one_epoch(self, epoch):
        with torch.no_grad():
            for arc_id in range(self.test_arc_per_epoch):
                meters = AverageMeterGroup()
                for x, y in self.test_loader:
                    x, y = to_device(x, self.device), to_device(y, self.device)
                    self.mutator.reset()
                    logits = self.model(x)
                    if isinstance(logits, tuple):
                        logits, _ = logits
                    metrics = self.metrics(logits, y)
                    loss = self.loss(logits, y)
                    metrics["loss"] = loss.item()
                    meters.update(metrics)

                logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary  %s",
                            epoch + 1, self.num_epochs, arc_id + 1,
                            self.test_arc_per_epoch, meters.summary())
示例#4
0
    def train_one_epoch(self, epoch):
        # Sample model and train
        self.model.train()
        self.mutator.eval()
        meters = AverageMeterGroup()
        for step in range(1, self.child_steps + 1):
            x, y = next(self.train_loader)
            x, y = to_device(x, self.device), to_device(y, self.device)
            self.optimizer.zero_grad()

            with torch.no_grad():
                self.mutator.reset()
            logits = self.model(x)

            if isinstance(logits, tuple):
                logits, aux_logits = logits
                aux_loss = self.loss(aux_logits, y)
            else:
                aux_loss = 0.
            metrics = self.metrics(logits, y)
            loss = self.loss(logits, y)
            loss = loss + self.aux_weight * aux_loss
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
            self.optimizer.step()
            metrics["loss"] = loss.item()
            meters.update(metrics)

            if self.log_frequency is not None and step % self.log_frequency == 0:
                logger.info("Model Epoch [%d/%d] Step [%d/%d]  %s", epoch + 1,
                            self.num_epochs, step, self.child_steps, meters)

        # Train sampler (mutator)
        self.model.eval()
        self.mutator.train()
        meters = AverageMeterGroup()
        for mutator_step in range(1, self.mutator_steps + 1):
            self.mutator_optim.zero_grad()
            for step in range(1, self.mutator_steps_aggregate + 1):
                x, y = next(self.valid_loader)
                x, y = to_device(x, self.device), to_device(y, self.device)

                self.mutator.reset()
                with torch.no_grad():
                    logits = self.model(x)
                metrics = self.metrics(logits, y)
                reward = self.reward_function(logits, y)
                if self.entropy_weight:
                    reward += self.entropy_weight * self.mutator.sample_entropy.item(
                    )
                self.baseline = self.baseline * self.baseline_decay + reward * (
                    1 - self.baseline_decay)
                loss = self.mutator.sample_log_prob * (reward - self.baseline)
                if self.skip_weight:
                    loss += self.skip_weight * self.mutator.sample_skip_penalty
                metrics["reward"] = reward
                metrics["loss"] = loss.item()
                metrics["ent"] = self.mutator.sample_entropy.item()
                metrics["log_prob"] = self.mutator.sample_log_prob.item()
                metrics["baseline"] = self.baseline
                metrics["skip"] = self.mutator.sample_skip_penalty

                loss /= self.mutator_steps_aggregate
                loss.backward()
                meters.update(metrics)

                cur_step = step + (mutator_step -
                                   1) * self.mutator_steps_aggregate
                if self.log_frequency is not None and cur_step % self.log_frequency == 0:
                    logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d]  %s",
                                epoch + 1, self.num_epochs, mutator_step,
                                self.mutator_steps, step,
                                self.mutator_steps_aggregate, meters)

            nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.)
            self.mutator_optim.step()
示例#5
0
    def train_one_epoch(self, epoch):
        # Sample model and train
        self.model.train()
        self.mutator.eval()
        meters = AverageMeterGroup()
        # COMMENT: 先训练模型
        for step in range(1, self.child_steps + 1):
            x, y = next(self.train_loader)
            x, y = to_device(x, self.device), to_device(y, self.device)
            self.optimizer.zero_grad()

            with torch.no_grad():
                self.mutator.reset()
            self._write_graph_status()
            logits = self.model(x)

            if isinstance(logits, tuple):
                logits, aux_logits = logits
                aux_loss = self.loss(aux_logits, y)
            else:
                aux_loss = 0.
            metrics = self.metrics(logits, y)
            # 计算acc 

            loss = self.loss(logits, y)
            # 计算loss

            loss = loss + self.aux_weight * aux_loss

            loss.backward()

            nn.utils.clip_grad_norm_(self.model.parameters(), 5.)

            self.optimizer.step()

            metrics["loss"] = loss.item()
            meters.update(metrics)

            if self.log_frequency is not None and step % self.log_frequency == 0:
                logger.info("Model Epoch [%d/%d] Step [%d/%d]  %s", epoch + 1,
                            self.num_epochs, step, self.child_steps, meters)

        # Train sampler (mutator)
        self.model.eval()
        self.mutator.train()
        # 然后训练变化器,突变器
        meters = AverageMeterGroup()
        for mutator_step in range(1, self.mutator_steps + 1):
            self.mutator_optim.zero_grad()
            for step in range(1, self.mutator_steps_aggregate + 1):
                x, y = next(self.valid_loader)
                x, y = to_device(x, self.device), to_device(y, self.device)

                self.mutator.reset()

                with torch.no_grad():
                    logits = self.model(x)

                self._write_graph_status()

                # 得到acc
                metrics = self.metrics(logits, y)
                # 得到reward

                '''
                def reward_accuracy(output, target, topk=(1,)):
                    batch_size = target.size(0)
                    
                    _, predicted = torch.max(output.data, 1)
                    return (predicted == target).sum().item() / batch_size
                '''

                reward = self.reward_function(logits, y) # 当前这个batch正确的个数

                if self.entropy_weight: # 交叉熵权重 
                    reward += self.entropy_weight * self.mutator.sample_entropy.item() # 得到样本熵
                
                self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
                # 有点policy gradient的感觉了
                
                loss = self.mutator.sample_log_prob * (reward - self.baseline)

                if self.skip_weight:
                    loss += self.skip_weight * self.mutator.sample_skip_penalty
                
                metrics["reward"] = reward
                metrics["loss"] = loss.item()
                metrics["ent"] = self.mutator.sample_entropy.item()
                metrics["log_prob"] = self.mutator.sample_log_prob.item()
                metrics["baseline"] = self.baseline
                metrics["skip"] = self.mutator.sample_skip_penalty

                loss /= self.mutator_steps_aggregate
                loss.backward()
                meters.update(metrics)

                cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate
                if self.log_frequency is not None and cur_step % self.log_frequency == 0:
                    logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d]  %s", epoch + 1, self.num_epochs,
                                mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate,
                                meters)

            nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.)
            self.mutator_optim.step()