コード例 #1
0
ファイル: pytorch_ignite.py プロジェクト: joshy/livelossplot
 def __init__(self, metrics_prefix='', **kwargs):
     """
     :param metrics_prefix: prefix will be added to each metric - f.e. you can add val_ to validation engine
     :param kwargs: key-word arguments of PlotLosses
     """
     self.liveplot = PlotLosses(**kwargs)
     self.metrics_prefix = metrics_prefix
コード例 #2
0
 def __init__(self,
              train_engine: Optional[ignite.engine.Engine] = None,
              **kwargs):
     """
     :param train_egine - engine with global setep info:
     :param kwargs key-word arguments of PlotLosses:
     """
     self.liveplot = PlotLosses(**kwargs)
     self.train_engine = train_engine
コード例 #3
0
ファイル: pytorch_ignite.py プロジェクト: Aryan05/realtimeplt
    def __init__(self,
                 train_engine: Optional[ignite.engine.Engine] = None,
                 **kwargs):
        """
        Args:
            train_engine: engine with global step information, send metohod callback will be attached to it
                if None send method will be called on the end of each store call it may cause warnings and errors in
                the case of multiple engines attached

        Keyword Args:
            **kwargs: keyword args that will be passed to livelossplot PlotLosses class
        """
        self.liveplot = PlotLosses(**kwargs)
        self.train_engine = train_engine
        if self.train_engine:
            self.train_engine.add_event_handler(
                ignite.engine.Events.EPOCH_STARTED, self.send)
            self.train_engine.add_event_handler(ignite.engine.Events.COMPLETED,
                                                self.send)
コード例 #4
0
ファイル: model.py プロジェクト: TheDudeFromCI/VAE-GAN
    def __init__(self, parameters: ModelParameters):
        self.parameters = parameters
        self.dataloader = get_dataloader(parameters)

        groups = {'VAE': ['recons_loss'], 'GAN': ['g_loss', 'd_loss']}
        self.liveloss = PlotLosses(outputs=[BokehPlot(max_cols=2)],
                                   mode='script',
                                   groups=groups)

        self.vae_gan = VAE_GAN(parameters.image_size,
                               parameters.image_channels,
                               parameters.latent_dim,
                               self.dataloader,
                               layers_per_size=parameters.layers_per_size,
                               channel_scaling=parameters.channel_scaling)
        self.vae_gan.eval()

        if parameters.cuda:
            self.vae_gan.cuda()

        if parameters.print_summary:
            print('Model Summary:')
            summary(self.vae_gan, depth=10)
コード例 #5
0
    def train(self, dataset, val_dataset=None, epochs=int(3e4), n_itr=100):
        try:
            z = tf.constant(
                np.load(f'{self.save_path}/{self.model_name}_z.npy'))
        except FileNotFoundError:
            z = tf.constant(random.normal((self.batch_size, 1, 1, self.z_dim)))
            os.makedirs(self.save_path, exist_ok=True)
            np.save(f'{self.save_path}/{self.model_name}_z', z.numpy())

        liveplot = PlotLosses()
        try:
            losses_list = pickle.load(
                open(f'{self.save_path}/{self.model_name}_losses_list.pkl',
                     'rb'))
        except:
            losses_list = []

        for i, losses in enumerate(losses_list):
            liveplot.update(losses, i)

        start_epoch = len(losses_list)

        g_train_loss = metrics.Mean()
        d_train_loss = metrics.Mean()
        d_val_loss = metrics.Mean()

        for epoch in range(start_epoch, epochs):
            train_bar = pbar(n_itr, epoch, epochs)
            for itr_c, batch in zip(range(n_itr), dataset):
                if train_bar.n >= n_itr:
                    break

                for _ in range(self.n_critic):
                    d_loss = self.train_d(batch['images'])
                    d_train_loss(d_loss)

                g_loss = self.train_g()
                g_train_loss(g_loss)
                self.train_g()

                train_bar.postfix['g_loss'] = f'{g_train_loss.result():6.3f}'
                train_bar.postfix['d_loss'] = f'{d_train_loss.result():6.3f}'
                train_bar.update(n=itr_c)

            train_bar.close()

            if val_dataset:
                val_bar = vbar(n_itr // 5, epoch, epochs)
                for itr_c, batch in zip(range(n_itr // 5), val_dataset):
                    if val_bar.n >= n_itr // 5:
                        break

                    d_val_l = self.val_d(batch['images'])
                    d_val_loss(d_val_l)

                    val_bar.postfix[
                        'd_val_loss'] = f'{d_val_loss.result():6.3f}'
                    val_bar.update(n=itr_c)
                val_bar.close()

            losses = {
                'g_loss': g_train_loss.result(),
                'd_loss': d_train_loss.result(),
                'd_val_loss': d_val_loss.result()
            }
            losses_list += [losses]
            pickle.dump(
                losses_list,
                open(f'{self.save_path}/{self.model_name}_losses_list.pkl',
                     'wb'))
            liveplot.update(losses, epoch)
            liveplot.send()

            g_train_loss.reset_states()
            d_train_loss.reset_states()
            d_val_loss.reset_states()
            del train_bar
            del val_bar

            self.G.save_weights(
                filepath=f'{self.save_path}/{self.model_name}_generator')
            self.D.save_weights(
                filepath=f'{self.save_path}/{self.model_name}_discriminator')

            if epoch >= int(2e4):
                if epoch % 1000 == 0:
                    self.G.save_weights(
                        filepath=
                        f'{self.save_path}/{self.model_name}_generator{epoch}')
                    self.D.save_weights(
                        filepath=
                        f'{self.save_path}/{self.model_name}_discriminator{epoch}'
                    )

            if epoch % 5 == 0:
                samples = self.generate_samples(z)
                image_grid = img_merge(samples, n_rows=6).squeeze()
                img_path = f'./images/{self.model_name}'
                os.makedirs(img_path, exist_ok=True)
                save_image_grid(image_grid,
                                epoch + 1,
                                self.model_name,
                                output_dir=img_path)
コード例 #6
0
ファイル: generic_keras.py プロジェクト: zwq1230/livelossplot
 def __init__(self, **kwargs):
     self.liveplot = PlotLosses(**kwargs)
コード例 #7
0
 def __init__(self, **kwargs):
     """
     Args:
         **kwargs: keyword arguments that will be passed to PlotLosses constructor
     """
     self.liveplot = PlotLosses(**kwargs)