예제 #1
0
    def __init__(self, **kw):
        if kw.get('user_global_ns', None) is not None:
            raise DeprecationWarning(
                "Key word argument `user_global_ns` has been replaced by `user_module` since IPython 4.0."
            )

        clid = kw.pop('_init_location_id', None)
        if not clid:
            frame = sys._getframe(1)
            clid = '%s:%s' % (frame.f_code.co_filename, frame.f_lineno)
        self._init_location_id = clid

        super(InteractiveShellEmbed, self).__init__(**kw)

        # don't use the ipython crash handler so that user exceptions aren't
        # trapped
        sys.excepthook = ultratb.FormattedTB(color_scheme=self.colors,
                                             mode=self.xmode,
                                             call_pdb=self.pdb)
예제 #2
0
def crash(crashfile, rerun, debug, ipydebug, dir):
    """Display Nipype crash files.

    For certain crash files, one can rerun a failed node in a temp directory.

    Examples:\n
    nipypecli crash crashfile.pklz\n
    nipypecli crash crashfile.pklz -r -i\n
    """
    from .crash_files import display_crash_file

    debug = 'ipython' if ipydebug else debug
    if debug == 'ipython':
        import sys
        from IPython.core import ultratb
        sys.excepthook = ultratb.FormattedTB(mode='Verbose',
                                             color_scheme='Linux',
                                             call_pdb=1)
    display_crash_file(crashfile, rerun, debug, dir)
예제 #3
0
    def __init__(self, config=None, ipython_dir=None, user_ns=None,
                 user_global_ns=None, custom_exceptions=((),None),
                 usage=None, banner1=None, banner2=None,
                 display_banner=None, exit_msg=''):

        super(InteractiveShellEmbed,self).__init__(
            config=config, ipython_dir=ipython_dir, user_ns=user_ns,
            user_global_ns=user_global_ns, custom_exceptions=custom_exceptions,
            usage=usage, banner1=banner1, banner2=banner2,
            display_banner=display_banner
        )

        self.exit_msg = exit_msg
        self.define_magic("kill_embedded", kill_embedded)

        # don't use the ipython crash handler so that user exceptions aren't
        # trapped
        sys.excepthook = ultratb.FormattedTB(color_scheme=self.colors,
                                             mode=self.xmode,
                                             call_pdb=self.pdb)
예제 #4
0
파일: embed.py 프로젝트: zzygyx9119/ipython
    def __init__(self, **kw):

        if kw.get('user_global_ns', None) is not None:
            warnings.warn(
                "user_global_ns has been replaced by user_module. The\
                           parameter will be ignored, and removed in IPython 5.0",
                DeprecationWarning)

        self._call_location_id = kw.pop('_call_location_id', None)

        super(InteractiveShellEmbed, self).__init__(**kw)

        if not self._call_location_id:
            frame = sys._getframe(1)
            self._call_location_id = '%s:%s' % (frame.f_code.co_filename,
                                                frame.f_lineno)
        # don't use the ipython crash handler so that user exceptions aren't
        # trapped
        sys.excepthook = ultratb.FormattedTB(color_scheme=self.colors,
                                             mode=self.xmode,
                                             call_pdb=self.pdb)
예제 #5
0
def setup_exceptionhook(ipython=False):
    """Overloads default sys.excepthook with our exceptionhook handler.

       If interactive, our exceptionhook handler will invoke
       pdb.post_mortem; if not interactive, then invokes default handler.
    """
    def _reproman_pdb_excepthook(type, value, tb):
        import traceback
        traceback.print_exception(type, value, tb)
        print()
        if is_interactive():
            import pdb
            pdb.post_mortem(tb)

    if ipython:
        from IPython.core import ultratb
        sys.excepthook = ultratb.FormattedTB(
            mode='Verbose',
            # color_scheme='Linux',
            call_pdb=is_interactive())
    else:
        sys.excepthook = _reproman_pdb_excepthook
