示例#1
0
    def infer_epoch(self, valid_queue, model, criterion, device):
        expect(self._is_setup, "trainer.setup should be called first")
        objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        objective_perfs = utils.OrderedStats()
        model.eval()

        context = torch.no_grad if self.eval_no_grad else nullcontext
        with context():
            for step, (inputs, target) in enumerate(valid_queue):
                inputs = inputs.to(device)
                target = target.to(device)

                logits = model(inputs)
                loss = criterion(logits, target)
                perfs = self._perf_func(inputs, logits, target, model)
                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
                n = inputs.size(0)
                objective_perfs.update(dict(zip(self._perf_names, perfs)), n=n)
                objs.update(loss.item(), n)
                top1.update(prec1.item(), n)
                top5.update(prec5.item(), n)

                if step % self.report_every == 0:
                    self.logger.info("valid %03d %e %f %f %s", step, objs.avg, top1.avg, top5.avg,
                                     "; ".join(["{}: {:.3f}".format(perf_n, v) \
                                                for perf_n, v in objective_perfs.avgs().items()]))

        return top1.avg, objs.avg, objective_perfs.avgs()
示例#2
0
 def _candnet_perf_use_param(
     cand_net, params, data, loss_criterion, forward_kwargs, return_outputs=False
 ):
     # TODO: only support cnn now. because directly use utils.accuracy
     #      should use self._report_loss_funcs instead maybe
     outputs = cand_net.forward_with_params(
         data[0], params=params, **forward_kwargs, mode="train"
     )
     loss = loss_criterion(data[0], outputs, data[1])
     acc = utils.accuracy(outputs, data[1], topk=(1,))[0]
     if return_outputs:
         return outputs, loss, acc
     return loss, acc
示例#3
0
    def train_epoch(self, train_queue, model, criterion, optimizer, device,
                    epoch):
        expect(self._is_setup, "trainer.setup should be called first")
        objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        model.train()

        for step, (inputs, target) in enumerate(train_queue):
            inputs = inputs.to(device)
            target = target.to(device)

            optimizer.zero_grad()
            if self.auxiliary_head:  # assume model return two logits in train mode
                logits, logits_aux = model(inputs)
                loss = self._obj_loss(
                    inputs,
                    logits,
                    target,
                    model,
                    add_evaluator_regularization=self.add_regularization)
                loss_aux = criterion(logits_aux, target)
                loss += self.auxiliary_weight * loss_aux
            else:
                logits = model(inputs)
                loss = self._obj_loss(
                    inputs,
                    logits,
                    target,
                    model,
                    add_evaluator_regularization=self.add_regularization)
            #torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip)
            optimizer.step()

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
            n = inputs.size(0)
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)
            del loss

            if step % self.report_every == 0:
                self.logger.info("train %03d %.3f; %.2f%%; %.2f%%", step,
                                 objs.avg, top1.avg, top5.avg)

        return top1.avg, objs.avg
示例#4
0
    def infer_epoch(self, valid_queue, model, criterion, device):
        expect(self._is_setup, "trainer.setup should be called first")
        objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        objective_perfs = utils.OrderedStats()
        all_perfs = []
        model.eval()

        context = torch.no_grad if self.eval_no_grad else nullcontext
        with context():
            for step, (inputs, target) in enumerate(valid_queue):
                inputs = inputs.to(device)
                target = target.to(device)

                logits = model(inputs)
                loss = criterion(logits, target)
                perfs = self._perf_func(inputs, logits, target, model)
                all_perfs.append(perfs)
                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
                n = inputs.size(0)
                # objective_perfs.update(dict(zip(self._perf_names, perfs)), n=n)
                objs.update(loss.item(), n)
                top1.update(prec1.item(), n)
                top5.update(prec5.item(), n)
                del loss
                if step % self.report_every == 0:
                    all_perfs_by_name = list(zip(*all_perfs))
                    # support use objective aggregate fn, for stat method other than mean
                    # e.g., adversarial distance median; detection mAP (see det_trainer.py)
                    obj_perfs = {
                        k: self.objective.aggregate_fn(k, False)(v)
                        for k, v in zip(self._perf_names, all_perfs_by_name)
                    }
                    self.logger.info("valid %03d %e %f %f %s", step, objs.avg, top1.avg, top5.avg,
                                     "; ".join(["{}: {:.3f}".format(perf_n, v) \
                                                # for perf_n, v in objective_perfs.avgs().items()]))

                                                for perf_n, v in obj_perfs.items()]))
        all_perfs_by_name = list(zip(*all_perfs))
        obj_perfs = {
            k: self.objective.aggregate_fn(k, False)(v)
            for k, v in zip(self._perf_names, all_perfs_by_name)
        }
        return top1.avg, objs.avg, obj_perfs