示例#1
0
 def _log_basic(self, stats, misc):
     """Log to stdout and text log file"""
     tr_loss = np.mean(stats['tr_loss'])
     val_loss = np.mean(stats['val_loss'])
     lr = misc['learning_rate']
     tr_speed = misc['tr_speed']
     tr_speed_vx = misc['tr_speed_vx']
     t = pretty_string_time(self._timer.t_passed)
     text = f'step={self.step:06d}, tr_loss={tr_loss:.3f}, val_loss={val_loss:.3f}, '
     text += f'lr={lr:.2e}, {tr_speed:.2f} it/s, {tr_speed_vx:.2f} MVx/s, {t}'
     logger.info(text)
示例#2
0
    def train(self, max_steps: int = 1, max_runtime=3600 * 24 * 7) -> None:
        """Train the network for ``max_steps`` steps.

        After each training epoch, validation performance is measured and
        visualizations are computed and logged to tensorboard."""
        self.start_time = datetime.datetime.now()
        self.end_time = self.start_time + datetime.timedelta(seconds=max_runtime)
        while not self.terminate:
            try:
                # --> self.train()
                self.model.train()

                # Scalar training stats that should be logged and written to tensorboard later
                stats: Dict[str, float] = {'tr_loss': 0.0}
                # Other scalars to be logged
                misc: Dict[str, float] = {}
                # Hold image tensors for real-time training sample visualization in tensorboard
                images: Dict[str, torch.Tensor] = {}

                running_acc = 0
                running_mean_target = 0
                running_vx_size = 0
                timer = Timer()
                for inp, target in self.train_loader:
                    inp, target = inp.to(self.device), target.to(self.device)

                    # forward pass
                    out = self.model(inp)
                    loss = self.criterion(out, target)
                    if torch.isnan(loss):
                        logger.error('NaN loss detected! Aborting training.')
                        raise NaNException

                    # update step
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                    # Prevent accidental autograd overheads after optimizer step
                    inp.detach_()
                    target.detach_()
                    out.detach_()
                    loss.detach_()

                    # get training performance
                    stats['tr_loss'] += float(loss)
                    acc = metrics.bin_accuracy(target, out)  # TODO
                    mean_target = target.to(torch.float32).mean()
                    print(f'{self.step:6d}, loss: {loss:.4f}', end='\r')
                    self._tracker.update_timeline([self._timer.t_passed, float(loss), mean_target])

                    # Preserve training batch and network output for later visualization
                    images['inp'] = inp
                    images['target'] = target
                    images['out'] = out
                    # this was changed to support ReduceLROnPlateau which does not implement get_lr
                    misc['learning_rate'] = self.optimizer.param_groups[0]["lr"] # .get_lr()[-1]
                    # update schedules
                    for sched in self.schedulers.values():
                        # support ReduceLROnPlateau; doc. uses validation loss instead
                        # http://pytorch.org/docs/master/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau
                        if "metrics" in inspect.signature(sched.step).parameters:
                            sched.step(metrics=float(loss))
                        else:
                            sched.step()

                    running_acc += acc
                    running_mean_target += mean_target
                    running_vx_size += inp.numel()

                    self.step += 1
                    if self.step >= max_steps:
                        logger.info(f'max_steps ({max_steps}) exceeded. Terminating...')
                        self.terminate = True
                        break
                    if datetime.datetime.now() >= self.end_time:
                        logger.info(f'max_runtime ({max_runtime} seconds) exceeded. Terminating...')
                        self.terminate = True
                        break
                stats['tr_accuracy'] = running_acc / len(self.train_loader)
                stats['tr_loss'] /= len(self.train_loader)
                misc['tr_speed'] = len(self.train_loader) / timer.t_passed
                misc['tr_speed_vx'] = running_vx_size / timer.t_passed / 1e6  # MVx
                mean_target = running_mean_target / len(self.train_loader)
                if self.valid_dataset is None:
                    stats['val_loss'], stats['val_accuracy'] = float('nan'), float('nan')
                else:
                    valid_stats = self.validate()
                    stats.update(valid_stats)


                # Update history tracker (kind of made obsolete by tensorboard)
                # TODO: Decide what to do with this, now that most things are already in tensorboard.
                if self.step // len(self.train_dataset) > 1:
                    tr_loss_gain = self._tracker.history[-1][2] - stats['tr_loss']
                else:
                    tr_loss_gain = 0
                self._tracker.update_history([
                    self.step, self._timer.t_passed, stats['tr_loss'], stats['val_loss'],
                    tr_loss_gain, stats['tr_accuracy'], stats['val_accuracy'], misc['learning_rate'], 0, 0
                ])  # 0's correspond to mom and gradnet (?)
                t = pretty_string_time(self._timer.t_passed)
                loss_smooth = self._tracker.loss._ema

                # Logging to stdout, text log file
                text = "%05i L_m=%.3f, L=%.2f, tr_acc=%05.2f%%, " % (self.step, loss_smooth, stats['tr_loss'], stats['tr_accuracy'])
                text += "val_acc=%05.2f%s, prev=%04.1f, L_diff=%+.1e, " % (stats['val_accuracy'], "%", mean_target * 100, tr_loss_gain)
                text += "LR=%.2e, %.2f it/s, %.2f MVx/s, %s" % (misc['learning_rate'], misc['tr_speed'], misc['tr_speed_vx'], t)
                logger.info(text)

                # Plot tracker stats to pngs in save_path
                self._tracker.plot(self.save_path)

                # Reporting to tensorboard logger
                if self.tb:
                    self.tb_log_scalars(stats, 'stats')
                    self.tb_log_scalars(misc, 'misc')
                    if self.previews_enabled:
                        self.tb_log_preview()
                    self.tb_log_sample_images(images, group='tr_samples')
                    self.tb.writer.flush()

                # Save trained model state
                self.save_model()
                if stats['val_loss'] < self.best_val_loss:
                    self.best_val_loss = stats['val_loss']
                    self.save_model(suffix='_best')
            except KeyboardInterrupt:
                IPython.embed(header=self._shell_info)
                if self.terminate:
                    return
            except Exception as e:
                traceback.print_exc()
                if self.ignore_errors:
                    # Just print the traceback and try to carry on with training.
                    # This can go wrong in unexpected ways, so don't leave the training unattended.
                    pass
                elif self.ipython_on_error:
                    print("\nEntering Command line such that Exception can be "
                          "further inspected by user.\n\n")
                    IPython.embed(header=self._shell_info)
                    if self.terminate:
                        return
                else:
                    raise e
        self.save_model(suffix='_final')