예제 #6
0
def register_ipython_excepthook(capture_keyboard_interrupt: bool = False) -> None:
    r"""Register an exception hook that launches an interactive IPython session upon uncaught exceptions.

    :param capture_keyboard_interrupt: If ``False``, an uncaught :py:exc:`KeyboardInterrupt` exception will not trigger
        the IPython debugger. Defaults to ``False``.
    """
    skip_exceptions: List[Type[BaseException]] = [BdbQuit]
    if not capture_keyboard_interrupt:
        skip_exceptions.append(KeyboardInterrupt)

    def excepthook(type, value, traceback):
        if any(type is exc_type for exc_type in skip_exceptions):
            # Don't capture keyboard interrupts (Ctrl+C) or Python debugger exit events.
            sys.__excepthook__(type, value, traceback)
        else:
            ipython_hook(type, value, traceback)

    # Enter IPython debugger on exception.
    from IPython.core import ultratb

    ipython_hook = ultratb.FormattedTB(mode='Context', color_scheme='Linux', call_pdb=1)
    sys.excepthook = excepthook
예제 #7
0
    def run(self):
        import sys
        # Invoke the 'build' command to "build" pure Python modules
        # (ie. copy 'em into the build tree)
        self.run_command('build')

        # remember old sys.path to restore it afterwards
        old_path = sys.path[:]

        # extend sys.path
        sys.path.insert(0, self.build_purelib)
        sys.path.insert(0, self.build_platlib)
        sys.path.insert(0, self.test_dir)

        # run tests
        if self.test is not None:
            tests = TestLoader().loadTestsFromNames([self.test])
        else:
            tests = TestLoader().loadTestsFromNames(
                [self.test_prefix + case for case in self.test_suffixes])

        if self.debug is not None:
            import sys
            from IPython.core import ultratb
            sys.excepthook = ultratb.FormattedTB(mode='Verbose',
                                                 color_scheme='LightBG',
                                                 call_pdb=1)
            tests.debug()
        else:
            print 'Running tests with verbosity %d' % self.verbosity
            runner = TextTestRunner(verbosity=self.verbosity)
            result = runner.run(tests)

            # Exit script with exitcode 1 if any tests failed
            if result.failures or result.errors:
                sys.exit(1)

        # restore sys.path
        sys.path = old_path[:]
예제 #8
0
    def __init__(self, config=None, ipython_dir=None, user_ns=None,
                 user_module=None, custom_exceptions=((),None),
                 usage=None, banner1=None, banner2=None,
                 display_banner=None, exit_msg=u'', user_global_ns=None):
    
        if user_global_ns is not None:
            warnings.warn("user_global_ns has been replaced by user_module. The\
                           parameter will be ignored.", DeprecationWarning)

        super(InteractiveShellEmbed,self).__init__(
            config=config, ipython_dir=ipython_dir, user_ns=user_ns,
            user_module=user_module, custom_exceptions=custom_exceptions,
            usage=usage, banner1=banner1, banner2=banner2,
            display_banner=display_banner
        )

        self.exit_msg = exit_msg

        # don't use the ipython crash handler so that user exceptions aren't
        # trapped
        sys.excepthook = ultratb.FormattedTB(color_scheme=self.colors,
                                             mode=self.xmode,
                                             call_pdb=self.pdb)
예제 #9
0
"""
Using SparkDataFrames and SQL to perform a query on a 20GB file
"""

import sys
from IPython.core import ultratb
from pyspark.sql.session import SparkSession
from pyspark.sql.types import (BooleanType, FloatType, IntegerType, StringType,
                               StructField, StructType)
from pyspark.sql.functions import to_timestamp

# ensure error messages are color coded using IPython color schema ----
sys.excepthook = ultratb.FormattedTB(mode="Plain",
                                     color_scheme="Linux",
                                     call_pdb=False)

