コード例 #1
0
class PlotLossesCallback:
    def __init__(self,
                 train_engine: Optional[ignite.engine.Engine] = None,
                 **kwargs):
        """
        :param train_egine - engine with global setep info:
        :param kwargs key-word arguments of PlotLosses:
        """
        self.liveplot = PlotLosses(**kwargs)
        self.train_engine = train_engine

    def attach(self, engine: ignite.engine.Engine):
        """
        Attach callback to ignite engine, attached method will be called on the end of each epoch
         and optionally on the end of every iteration
        """
        engine.add_event_handler(ignite.engine.Events.EPOCH_COMPLETED,
                                 self.store)
        engine.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED,
                                 self.store)

    def store(self, engine: ignite.engine.Engine):
        """Evaluation engine store state with computed metrics, that will be send to main logger"""
        metrics = {}
        if not hasattr(engine.state, 'metrics') or len(
                engine.state.metrics) == 0:
            return
        kwargs = dict(current_step=global_step_from_engine(self.train_engine)(
            self.train_engine,
            self.train_engine.last_event_name)) if self.train_engine else {}
        for key, val in engine.state.metrics.items():
            metric_name = key
            metrics[metric_name] = val
        self.liveplot.update(metrics, **kwargs)
        self.liveplot.send()
コード例 #2
0
ファイル: generic_keras.py プロジェクト: zwq1230/livelossplot
class _PlotLossesCallback:
    """Base keras callback class for keras and tensorflow.keras"""
    def __init__(self, **kwargs):
        self.liveplot = PlotLosses(**kwargs)

    def on_epoch_end(self, epoch, logs):
        """Send metrics to livelossplot"""
        self.liveplot.update(logs.copy(), epoch)
        self.liveplot.send()
コード例 #3
0
ファイル: pytorch_ignite.py プロジェクト: Aryan05/realtimeplt
class PlotLossesCallback:
    def __init__(self,
                 train_engine: Optional[ignite.engine.Engine] = None,
                 **kwargs):
        """
        Args:
            train_engine: engine with global step information, send metohod callback will be attached to it
                if None send method will be called on the end of each store call it may cause warnings and errors in
                the case of multiple engines attached

        Keyword Args:
            **kwargs: keyword args that will be passed to livelossplot PlotLosses class
        """
        self.liveplot = PlotLosses(**kwargs)
        self.train_engine = train_engine
        if self.train_engine:
            self.train_engine.add_event_handler(
                ignite.engine.Events.EPOCH_STARTED, self.send)
            self.train_engine.add_event_handler(ignite.engine.Events.COMPLETED,
                                                self.send)

    def attach(self, engine: ignite.engine.Engine):
        """Attach callback to ignite engine, attached method will be called on the end of each epoch
        Args:
            engine: engine that computes metrics on the end of each epoch and / or on the end of each iteration

        Notes:
            metrics computation plugins have to be attached before this one
        """
        engine.add_event_handler(ignite.engine.Events.EPOCH_COMPLETED,
                                 self.store)
        engine.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED,
                                 self.store)

    def store(self, engine: ignite.engine.Engine):
        """Store computed metrics, that will be send to main logger
        Args:
            engine: engine with state.metrics
        """
        metrics = {}
        if not hasattr(engine.state, 'metrics') or len(
                engine.state.metrics) == 0:
            return
        kwargs = dict(current_step=global_step_from_engine(self.train_engine)(
            self.train_engine,
            self.train_engine.last_event_name)) if self.train_engine else {}
        for key, val in engine.state.metrics.items():
            metrics[key] = val
        self.liveplot.update(metrics, **kwargs)
        if not self.train_engine:
            self.send()

    def send(self, _: Optional[ignite.engine.Engine] = None):
        self.liveplot.send()