示例#3
0
    def run(self, max_steps: int = 1, max_runtime=3600 * 24 * 7) -> None:
        """Train the network for ``max_steps`` steps.

        After each training epoch, validation performance is measured and
        visualizations are computed and logged to tensorboard."""
        self.start_time = datetime.datetime.now()
        self.end_time = self.start_time + datetime.timedelta(seconds=max_runtime)
        while not self.terminate:
            try:
                stats, misc, images = self._train(max_steps, max_runtime)
                self.epoch += 1

                if self.valid_dataset is None:
                    stats['val_loss'] = float('nan')
                else:
                    valid_stats = self._validate()
                    stats.update(valid_stats)

                if not 'val_accuracy' in stats:
                    stats['val_accuracy'] = float('nan')

                # Update history tracker (kind of made obsolete by tensorboard)
                # TODO: Decide what to do with this, now that most things are already in tensorboard.
                if self.step // len(self.train_dataset) > 1:
                    tr_loss_gain = self._tracker.history[-1][2] - stats['tr_loss']
                else:
                    tr_loss_gain = 0
                self._tracker.update_history([
                    self.step, self._timer.t_passed, stats['tr_loss'], stats['val_loss'],
                    tr_loss_gain, stats['tr_accuracy'], stats['val_accuracy'], misc['learning_rate'], 0, 0
                ])  # 0's correspond to mom and gradnet (?)
                t = pretty_string_time(self._timer.t_passed)
                loss_smooth = self._tracker.loss._ema

                # Logging to stdout, text log file
                text = "%05i L_m=%.3f, L=%.2f, tr_acc=%05.2f%%, " % (self.step, loss_smooth, stats['tr_loss'], stats['tr_accuracy'])
                text += "val_acc=%05.2f%s, prev=%04.1f, L_diff=%+.1e, " % (stats['val_accuracy'], "%", misc['mean_target'] * 100, tr_loss_gain)
                text += "LR=%.2e, %.2f it/s, %.2f MVx/s, %s" % (misc['learning_rate'], misc['tr_speed'], misc['tr_speed_vx'], t)
                logger.info(text)

                # Plot tracker stats to pngs in save_path
                self._tracker.plot(self.save_path)

                # Reporting to tensorboard logger
                if self.tb:
                    try:
                        self._tb_log_scalars(stats, 'stats')
                        self._tb_log_scalars(misc, 'misc')
                        if self.preview_batch is not None:
                            if self.epoch % self.preview_interval == 0 or self.epoch == 1:
                                # TODO: Also save preview inference results in a (3D) HDF5 file
                                self.preview_plotting_handler(self)
                        self.sample_plotting_handler(self, images, group='tr_samples')
                    except Exception:
                        logger.exception('Error occured while logging to tensorboard:')

                # Save trained model state
                self._save_model()
                # TODO: Support other metrics for determining what's the "best" model?
                if stats['val_loss'] < self.best_val_loss:
                    self.best_val_loss = stats['val_loss']
                    self._save_model(suffix='_best')
            except KeyboardInterrupt:
                if self.ipython_shell:
                    IPython.embed(header=self._shell_info)
                else:
                    break
                if self.terminate:
                    break
            except Exception as e:
                logger.exception('Unhandled exception during training:')
                if self.ignore_errors:
                    # Just print the traceback and try to carry on with training.
                    # This can go wrong in unexpected ways, so don't leave the training unattended.
                    pass
                elif self.ipython_shell:
                    print("\nEntering Command line such that Exception can be "
                          "further inspected by user.\n\n")
                    IPython.embed(header=self._shell_info)
                    if self.terminate:
                        break
                else:
                    raise e
        self._save_model(suffix='_final')
