Esempio n. 1
0
    def validate_one_epoch(self, epoch):

        # return a list of meters
        meter_list = []

        with torch.no_grad():
            for arc_id in range(self.test_arc_per_epoch):
                meters = AverageMeterGroup()
                for x, y in self.test_loader:
                    x, y = to_device(x, self.device), to_device(y, self.device)
                    self.mutator.reset()
                    logits = self.model(x)
                    if isinstance(logits, tuple):
                        logits, _ = logits
                    metrics = self.metrics(logits, y)
                    loss = self.loss(logits, y)
                    metrics["loss"] = loss.item()
                    meters.update(metrics)

                meter_dict = json.loads(
                    json.dumps('{' + meters.summary() + '}'))
                #print("meter_dict: {}".format(meter_dict))
                #print("type of meter_dict: {}".format(type(json.loads(meter_dict))))
                meter_list.append(json.loads(meter_dict))

                logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary  %s",
                            epoch + 1, self.num_epochs, arc_id + 1,
                            self.test_arc_per_epoch, meters.summary())

        return meter_list
Esempio n. 2
0
    def validate_one_epoch(self, epoch):

        # return list of meters
        meter_list = []

        self.model.eval()
        self.mutator.eval()
        meters = AverageMeterGroup()
        with torch.no_grad():
            self.mutator.reset()
            for step, (X, y) in enumerate(self.test_loader):
                X, y = X.to(self.device), y.to(self.device)
                logits = self.model(X)
                metrics = self.metrics(logits, y)
                loss = self.loss(logits, y)
                metrics["loss"] = loss.item()
                meters.update(metrics)

                if self.log_frequency is not None and step % self.log_frequency == 0:
                    logger.info("Epoch [%s/%s] Step [%s/%s]  %s", epoch + 1,
                                self.num_epochs, step + 1, len(self.test_loader), meters)
            
            meter_dict = json.loads(json.dumps('{' + meters.summary() + '}'))
            meter_list.append(json.loads(meter_dict))
                
        return meter_list
Esempio n. 3
0
    def validate_one_epoch(self, epoch):
        with torch.no_grad():
            for arc_id in range(self.test_arc_per_epoch):
                meters = AverageMeterGroup()
                for x, y in self.test_loader:
                    x, y = to_device(x, self.device), to_device(y, self.device)
                    self.mutator.reset()
                    logits = self.model(x)
                    if isinstance(logits, tuple):
                        logits, _ = logits
                    metrics = self.metrics(logits, y)
                    loss = self.loss(logits, y)
                    metrics["loss"] = loss.item()
                    meters.update(metrics)

                logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary  %s",
                            epoch + 1, self.num_epochs, arc_id + 1,
                            self.test_arc_per_epoch, meters.summary())