# construct schema ----
# store column names in one string
schemaString = """trip_id
trip_start
trip_end
trip_seconds
trip_miles
pickup_census_tract
dropoff_census_tract
pickup_community_area
dropoff_community_area
fare
tip
additional_charges
trip_total
예제 #10
0
def enable_ipdb():
    # from <http://ipython.readthedocs.io/en/stable/interactive/reference.html#post-mortem-debugging>
    from IPython.core import ultratb
    sys.excepthook = ultratb.FormattedTB(mode='Verbose',
                                         color_scheme='Linux',
                                         call_pdb=1)
예제 #11
0
    def run(self, argv=None):
        parsed_args = self.args_parser.parse_args(argv)

        for trace in parsed_args.trace:
            print("Trace {} is activated".format(trace))
            Log.enable(trace)

        Diagnostics.set_style(parsed_args.diagnostic_style)

        if parsed_args.profile:
            import cProfile
            import pstats

            pr = cProfile.Profile()
            pr.enable()

        # Set the verbosity
        self.verbosity = parsed_args.verbosity

        self.no_ada_api = parsed_args.no_ada_api

        # If asked to, setup the exception hook as a last-chance handler to
        # invoke a debugger in case of uncaught exception.
        if parsed_args.debug:
            # Try to use IPython's debugger if it is available, otherwise
            # fallback to PDB.
            try:
                # noinspection PyPackageRequirements
                from IPython.core import ultratb
            except ImportError:
                ultratb = None  # To keep PyCharm happy...

                def excepthook(type, value, tb):
                    traceback.print_exception(type, value, tb)
                    pdb.post_mortem(tb)

                sys.excepthook = excepthook
            else:
                sys.excepthook = ultratb.FormattedTB(mode='Verbose',
                                                     color_scheme='Linux',
                                                     call_pdb=1)
            del ultratb

        self.dirs.set_build_dir(parsed_args.build_dir)
        install_dir = getattr(parsed_args, 'install-dir', None)
        if install_dir:
            self.dirs.set_install_dir(install_dir)

        if getattr(parsed_args, 'list_warnings', False):
            WarningSet.print_list()
            return

        # noinspection PyBroadException
        try:
            parsed_args.func(parsed_args)

        except DiagnosticError:
            if parsed_args.debug:
                raise
            if parsed_args.verbosity.debug or parsed_args.full_error_traces:
                traceback.print_exc()
            print(col('Errors, exiting', Colors.FAIL), file=sys.stderr)
            sys.exit(1)

        except Exception as e:
            if parsed_args.debug:
                raise
            ex_type, ex, tb = sys.exc_info()

            # If we have a syntax error, we know for sure the last stack frame
            # points to the code that must be fixed. Otherwise, point to the
            # top-most stack frame that does not belong to Langkit.
            if e.args and e.args[0] == 'invalid syntax':
                loc = Location(e.filename, e.lineno)
            else:
                loc = extract_library_location(traceback.extract_tb(tb))
            with Context("", loc, "recovery"):
                check_source_language(False, str(e), do_raise=False)

            # Keep Langkit bug "pretty" for users: display the Python stack
            # trace only when requested.
            if parsed_args.verbosity.debug or parsed_args.full_error_traces:
                traceback.print_exc()

            print(col('Internal error! Exiting', Colors.FAIL), file=sys.stderr)
            sys.exit(1)

        finally:
            if parsed_args.profile:
                pr.disable()
                ps = pstats.Stats(pr)
                ps.dump_stats('langkit.prof')
예제 #12
0
 def excepthook(*args, **kwargs):
     from IPython.core import ultratb
     sys.excepthook = ultratb.FormattedTB(mode='Verbose', color_scheme='Linux', call_pdb=1)
     return sys.excepthook(*args, **kwargs)
예제 #13
0
 def __call__(self, *args, **kwargs):
     if self.instance is None:
         from IPython.core import ultratb
         self.instance = ultratb.FormattedTB(mode='Plain',
              color_scheme='Linux', call_pdb=1)
     return self.instance(*args, **kwargs)
def set_debugger_org():
    if not sys.excepthook == sys.__excepthook__:
        from IPython.core import ultratb
        sys.excepthook = ultratb.FormattedTB(call_pdb=True)
예제 #15
0
       data object

functions:
pick_events -> pick waveforms around ponset of events
calculate_rf -> calculates rf
mout -> move out correction

the rest of the functions are depreciated and moved to other scripts
"""

