예제 #1
0
    def default_eval_fn(evaluator, batch):
        model = evaluator.model
        teacher = evaluator.teacher

        inputs, targets = split_batch(batch)
        outputs = model(inputs)

        # get teacher outputs
        if isinstance(teacher, torch.nn.ModuleList):
            targets = [
                task.predict(tea(inputs))
                for (tea, task) in zip(teacher, evaluator.task)
            ]
        else:
            t_outputs = teacher(inputs)
            targets = evaluator.task.predict(t_outputs)
        evaluator.metric.update(outputs, targets)
예제 #2
0
 def step_fn(self, engine, batch):
     model = self.model
     start_time = time.perf_counter()
     batch = move_to_device(batch, self.device)
     inputs, targets = split_batch(batch)
     outputs = model(inputs)
     loss_dict = self.task.get_loss(outputs, targets) # get loss
     loss = sum( loss_dict.values() )
     self.optimizer.zero_grad()
     loss.backward()
     self.optimizer.step()
     step_time = time.perf_counter() - start_time
     metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() }
     metrics.update({
         'total_loss': loss.item(),
         'step_time': step_time,
         'lr': float( self.optimizer.param_groups[0]['lr'] )
     })
     return metrics
예제 #3
0
 def default_eval_fn(evaluator, batch):
     model = evaluator.model
     inputs, targets = split_batch(batch)
     outputs = model(inputs)
     evaluator.metric.update(outputs, targets)
예제 #4
0
 def wrapper(model, batch):
     inputs, targets = split_batch(batch)
     outputs = model(inputs)
     outputs, targets = attach_to(outputs, targets)
     return inputs, targets, pred_fn(outputs)