示例#1
0
    def eval(self,
             val_loader,
             dataset_name,
             savedir=None,
             loss_key=None,
             **kwargs):

        avg_meters = util.AverageMeters()
        model = self.model
        opt = self.opt
        with torch.no_grad():
            for i, data in enumerate(val_loader):
                index = model.eval(data, savedir=savedir, **kwargs)
                avg_meters.update(index)

                util.progress_bar(i, len(val_loader), str(avg_meters))

        if not opt.no_log:
            util.write_loss(self.writer, join('eval', dataset_name),
                            avg_meters, self.epoch)

        if loss_key is not None:
            val_loss = avg_meters[loss_key]
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                print(
                    'saving the best model at the end of epoch %d, iters %d' %
                    (self.epoch, self.iterations))
                model.save(label='best_{}_{}'.format(loss_key, dataset_name))

        return avg_meters
示例#2
0
 def test(self, test_loader, savedir=None, **kwargs):
     model = self.model
     opt = self.opt
     with torch.no_grad():
         for i, data in enumerate(test_loader):
             model.test(data, savedir=savedir, **kwargs)
             util.progress_bar(i, len(test_loader))
示例#3
0
文件: engine.py 项目: szWingLee/ELD
    def train(self, train_loader, **kwargs):
        print('\nEpoch: %d' % self.epoch)
        avg_meters = util.AverageMeters()
        opt = self.opt
        model = self.model
        epoch = self.epoch

        epoch_start_time = time.time()
        # model.print_optimizer_param()
        for i, data in enumerate(train_loader):

            iter_start_time = time.time()
            iterations = self.iterations

            model.set_input(data, mode='train')
            model.optimize_parameters(**kwargs)

            errors = model.get_current_errors()
            avg_meters.update(errors)
            util.progress_bar(i, len(train_loader), str(avg_meters))

            if not opt.no_log:
                util.write_loss(self.writer, 'train', avg_meters, iterations)

                if iterations % opt.display_freq == 0 and opt.display_id != 0:
                    save_result = iterations % opt.update_html_freq == 0
                    self.visualizer.display_current_results(
                        model.get_current_visuals(), epoch, save_result)

                if iterations % opt.print_freq == 0 and opt.display_id != 0:
                    t = (time.time() - iter_start_time)
                    # self.visualizer.print_current_errors(epoch, i, errors, t)
                    # self.visualizer.plot_current_errors(epoch, i/len(train_loader), opt, errors)

            self.iterations += 1

        self.epoch += 1

        if not self.opt.no_log:
            if self.epoch % opt.save_epoch_freq == 0:
                print('saving the model at epoch %d, iters %d' %
                      (self.epoch, self.iterations))
                model.save()

            print('saving the latest model at the end of epoch %d, iters %d' %
                  (self.epoch, self.iterations))
            model.save(label='latest')

            print('Time Taken: %d sec' % (time.time() - epoch_start_time))

        model.update_learning_rate()
示例#4
0
    async def _level(
            self,
            ctx,
            *,
            user: typing.Optional[converters.BetterMemberConverter] = None):

        await ctx.cd()
        user = user or ctx.author

        embed = discord.Embed(timestamp=ctx.now, color=core.COLOR)
        embed.set_author(name=str(user), icon_url=user.avatar_url)

        _level, _xp = await ctx.bot.db.get_level(user)
        _multiplier = -1 + await ctx.db.get("users", user, "xp_multiplier")
        _req = core.LEVEL_FORMULA(_level)
        _percent = round((_ratio := _xp / _req) * 100, 1)
        _bar = util.progress_bar(**core.PROGRESS_BAR, ratio=_ratio, length=8)

        if _multiplier > 0:
            embed.add_field(name="XP Multiplier",
                            value=f"{_multiplier*100:.2f}%",
                            inline=False)
        embed.add_field(name=f"Level {_level:,}",
                        value=f"{_xp:,}/{_req:,} XP ({_percent}%)\n\n{_bar}")

        await ctx.send(embed, embed_perms=True)
        await ctx.bot.db.add_xp(ctx.author, 1)
示例#5
0
    def train(self, train_loader, **kwargs):
        print('\nEpoch: %d' % self.epoch)
        avg_meters = util.AverageMeters()
        opt = self.opt
        model = self.model
        epoch = self.epoch

        epoch_start_time = time.time()
        ########### for debug only
        # for i, data in enumerate(train_loader):
        #     iterations = self.iterations

        #     model.set_input(data, mode='train')
        #     model.optimize_parameters(**kwargs)
        #     # print(data.keys()) # ['input', 'target_t', 'fn', 'real', 'target_r', 'unaligned']
        #     # from PIL import Image
        #     # from util.util import tensor2im
        #     # for st in ['input', 'target_t', 'target_r']:
        #     #     A = tensor2im(data[st])
        #     #     im = Image.fromarray(A)
        #     #     im.save(st+'.jpeg')
        #     # del A
        #     # del im
        #     return 0
        ############# for debug only
        for i, data in enumerate(train_loader):
            iter_start_time = time.time()
            iterations = self.iterations

            model.set_input(data, mode='train')
            model.optimize_parameters(**kwargs)

            errors = model.get_current_errors()
            avg_meters.update(errors)
            util.progress_bar(i, len(train_loader), str(avg_meters), opt)

            if not opt.no_log:
                util.write_loss(self.writer, 'train', avg_meters, iterations)

                if iterations % opt.display_freq == 0 and opt.display_id != 0:
                    save_result = iterations % opt.update_html_freq == 0
                    self.visualizer.display_current_results(
                        model.get_current_visuals(), epoch, save_result)

                if iterations % opt.print_freq == 0 and opt.display_id != 0:
                    t = (time.time() - iter_start_time)

            self.iterations += 1

        self.epoch += 1

        if not self.opt.no_log:
            if self.epoch % opt.save_epoch_freq == 0:
                print('saving the model at epoch %d, iters %d' %
                      (self.epoch, self.iterations))
                model.save()

            print('saving the latest model at the end of epoch %d, iters %d' %
                  (self.epoch, self.iterations))
            model.save(label='latest')

            print('Time Taken: %d sec' % (time.time() - epoch_start_time))

        # model.update_learning_rate()
        train_loader.reset()