示例#4
0
    def predict_proba(self, inp: np.ndarray, bs: int = 10,
                        verbose: bool = False):
        """

        Args:
            inp: Input data, e.g. of shape [N, C, H, W]
            bs: batch size
            verbose: report inference speed

        Returns:

        """
        if verbose:
            start = time.time()
        if self.normalize_func is not None:
            inp = self.normalize_func(inp)
        with torch.no_grad():
            # get output shape shape
            if type(inp) is tuple:
                out = self.model(*(torch.Tensor(ii[:2]).to(torch.float32).to(self.device) for ii in inp))
                n_samples = len(inp[0])
            else:
                out = self.model(torch.Tensor(inp[:2]).to(torch.float32).to(self.device))
                n_samples = len(inp)
            # change sample number according to input
            if type(out) is tuple:
                out = tuple(np.zeros([n_samples] + list(out[ii].shape)[1:],
                               dtype=np.float32) for ii in range(len(out)))
            else:
                out = np.zeros([n_samples] + list(out.shape)[1:], dtype=np.float32)
            for ii in range(0, int(np.ceil(n_samples / bs))):
                low = bs * ii
                high = bs * (ii + 1)
                if type(inp) is tuple:
                    inp_stride = tuple(torch.Tensor(ii[low:high]).to(torch.float32).to(self.device) for ii in inp)
                    res = self.model(*inp_stride)
                else:
                    inp_stride = torch.Tensor(inp[low:high]).to(torch.float32).to(self.device)
                    res = self.model(inp_stride)
                if type(res) is tuple:
                    for ii in range(len(res)):
                        out[ii][low:high] = res[ii].detach().cpu()
                else:
                    out[low:high] = res.detach().cpu()
                if type(inp_stride) is tuple:
                    for el in inp_stride:
                        el.detach_()
                else:
                    inp_stride.detach_()
                del inp_stride
                del res
                torch.cuda.empty_cache()
            assert high >= n_samples, "Prediction less samples then given" \
                                     " in input."
        if verbose:
            dtime = time.time() - start
            if type(inp) is tuple:
                inp_el = np.sum([float(np.prod(inp[kk].shape)) for kk in range(len(inp))])
            else:
                inp_el = float(np.prod(inp.shape))
            speed = inp_el / dtime / 1e6
            dtime = pretty_string_time(dtime)
            print(f'Inference speed: {speed:.2f} MB or MPix /s, time: {dtime}.')
        return out