import sys

from IPython.core import ultratb

sys.excepthook = ultratb.FormattedTB(mode='Verbose',
                                     color_scheme='LightBG',
                                     call_pdb=1)

import logging
import os.path
import calendar
import numpy as np
from obspy.core import UTCDateTime
from sito import data as mod_data, events, imaging, read, util, Stream
import pylab as plt
import matplotlib as mpl

exha = util.exha
log = logging.getLogger('script_receiver')
top = 'black'
bot = 'gray'
예제 #16
0
import h5py
import json
import numpy as np
import tensorflow as tf
from termcolor import colored, cprint
import tensorflow.contrib.slim as slim

from config import config, loadDatasetConfig, parseArgs
from preprocess import Preprocesser, bold, bcolored, writeline, writelist
from model import MACnet
from collections import defaultdict

from IPython.core import ultratb

sys.excepthook = ultratb.FormattedTB(mode='Plain',
                                     color_scheme='Linux',
                                     call_pdb=1)

############################################# loggers #############################################


# Writes log header to file
def logInit():
    with open(config.logFile(), "a+") as outFile:
        writeline(outFile, config.expName)
        headers = [
            "epoch", "trainAcc", "valAcc", 'testAcc', "trainLoss", "valLoss",
            'testLoss'
        ]
        if config.evalTrain:
            headers += ["evalTrainAcc", "evalTrainLoss"]
예제 #17
0
import os
import json

import pdb
from IPython.core import ultratb
import sys
sys.excepthook = ultratb.FormattedTB(call_pdb=True)

COLORS = ['gray', 'red', 'blue', 'green', 'brown', 'yellow', 'cyan', 'purple']
MATERIALS = ['metal', 'rubber']
SHAPES = ['sphere', 'cylinder', 'cube']


class Executor():
    """Symbolic program executor for V-CLEVR questions"""
    def __init__(self, sim):
        self._set_sim(sim)
        self._register_modules()

    def run(self, pg, debug=False):
        exe_stack = []
        for m in pg:
            #if m=="unique":
            #    import pdb
            #     pdb.set_trace()
            if m in ['<END>', '<NULL>']:
                break
            if m not in ['<START>']:
                if m not in self.modules:
                    exe_stack.append(m)
                else:
예제 #18
0
def interactive_debugger():
    from IPython.core import ultratb
    sys.excepthook = ultratb.FormattedTB(mode='Verbose',
        color_scheme='Linux', call_pdb=1)
예제 #19
0
 def wrappers(*args, **kwargs):
     try:
         return func(*args, **kwargs)
     except errors:
         ipshell = ultratb.FormattedTB(mode='Context', color_scheme='LightBG', call_pdb=1)
         ipshell()
예제 #20
0
def activate_live_debugging():
    """Activates live debugging with IPython's pdb
    """
    _logger.info("Activating live debugging...")
    from IPython.core import ultratb
    sys.excepthook = ultratb.FormattedTB(mode='Verbose', color_scheme='Linux', call_pdb=1)
