Exemplo n.º 1
0
        val_loss = evaluate(val_data)
        msglogger.info('-' * 89)
        msglogger.info(
            '| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.3f} | '
            'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                       val_loss, math.exp(val_loss)))
        msglogger.info('-' * 89)

        distiller.log_weights_sparsity(model,
                                       epoch,
                                       loggers=[tflogger, pylogger])

        stats = ('Peformance/Validation/',
                 OrderedDict([('Loss', val_loss),
                              ('Perplexity', math.exp(val_loss))]))
        tflogger.log_training_progress(stats, epoch, 0, total=1, freq=1)

        with open(args.save, 'wb') as f:
            torch.save(model, f)

        # Save the model if the validation loss is the best we've seen so far.
        if val_loss < best_val_loss:
            with open(args.save + ".best", 'wb') as f:
                torch.save(model, f)
            best_val_loss = val_loss
        lr_scheduler.step(val_loss)

        if compression_scheduler:
            compression_scheduler.on_epoch_end(epoch)

except KeyboardInterrupt:
Exemplo n.º 2
0
        val_loss = evaluate(val_data)
        msglogger.info('-' * 89)
        msglogger.info(
            '| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.3f} | '
            'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                       val_loss, math.exp(val_loss)))
        msglogger.info('-' * 89)

        distiller.log_weights_sparsity(model,
                                       epoch,
                                       loggers=[tflogger, pylogger])

        stats = ('Performance/Validation/',
                 OrderedDict([('Loss', val_loss),
                              ('Perplexity', math.exp(val_loss))]))
        tflogger.log_training_progress(stats, epoch, None)

        with open(args.save, 'wb') as f:
            torch.save(model, f)

        # Save the model if the validation loss is the best we've seen so far.
        if val_loss < best_val_loss:
            with open(args.save + ".best", 'wb') as f:
                torch.save(model, f)
            best_val_loss = val_loss
        lr_scheduler.step(val_loss)

        if compression_scheduler:
            compression_scheduler.on_epoch_end(epoch)

except KeyboardInterrupt: