Beispiel #1
0
    def train(self, step: int) -> torch.Tensor:

        # update model
        self.optimizer.zero_grad()

        # flag for logging
        log_flag = step % self.log_interval == 0

        # forward model
        loss, meta = self.forward(*to_device(next(self.train_dataset)),
                                  log_flag)

        # check loss nan
        if loss != loss:
            log('{} cur step NAN is occured'.format(step))
            return

        loss.backward()
        self.clip_grad()
        self.optimizer.step()

        # logging
        if log_flag:
            # console logging
            self.console_log('train', meta, step)
            # tensorboard logging
            self.tensorboard_log('train', meta, step)
Beispiel #2
0
    def validate(self, step: int):

        loss = 0.
        stat = defaultdict(float)

        for i in range(self.valid_max_step):
            # flag for logging
            log_flag = i % self.log_interval == 0 or i == self.valid_max_step - 1

            # forward model
            with torch.no_grad():
                batch_loss, meta = self.forward(*to_device(next(self.valid_dataset)), is_logging=log_flag)
                loss += batch_loss

            # update stat
            for key, (value, log_type) in meta.items():
                if log_type == LogType.SCALAR:
                    stat[key] += value

            # console logging of this step
            if (i + 1) % self.log_interval == 0:
                self.console_log('valid', meta, i + 1)

        meta_non_scalar = {
            key: (value, log_type) for key, (value, log_type) in meta.items()
            if not log_type == LogType.SCALAR
        }

        try:
            self.tensorboard_log('valid', meta_non_scalar, step)
        except OverflowError:
            pass

        # averaging stat
        loss /= self.valid_max_step
        for key in stat.keys():
            stat[key] = stat[key] / self.valid_max_step

        # update best valid loss
        if loss < self.best_valid_loss:
            self.best_valid_loss = loss

        # console logging of total stat
        msg = 'step {} / total stat'.format(step)
        for key, value in sorted(stat.items()):
            msg += '\t{}: {:.6f}'.format(key, value)
        log(msg)

        # tensor board logging of scalar stat
        for key, value in stat.items():
            self.writer.add_scalar('valid/{}'.format(key), value, global_step=step)
Beispiel #3
0
    def validate(self, step: int):

        loss = 0.
        count = 0
        stat = defaultdict(float)

        for i in range(self.valid_max_step):
            # forward model
            with torch.no_grad():
                batch_loss, meta = self.forward(*to_device(
                    next(self.valid_dataset)),
                                                is_logging=True)
                loss += batch_loss

            for key, (value, log_type) in meta.items():
                if log_type == LogType.SCALAR:
                    stat[key] += value

            if i % self.log_interval == 0 or i == self.valid_max_step - 1:
                self.console_log('valid', meta, i + 1)

        # averaging stat
        loss /= self.valid_max_step
        for key in stat.keys():
            if key == 'loss':
                continue
            stat[key] = stat[key] / self.valid_max_step
        stat['loss'] = loss

        # update best valid loss
        if loss < self.best_valid_loss:
            self.best_valid_loss = loss

        # console logging of total stat
        msg = 'step {} / total stat'.format(step)
        for key, value in sorted(stat.items()):
            msg += '\t{}: {:.6f}'.format(key, value)
        log(msg)

        # tensor board logging of scalar stat
        for key, value in stat.items():
            self.writer.add_scalar('valid/{}'.format(key),
                                   value,
                                   global_step=step)
Beispiel #4
0
 def repeat(iterable):
     while True:
         for group in iterable:
             for x in group:
                 yield to_device(x, 'cuda')