예제 #21
0
    def __init__(self, **kwargs) -> None:
        self._check_types()

        for k, v in kwargs.items():
            setattr(self, k, v)

        # TODO: Add non-null checks
        # TODO: Add "no-" prefix stuff for switches
        # TODO: Generate help by inspecting comments

        i = 1
        while i < len(sys.argv):
            arg: str = sys.argv[i]
            if arg.startswith('--'):
                argname = arg[2:].replace('-', '_')
                if argname.startswith('no_') and not hasattr(
                        self, argname) and hasattr(self, argname[3:]):
                    attr = getattr(self, argname[3:])
                    if isinstance(attr, Arguments.Switch):
                        attr._value = False
                        i += 1
                        continue

                if hasattr(self, argname):
                    attr = getattr(self, argname)
                    if isinstance(attr, Arguments.Switch):
                        attr._value = True
                        i += 1
                        continue

                    typ = self.__annotations__.get(argname, type(attr))
                    nullable = False
                    # TODO: hacks here
                    if hasattr(
                            typ,
                            '__origin__') and typ.__origin__ == Union and type(
                                None) in typ.__args__:
                        # hacky check of whether `typ` is `Optional`
                        nullable = True
                        typ = next(t for t in typ.__args__
                                   if not isinstance(t, custom_types.NoneType)
                                   )  # type: ignore
                    argval: str = sys.argv[i + 1]
                    if argval.lower() == 'none':
                        if nullable:
                            val = None
                        else:
                            assert typ is str or is_choices(typ), \
                                f"Cannot assign None to non-nullable, non-str argument '{argname}'"
                            val = argval
                    elif isinstance(typ,
                                    custom_types.NoneType):  # type: ignore
                        val = None  # just to suppress "ref before assign" warning
                        try:
                            # priority: low -> high
                            for target_typ in [str, float, int]:
                                val = target_typ(argval)
                        except ValueError:
                            pass
                    elif typ is str:
                        val = argval
                    elif isinstance(
                            typ,
                            custom_types.Path) or typ is custom_types.Path:
                        val = Path(argval)
                        if isinstance(typ, custom_types.Path) and typ.exists:
                            assert val.exists(), ValueError(
                                f"Argument '{argname}' requires an existing path, "
                                f"but '{argval}' does not exist")
                    elif is_choices(typ):
                        val = argval
                        assert val in typ.__values__, f"Invalid value '{val}' for argument '{arg}', " \
                            f"available choices are: {typ.__values__}"
                    elif issubclass(Arguments.Enum, typ):
                        # experimental support for custom enum
                        try:
                            # noinspection PyCallingNonCallable
                            val = typ(argval)
                        except ValueError:
                            valid_args = {x.value for x in typ}
                            raise ValueError(
                                f"Invalid value '{argval}' for argument '{argname}', "
                                f"available choices are: {valid_args}"
                            ) from None

                    elif typ is bool:
                        val = argval in ['true', '1', 'True', 'y', 'yes']
                    else:
                        try:
                            val = ast.literal_eval(argval)
                        except ValueError:
                            raise ValueError(
                                f"Invalid value '{argval}' for argument '{argname}'"
                            ) from None
                    setattr(self, argname, val)
                    i += 2
                else:
                    raise ValueError(f"Invalid argument: '{arg}'")
            else:
                Logging.warn(f"Unrecognized command line argument: '{arg}'")
                i += 1

        if self.ipdb:
            # enter IPython debugger on exception
            from IPython.core import ultratb
            sys.excepthook = ultratb.FormattedTB(mode='Context',
                                                 color_scheme='Linux',
                                                 call_pdb=1)

        self.preprocess()
        self._validate()
        self.postprocess()
예제 #22
0
from prompt_toolkit.history import FileHistory
from prompt_toolkit.layout.processors import HighlightMatchingBracketProcessor
from prompt_toolkit.lexers import PygmentsLexer
from prompt_toolkit.styles import style_from_pygments_cls
from pygments.lexers import Python3Lexer
from pygments.styles.friendly import FriendlyStyle
from pygments.styles.monokai import MonokaiStyle
from pygments_style_monokailight.monokailight import MonokaiLightStyle
import requests
import qhue

from philips_hue.color import mired, rgb_to_xybri

try:
    from IPython.core import ultratb
    sys.excepthook = ultratb.FormattedTB()