コード例 #4
0
class _PlotLossesCallback:
    """Base keras callback class for keras and tensorflow.keras"""
    def __init__(self, **kwargs):
        """
        Args:
            **kwargs: keyword arguments that will be passed to PlotLosses constructor
        """
        self.liveplot = PlotLosses(**kwargs)

    def on_epoch_end(self, epoch: int, logs: Dict[str, float]):
        """Send metrics to livelossplot
        Args:
            epoch: epoch number
            logs: metrics with values
        """
        self.liveplot.update(logs.copy(), epoch)
        self.liveplot.send()
コード例 #5
0
ファイル: pytorch_ignite.py プロジェクト: joshy/livelossplot
class PlotLossesCallback:
    def __init__(self, metrics_prefix='', **kwargs):
        """
        :param metrics_prefix: prefix will be added to each metric - f.e. you can add val_ to validation engine
        :param kwargs: key-word arguments of PlotLosses
        """
        self.liveplot = PlotLosses(**kwargs)
        self.metrics_prefix = metrics_prefix

    def attach(self, engine):
        """Attach callback to ignite engine, attached method will be called on the end of each epoch"""
        engine.add_event_handler(ignite.engine.Events.EPOCH_COMPLETED, self.on_epoch_end)

    def on_epoch_end(self, engine):
        """Evaluation engine store state with computed metrics, that will be send to main logger"""
        metrics = {}
        for key, val in engine.state.metrics.items():
            metric_name = '{}{}'.format(self.metrics_prefix, key)
            metrics[metric_name] = val
        self.liveplot.update(metrics)
        self.liveplot.send()