示例#5
0
    def run(self, max_steps: int = 1) -> None:
        """Train the network for ``max_steps`` steps.

        After each training epoch, validation performance is measured and
        visualizations are computed and logged to tensorboard."""
        while self.step < max_steps:
            try:
                # --> self.train()
                self.model.train()

                # Scalar training stats that should be logged and written to tensorboard later
                stats: Dict[str, float] = {'tr_loss_G': .0, 'tr_loss_D': .0}
                # Other scalars to be logged
                misc: Dict[str, float] = {
                    'G_loss_advreg': .0,
                    'G_loss_tnet': .0,
                    'G_loss_l2': .0,
                    'D_loss_fake': .0,
                    'D_loss_real': .0
                }
                # Hold image tensors for real-time training sample visualization in tensorboard
                images: Dict[str, torch.Tensor] = {}

                running_error = 0
                running_mean_target = 0
                running_vx_size = 0
                timer = Timer()
                latent_points_fake = []
                latent_points_real = []
                for inp in self.train_loader:  # ref., pos., neg. samples
                    if inp.size()[1] != 3:
                        raise ValueError(
                            "Data must not contain targets. "
                            "Input data shape is assumed to be "
                            "(N, 3, ch, x, y), where the first two"
                            " images in each sample is the similar"
                            " pair, while the third one is the "
                            "distant one.")
                    inp0 = Variable(inp[:, 0].to(self.device))
                    inp1 = Variable(inp[:, 1].to(self.device))
                    inp2 = Variable(inp[:, 2].to(self.device))
                    self.optimizer.zero_grad()
                    # forward pass
                    dA, dB, z0, z1, z2 = self.model(inp0, inp1, inp2)
                    z_fake_gauss = torch.squeeze(torch.cat((z0, z1, z2),
                                                           dim=1))
                    target = torch.FloatTensor(dA.size()).fill_(-1).to(
                        self.device)
                    target = Variable(target)
                    loss = self.criterion(dA, dB, target)
                    L_l2 = torch.mean(
                        torch.cat((z0.norm(1, dim=1), z1.norm(
                            1, dim=1), z2.norm(1, dim=1)),
                                  dim=0))
                    misc['G_loss_l2'] += self.alpha * float(L_l2)
                    loss = loss + self.alpha * L_l2
                    misc['G_loss_tnet'] += (1 - self.alpha2) * float(
                        loss)  # log actual loss
                    if torch.isnan(loss):
                        logger.error('NaN loss detected after {self.step} '
                                     'steps! Aborting training.')
                        raise NaNException

                    # Adversarial part to enforce latent variable distribution
                    # to be Normal / whatever prior is used
                    if self.alpha2 > 0:
                        self.optimizer_discr.zero_grad()
                        # adversarial labels
                        valid = Variable(torch.Tensor(inp0.size()[0],
                                                      1).fill_(1.0),
                                         requires_grad=False).to(self.device)
                        fake = Variable(torch.Tensor(inp0.shape[0],
                                                     1).fill_(0.0),
                                        requires_grad=False).to(self.device)

                        # --- Generator / TripletNet
                        self.model_discr.eval()
                        # TripletNet latent space should be classified as valid
                        L_advreg = self.criterion_discr(
                            self.model_discr(z_fake_gauss), valid)
                        # average adversarial reg. and triplet-loss
                        loss = (1 -
                                self.alpha2) * loss + self.alpha2 * L_advreg
                        # perform generator step
                        loss.backward()
                        self.optimizer.step()

                        # --- Discriminator
                        self.model.eval()
                        self.model_discr.train()
                        # rebuild graph (model output) to get clean backprop.
                        z_real_gauss = Variable(
                            self.latent_distr(inp0.size()[0],
                                              z0.size()[-1] * 3)).to(
                                                  self.device)
                        _, _, z_fake_gauss0, z_fake_gauss1, z_fake_gauss2 = self.model(
                            inp0, inp1, inp2)
                        z_fake_gauss = torch.squeeze(
                            torch.cat(
                                (z_fake_gauss0, z_fake_gauss1, z_fake_gauss2),
                                dim=1))
                        # Compute discriminator outputs and loss
                        L_real_gauss = self.criterion_discr(
                            self.model_discr(z_real_gauss), valid)
                        L_fake_gauss = self.criterion_discr(
                            self.model_discr(z_fake_gauss), fake)
                        L_discr = 0.5 * (L_real_gauss + L_fake_gauss)
                        L_discr.backward()  # Backprop loss
                        self.optimizer_discr.step()  # Apply optimization step
                        self.model.train()  # set back to training mode

                        # # clean and report
                        L_discr.detach()
                        L_advreg.detach()
                        L_real_gauss.detach()
                        L_fake_gauss.detach()
                        stats['tr_loss_D'] += float(L_discr)
                        misc['G_loss_advreg'] += self.alpha2 * float(
                            L_advreg)  # log actual part of advreg
                        misc['D_loss_real'] += float(L_real_gauss)
                        misc['D_loss_fake'] += float(L_fake_gauss)
                        latent_points_real.append(
                            z_real_gauss.detach().cpu().numpy())
                    else:
                        loss.backward()
                        self.optimizer.step()

                    latent_points_fake.append(
                        z_fake_gauss.detach().cpu().numpy())
                    # # Prevent accidental autograd overheads after optimizer step
                    inp.detach()
                    target.detach()
                    dA.detach()
                    dB.detach()
                    z0.detach()
                    z1.detach()
                    z2.detach()
                    loss.detach()
                    L_l2.detach()

                    # get training performance
                    stats['tr_loss_G'] += float(loss)
                    error = calculate_error(dA, dB)
                    mean_target = target.to(torch.float32).mean()
                    print(f'{self.step:6d}, loss: {loss:.4f}', end='\r')
                    self._tracker.update_timeline(
                        [self._timer.t_passed,
                         float(loss), mean_target])

                    # Preserve training batch and network output for later visualization
                    images['inp_ref'] = inp0.cpu().numpy()
                    images['inp_+'] = inp1.cpu().numpy()
                    images['inp_-'] = inp2.cpu().numpy()
                    # this was changed to support ReduceLROnPlateau which does not implement get_lr
                    misc['learning_rate_G'] = self.optimizer.param_groups[0][
                        "lr"]  # .get_lr()[-1]
                    misc[
                        'learning_rate_D'] = self.optimizer_discr.param_groups[
                            0]["lr"]  # .get_lr()[-1]
                    # update schedules
                    for sched in self.schedulers.values():
                        # support ReduceLROnPlateau; doc. uses validation loss instead
                        # http://pytorch.org/docs/master/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau
                        if "metrics" in inspect.signature(
                                sched.step).parameters:
                            sched.step(metrics=float(loss))
                        else:
                            sched.step()
                    running_error += error
                    running_mean_target += mean_target
                    running_vx_size += inp.numel()

                    self.step += 1
                    if self.step >= max_steps:
                        break
                stats['tr_err_G'] = float(running_error) / len(
                    self.train_loader)
                stats['tr_loss_G'] /= len(self.train_loader)
                stats['tr_loss_D'] /= len(self.train_loader)
                misc['G_loss_advreg'] /= len(self.train_loader)
                misc['G_loss_tnet'] /= len(self.train_loader)
                misc['G_loss_l2'] /= len(self.train_loader)
                misc['D_loss_fake'] /= len(self.train_loader)
                misc['D_loss_real'] /= len(self.train_loader)
                misc['tr_speed'] = len(self.train_loader) / timer.t_passed
                misc[
                    'tr_speed_vx'] = running_vx_size / timer.t_passed / 1e6  # MVx
                mean_target = running_mean_target / len(self.train_loader)
                if (self.valid_dataset is None) or (1 != np.random.randint(
                        0, 10)):  # only validate 10% of the times
                    stats['val_loss_G'], stats['val_err_G'] = float(
                        'nan'), float('nan')
                else:
                    stats['val_loss_G'], stats['val_err_G'] = self._validate()
                # TODO: Report more metrics, e.g. dice error

                # Update history tracker (kind of made obsolete by tensorboard)
                # TODO: Decide what to do with this, now that most things are already in tensorboard.
                if self.step // len(self.train_dataset) > 1:
                    tr_loss_gain = self._tracker.history[-1][2] - stats[
                        'tr_loss_G']
                else:
                    tr_loss_gain = 0
                self._tracker.update_history([
                    self.step, self._timer.t_passed, stats['tr_loss_G'],
                    stats['val_loss_G'], tr_loss_gain, stats['tr_err_G'],
                    stats['val_err_G'], misc['learning_rate_G'], 0, 0
                ])  # 0's correspond to mom and gradnet (?)
                t = pretty_string_time(self._timer.t_passed)
                loss_smooth = self._tracker.loss._ema

                # Logging to stdout, text log file
                text = "%05i L_m=%.3f, L=%.2f, tr=%05.2f%%, " % (
                    self.step, loss_smooth, stats['tr_loss_G'],
                    stats['tr_err_G'])
                text += "vl=%05.2f%s, prev=%04.1f, L_diff=%+.1e, " % (
                    stats['val_err_G'], "%", mean_target * 100, tr_loss_gain)
                text += "LR=%.2e, %.2f it/s, %.2f MVx/s, %s" % (
                    misc['learning_rate_G'], misc['tr_speed'],
                    misc['tr_speed_vx'], t)
                logger.info(text)

                # Plot tracker stats to pngs in save_path
                self._tracker.plot(self.save_path)

                # Reporting to tensorboard logger
                if self.tb:
                    self._tb_log_scalars(stats, 'stats')
                    self._tb_log_scalars(misc, 'misc')
                    self.tb_log_sample_images(images, group='tr_samples')

                # save histrograms
                if len(latent_points_fake) > 0:
                    fig, ax = plt.subplots()
                    sns.distplot(np.concatenate(latent_points_fake).flatten())
                    # plt.savefig(os.path.join(self.save_path,
                    #                          'latent_fake_{}.png'.format(self.step)))
                    fig.canvas.draw()
                    img_data = np.array(fig.canvas.renderer._renderer)
                    self.tb.add_figure(f'latent_distr/latent_fake',
                                       plot_image(img_data),
                                       global_step=self.step)
                    plt.close()

                if len(latent_points_real) > 0:
                    fig, ax = plt.subplots()
                    sns.distplot(np.concatenate(latent_points_real).flatten())
                    # plt.savefig(os.path.join(self.save_path,
                    #                          'latent_real_{}.png'.format(self.step)))
                    fig.canvas.draw()
                    img_data = np.array(fig.canvas.renderer._renderer)
                    self.tb.add_figure(f'latent_distr/latent_real',
                                       plot_image(img_data),
                                       global_step=self.step)
                    plt.close()

                    # grab the pixel buffer and dump it into a numpy array

                # Save trained model state
                torch.save(
                    self.model.state_dict(),
                    # os.path.join(self.save_path, f'model-{self.step:06d}.pth')  # Saving with different file names leads to heaps of large files,
                    os.path.join(self.save_path, 'model-checkpoint.pth'))
                # TODO: Also save "best" model, not only the latest one, which is often overfitted.
                #       -> "best" in which regard? Lowest validation loss, validation error?
                #          We can't blindly trust these metrics and may have to calculate
                #          additional metrics (with focus on object boundary correctness).
            except KeyboardInterrupt:
                IPython.embed(header=self._shell_info)
                if self.terminate:
                    return
            except Exception as e:
                traceback.print_exc()
                if self.ignore_errors:
                    # Just print the traceback and try to carry on with training.
                    # This can go wrong in unexpected ways, so don't leave the training unattended.
                    pass
                elif self.ipython_shell:
                    print("\nEntering Command line such that Exception can be "
                          "further inspected by user.\n\n")
                    IPython.embed(header=self._shell_info)
                    if self.terminate:
                        return
                else:
                    raise e
        torch.save(
            self.model.state_dict(),
            os.path.join(self.save_path, f'model-final-{self.step:06d}.pth'))