except ImportError:
    pass


class BGColor(Enum):
    """Represents our state of knowledge about the terminal background color."""
    UNKNOWN = 0
    DARK = -1
    LIGHT = 1


def get_bg_color():
    """Returns :data:`BGColor.UNKNOWN` if the terminal background color could
    not be determined; otherwise, :data:`BGColor.LIGHT` if the terminal has a
    light background and :data:`BGColor.DARK` if it has a dark background."""
예제 #23
0
파일: migrate.py 프로젝트: DarkDare/soledad
def _enable_pdb():
    import sys
    from IPython.core import ultratb
    sys.excepthook = ultratb.FormattedTB(mode='Verbose',
                                         color_scheme='Linux',
                                         call_pdb=1)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
import sys
from pathlib import Path

import click
from IPython.core import ultratb

import justcause

# fallback to debugger on error
sys.excepthook = ultratb.FormattedTB(mode="Verbose", color_scheme="Linux", call_pdb=1)

_logger = logging.getLogger(__name__)


@click.command()
@click.option(
    "-c",
    "--config",
    "cfg_path",
    required=True,
    type=click.Path(exists=True),
    help="path to config file",
)
@click.option("--quiet", "log_level", flag_value=logging.WARNING, default=True)
@click.option("-v", "--verbose", "log_level", flag_value=logging.INFO)
@click.option("-vv", "--very-verbose", "log_level", flag_value=logging.DEBUG)
@click.version_option(justcause.__version__)
def main(cfg_path: Path, log_level: int):
def set_debugger_org_frc():
    from IPython.core import ultratb
    sys.excepthook = ultratb.FormattedTB(call_pdb=True)
예제 #26
0
        states = pool.get_states()
        for state in states:
            total.update(state.results)
    print(
        f"Train: {sum(v for k, v in total.items() if k.startswith('train'))}")
    print(
        f"Valid: {sum(v for k, v in total.items() if k.startswith('valid'))}")
    print(f"Test:  {sum(v for k, v in total.items() if k.startswith('test'))}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dump-dir", type=Path, required=True)
    parser.add_argument("--output-dir", type=Path, required=True)
    parser.add_argument("--seed", type=int, default=1726)
    parser.add_argument("--split-ratio",
                        type=float,
                        default=0.8,
                        help="ratio of training set.")
    parser.add_argument("--overwrite", action="store_true")
    parser.add_argument("--njobs", type=int, default=6)
    parser.add_argument("--legacy", action="store_true", help="TODO")
    from IPython.core import ultratb
    import sys

    sys.excepthook = ultratb.FormattedTB(mode="Context",
                                         color_scheme="Linux",
                                         call_pdb=1)
    args = parser.parse_args()
    main(args)
예제 #27
0
import threading
import time
import random

# import stringio

import re
import configparser
from signal import signal, SIGWINCH, SIGKILL, SIGTERM

from IPython.core.debugger import Tracer
from IPython.core import ultratb

sys.excepthook = ultratb.FormattedTB(mode='Verbose',
                                     color_scheme='Linux',
                                     call_pdb=True,
                                     ostream=sys.__stdout__)

from colorlog import ColoredFormatter

import logging

from gettext import gettext as _

import traceback
from functools import wraps
import Queue


def setup_logger():
    """Return a logger with a default ColoredFormatter."""