コード例 #6
0
    def train(self, dataset, val_dataset=None, epochs=int(3e4), n_itr=100):
        try:
            z = tf.constant(
                np.load(f'{self.save_path}/{self.model_name}_z.npy'))
        except FileNotFoundError:
            z = tf.constant(random.normal((self.batch_size, 1, 1, self.z_dim)))
            os.makedirs(self.save_path, exist_ok=True)
            np.save(f'{self.save_path}/{self.model_name}_z', z.numpy())

        liveplot = PlotLosses()
        try:
            losses_list = pickle.load(
                open(f'{self.save_path}/{self.model_name}_losses_list.pkl',
                     'rb'))
        except:
            losses_list = []

        for i, losses in enumerate(losses_list):
            liveplot.update(losses, i)

        start_epoch = len(losses_list)

        g_train_loss = metrics.Mean()
        d_train_loss = metrics.Mean()
        d_val_loss = metrics.Mean()

        for epoch in range(start_epoch, epochs):
            train_bar = pbar(n_itr, epoch, epochs)
            for itr_c, batch in zip(range(n_itr), dataset):
                if train_bar.n >= n_itr:
                    break

                for _ in range(self.n_critic):
                    d_loss = self.train_d(batch['images'])
                    d_train_loss(d_loss)

                g_loss = self.train_g()
                g_train_loss(g_loss)
                self.train_g()

                train_bar.postfix['g_loss'] = f'{g_train_loss.result():6.3f}'
                train_bar.postfix['d_loss'] = f'{d_train_loss.result():6.3f}'
                train_bar.update(n=itr_c)

            train_bar.close()

            if val_dataset:
                val_bar = vbar(n_itr // 5, epoch, epochs)
                for itr_c, batch in zip(range(n_itr // 5), val_dataset):
                    if val_bar.n >= n_itr // 5:
                        break

                    d_val_l = self.val_d(batch['images'])
                    d_val_loss(d_val_l)

                    val_bar.postfix[
                        'd_val_loss'] = f'{d_val_loss.result():6.3f}'
                    val_bar.update(n=itr_c)
                val_bar.close()

            losses = {
                'g_loss': g_train_loss.result(),
                'd_loss': d_train_loss.result(),
                'd_val_loss': d_val_loss.result()
            }
            losses_list += [losses]
            pickle.dump(
                losses_list,
                open(f'{self.save_path}/{self.model_name}_losses_list.pkl',
                     'wb'))
            liveplot.update(losses, epoch)
            liveplot.send()

            g_train_loss.reset_states()
            d_train_loss.reset_states()
            d_val_loss.reset_states()
            del train_bar
            del val_bar

            self.G.save_weights(
                filepath=f'{self.save_path}/{self.model_name}_generator')
            self.D.save_weights(
                filepath=f'{self.save_path}/{self.model_name}_discriminator')

            if epoch >= int(2e4):
                if epoch % 1000 == 0:
                    self.G.save_weights(
                        filepath=
                        f'{self.save_path}/{self.model_name}_generator{epoch}')
                    self.D.save_weights(
                        filepath=
                        f'{self.save_path}/{self.model_name}_discriminator{epoch}'
                    )

            if epoch % 5 == 0:
                samples = self.generate_samples(z)
                image_grid = img_merge(samples, n_rows=6).squeeze()
                img_path = f'./images/{self.model_name}'
                os.makedirs(img_path, exist_ok=True)
                save_image_grid(image_grid,
                                epoch + 1,
                                self.model_name,
                                output_dir=img_path)
コード例 #7
0
ファイル: model.py プロジェクト: TheDudeFromCI/VAE-GAN
class Model:
    def __init__(self, parameters: ModelParameters):
        self.parameters = parameters
        self.dataloader = get_dataloader(parameters)

        groups = {'VAE': ['recons_loss'], 'GAN': ['g_loss', 'd_loss']}
        self.liveloss = PlotLosses(outputs=[BokehPlot(max_cols=2)],
                                   mode='script',
                                   groups=groups)

        self.vae_gan = VAE_GAN(parameters.image_size,
                               parameters.image_channels,
                               parameters.latent_dim,
                               self.dataloader,
                               layers_per_size=parameters.layers_per_size,
                               channel_scaling=parameters.channel_scaling)
        self.vae_gan.eval()

        if parameters.cuda:
            self.vae_gan.cuda()

        if parameters.print_summary:
            print('Model Summary:')
            summary(self.vae_gan, depth=10)

    def _epoch_callback(self,
                        epoch,
                        recons_loss=0,
                        kld_loss=0,
                        g_loss=0,
                        d_loss=0):
        self.recons_loss = recons_loss
        self.kld_loss = kld_loss
        self.g_loss = g_loss
        self.d_loss = d_loss

        if self.parameters.print_info:
            print('Finished epoch {} with a loss of {:.4f}.'.format(
                epoch, recons_loss))

        if self.parameters.save_snapshots:
            save_vae_snapshot(self.vae_gan.vae, self.dataloader, epoch,
                              self.parameters.cuda)
            save_gan_snapshot(self.vae_gan.gan, epoch)

        if self.parameters.save_model:
            save_model(self.vae_gan, epoch)

        if self.parameters.plot_loss:
            logs = {
                'recons_loss': recons_loss + kld_loss,
                'g_loss': g_loss,
                'd_loss': d_loss
            }

            self.liveloss.update(logs, current_step=epoch)
            self.liveloss.send()

        if self.parameters.epoch_callback:
            self.parameters.epoch_callback(epoch, recons_loss, kld_loss,
                                           g_loss, d_loss)

    def train(self):
        self.vae_gan.train()

        if self.parameters.vae_pretraining_epochs > 0:
            if self.parameters.print_info:
                print('Pre-training VAE.')

            self.vae_gan.train_vae(
                epochs=self.parameters.vae_pretraining_epochs,
                epoch_callback=self._epoch_callback,
                print_info=self.parameters.print_info)

        if self.parameters.print_info:
            print('Converting to dual-training.')

        self.vae_gan.train_dual(
            epochs=self.parameters.epochs,
            epoch_offset=self.parameters.vae_pretraining_epochs,
            epoch_callback=self._epoch_callback,
            print_info=self.parameters.print_info)

        self.vae_gan.eval()