예제 #28
0
    def run(self, argv=None):
        parsed_args = self.args_parser.parse_args(argv)

        from langkit import diagnostics
        diagnostics.EMIT_PARSABLE_ERRORS = parsed_args.parsable_errors

        if parsed_args.profile:
            import cProfile
            import pstats

            pr = cProfile.Profile()
            pr.enable()

        # If asked to, setup the exception hook as a last-chance handler to
        # invoke a debugger in case of uncaught exception.
        if parsed_args.debug:
            # Try to use IPython's debugger if it is available, otherwise
            # fallback to PDB.
            try:
                # noinspection PyPackageRequirements
                from IPython.core import ultratb
            except ImportError:
                ultratb = None  # To keep PyCharm happy...

                def excepthook(type, value, tb):
                    import traceback
                    traceback.print_exception(type, value, tb)
                    pdb.post_mortem(tb)

                sys.excepthook = excepthook
            else:
                sys.excepthook = ultratb.FormattedTB(mode='Verbose',
                                                     color_scheme='Linux',
                                                     call_pdb=1)
            del ultratb

        self.dirs.set_build_dir(parsed_args.build_dir)
        install_dir = getattr(parsed_args, 'install-dir', None)
        if install_dir:
            self.dirs.set_install_dir(install_dir)

        # Compute code coverage in the code generator if asked to
        if parsed_args.func == self.do_generate and parsed_args.coverage:
            try:
                cov = Coverage(self.dirs)
            except Exception as exc:
                import traceback
                print >> sys.stderr, 'Coverage not available:'
                traceback.print_exc(exc)
                sys.exit(1)

            cov.start()
        else:
            cov = None

        # noinspection PyBroadException
        try:
            parsed_args.func(parsed_args)
        except DiagnosticError:
            if parsed_args.debug:
                raise
            print >> sys.stderr, col('Errors, exiting', Colors.FAIL)
            sys.exit(1)
        except Exception, e:
            if parsed_args.debug:
                raise
            import traceback
            ex_type, ex, tb = sys.exc_info()
            if e.args[0] == 'invalid syntax':
                loc = Location(e.filename, e.lineno, "")
            else:
                loc = extract_library_location(traceback.extract_tb(tb))
            with Context("", loc, "recovery"):
                check_source_language(False, str(e), do_raise=False)
            if parsed_args.verbosity.debug:
                traceback.print_exc()

            print >> sys.stderr, col('Internal error! Exiting', Colors.FAIL)
            sys.exit(1)
예제 #29
0
import time
from typing import List

import torch
import torch.utils.data as data_utils
from IPython.core import ultratb

from ImageTranslate import dataset
from ImageTranslate.option_parser import get_img_options_parser
from ImageTranslate.sen_sim import SenSim
from ImageTranslate.seq2seq import Seq2Seq
from ImageTranslate.textprocessor import TextProcessor
from ImageTranslate.train_image_mt import ImageMTTrainer
from ImageTranslate.utils import build_optimizer, backward

sys.excepthook = ultratb.FormattedTB(mode='Verbose', color_scheme='Linux', call_pdb=False)


class SenSimTrainer(ImageMTTrainer):
    def train_epoch(self, step: int, saving_path: str = None,
                    mt_dev_iter: List[data_utils.DataLoader] = None,
                    mt_train_iter: List[data_utils.DataLoader] = None, max_step: int = 300000,
                    src_neg_iter: data_utils.DataLoader = None, dst_neg_iter: data_utils.DataLoader = None,
                    **kwargs):
        "Standard Training and Logging Function"
        start = time.time()
        total_tokens, total_loss, tokens, cur_loss = 0, 0, 0, 0
        cur_loss = 0

        batch_zip, shortest = self.get_batch_zip(None, None, mt_train_iter)
예제 #30
0
    def __init__(self, *args, **kwargs) -> None:
        self._check_types()

        for k, v in kwargs.items():
            setattr(self, k, v)

        # TODO: Add non-null checks
        # TODO: Add "no-" prefix stuff for switches
        # TODO: Generate help by inspecting comments

        if len(args) == 0:
            argv = sys.argv
        elif len(args) == 1:
            argv = args[0]
        else:
            raise ValueError(
                f"Argument class takes zero or one positional arguments but {len(args)} were given"
            )
        i = 1
        while i < len(argv):
            arg: str = argv[i]
            if arg.startswith('--'):
                argname = arg[2:].replace('-', '_')
                if argname.startswith('no_') and not hasattr(
                        self, argname) and hasattr(self, argname[3:]):
                    attr = getattr(self, argname[3:])
                    if isinstance(attr, Arguments.Switch):
                        attr._value = False
                        i += 1
                        continue

                if hasattr(self, argname):
                    attr = getattr(self, argname)
                    if isinstance(attr, Arguments.Switch):
                        attr._value = True
                        i += 1
                        continue

                    nullable, typ = self._get_arg_type(argname)
                    argval: str = argv[i + 1]
                    if argval.lower() == 'none':
                        if nullable:
                            val = None
                        else:
                            assert typ is str or is_choices(typ), \
                                f"Cannot assign None to non-nullable, non-str argument '{argname}'"
                            val = argval
                    elif isinstance(typ,
                                    custom_types.NoneType):  # type: ignore
                        val = None  # just to suppress "ref before assign" warning
                        try:
                            # priority: low -> high
                            for target_typ in [str, float, int]:
                                val = target_typ(argval)
                        except ValueError:
                            pass
                    elif typ is str:
                        val = argval
                    elif isinstance(
                            typ,
                            custom_types.Path) or typ is custom_types.Path:
                        val = Path(argval)
                        if isinstance(typ, custom_types.Path) and typ.exists:
                            assert val.exists(), ValueError(
                                f"Argument '{argname}' requires an existing path, "
                                f"but '{argval}' does not exist")
                    elif is_choices(typ):
                        val = argval
                        assert val in typ.__values__, f"Invalid value '{val}' for argument '{arg}', " \
                                                      f"available choices are: {typ.__values__}"
                    elif issubclass(typ, Arguments.Enum):
                        # experimental support for custom enum
                        try:
                            # noinspection PyCallingNonCallable
                            val = typ(argval)
                        except ValueError:
                            valid_args = {x.value for x in typ}
                            raise ValueError(
                                f"Invalid value '{argval}' for argument '{argname}', "
                                f"available choices are: {valid_args}"
                            ) from None

                    elif typ is bool:
                        val = argval in ['true', '1', 'True', 'y', 'yes']
                    else:
                        try:
                            val = ast.literal_eval(argval)
                        except ValueError:
                            raise ValueError(
                                f"Invalid value '{argval}' for argument '{argname}'"
                            ) from None
                    setattr(self, argname, val)
                    i += 2
                else:
                    raise ValueError(f"Invalid argument: '{arg}'")
            else:
                Logging.warn(f"Unrecognized command line argument: '{arg}'")
                i += 1

        if self.pdb:
            # enter IPython debugger on exception
            from IPython.core import ultratb
            ipython_hook = ultratb.FormattedTB(mode='Context',
                                               color_scheme='Linux',
                                               call_pdb=1)

            def excepthook(type, value, traceback):
                if type is KeyboardInterrupt:
                    # don't capture keyboard interrupts (Ctrl+C)
                    sys.__excepthook__(type, value, traceback)
                else:
                    ipython_hook(type, value, traceback)

            sys.excepthook = excepthook

        self.preprocess()

        # check whether non-optional attributes are none
        for arg in dir(self):
            if not arg.startswith('_') and arg not in self._reserved_keys:
                attr = getattr(self, arg)
                nullable, _ = self._get_arg_type(arg)
                if attr is None and not nullable:
                    raise ValueError(f"argument '{arg}' cannot be none")

        self._validate()
        self.postprocess()

        # convert switches to bool
        for arg in dir(self):
            if not arg.startswith('_') and arg not in self._reserved_keys:
                attr = getattr(self, arg)
                typ = self.__annotations__.get(arg, None)
                if isinstance(attr, Arguments.Switch):
                    # noinspection PyProtectedMember
                    setattr(self, arg, bool(attr))
                if isinstance(typ, type) and issubclass(
                        typ, Path) and isinstance(attr, str):
                    setattr(self, arg, Path(attr))