예제 #1
0
def main(argv=None):
    if argv is None:
        argv = sys.argv

    usage = "usage: %prog [options] [workflow_file]"
    parser = optparse.OptionParser(usage=usage)

    parser.add_option("--no-discovery",
                      action="store_true",
                      help="Don't run widget discovery "
                      "(use full cache instead)")
    parser.add_option("--force-discovery",
                      action="store_true",
                      help="Force full widget discovery "
                      "(invalidate cache)")
    parser.add_option("--clear-widget-settings",
                      action="store_true",
                      help="Remove stored widget setting")
    parser.add_option("--no-welcome",
                      action="store_true",
                      help="Don't show welcome dialog.")
    parser.add_option("--no-splash",
                      action="store_true",
                      help="Don't show splash screen.")
    parser.add_option("-l",
                      "--log-level",
                      help="Logging level (0, 1, 2, 3, 4)",
                      type="int",
                      default=1)
    parser.add_option("--style",
                      help="QStyle to use",
                      type="str",
                      default=None)
    parser.add_option("--stylesheet",
                      help="Application level CSS style sheet to use",
                      type="str",
                      default="orange.qss")
    parser.add_option("--qt",
                      help="Additional arguments for QApplication",
                      type="str",
                      default=None)

    parser.add_option("--config",
                      help="Configuration namespace",
                      type="str",
                      default="orangecanvas.example")

    # -m canvas orange.widgets
    # -m canvas --config orange.widgets

    (options, args) = parser.parse_args(argv[1:])

    levels = [
        logging.CRITICAL, logging.ERROR, logging.WARN, logging.INFO,
        logging.DEBUG
    ]

    # Fix streams before configuring logging (otherwise it will store
    # and write to the old file descriptors)
    fix_win_pythonw_std_stream()

    # Try to fix fonts on OSX Mavericks/Yosemite, ...
    fix_osx_private_font()

    # File handler should always be at least INFO level so we need
    # the application root level to be at least at INFO.
    root_level = min(levels[options.log_level], logging.INFO)
    rootlogger = logging.getLogger(__package__)
    rootlogger.setLevel(root_level)

    # Standard output stream handler at the requested level
    stream_hander = logging.StreamHandler()
    stream_hander.setLevel(level=levels[options.log_level])
    rootlogger.addHandler(stream_hander)

    if options.config is not None:
        try:
            cfg = utils.name_lookup(options.config)
        except (ImportError, AttributeError):
            pass
        else:
            #             config.default = cfg
            config.set_default(cfg)

            log.info("activating %s", options.config)

    log.info("Starting 'Orange Canvas' application.")

    qt_argv = argv[:1]

    if options.style is not None:
        qt_argv += ["-style", options.style]

    if options.qt is not None:
        qt_argv += shlex.split(options.qt)

    qt_argv += args

    if options.clear_widget_settings:
        log.debug("Clearing widget settings")
        shutil.rmtree(config.widget_settings_dir(), ignore_errors=True)

    if QT_VERSION >= 0x50600:
        CanvasApplication.setAttribute(Qt.AA_UseHighDpiPixmaps)

    log.debug("Starting CanvasApplicaiton with argv = %r.", qt_argv)
    app = CanvasApplication(qt_argv)

    # NOTE: config.init() must be called after the QApplication constructor
    config.init()

    file_handler = logging.FileHandler(filename=os.path.join(
        config.log_dir(), "canvas.log"),
                                       mode="w")

    file_handler.setLevel(root_level)
    rootlogger.addHandler(file_handler)

    # intercept any QFileOpenEvent requests until the main window is
    # fully initialized.
    # NOTE: The QApplication must have the executable ($0) and filename
    # arguments passed in argv otherwise the FileOpen events are
    # triggered for them (this is done by Cocoa, but QApplicaiton filters
    # them out if passed in argv)

    open_requests = []

    def onrequest(url):
        log.info("Received an file open request %s", url)
        open_requests.append(url)

    app.fileOpenRequest.connect(onrequest)

    settings = QSettings()

    stylesheet = options.stylesheet
    stylesheet_string = None

    if stylesheet != "none":
        if os.path.isfile(stylesheet):
            with io.open(stylesheet, "r") as f:
                stylesheet_string = f.read()
        else:
            if not os.path.splitext(stylesheet)[1]:
                # no extension
                stylesheet = os.path.extsep.join([stylesheet, "qss"])

            pkg_name = __package__
            resource = "styles/" + stylesheet

            if pkg_resources.resource_exists(pkg_name, resource):
                stylesheet_string = \
                    pkg_resources.resource_string(pkg_name, resource).decode("utf-8")

                base = pkg_resources.resource_filename(pkg_name, "styles")

                pattern = re.compile(
                    r"^\s@([a-zA-Z0-9_]+?)\s*:\s*([a-zA-Z0-9_/]+?);\s*$",
                    flags=re.MULTILINE)

                matches = pattern.findall(stylesheet_string)

                for prefix, search_path in matches:
                    QDir.addSearchPath(prefix, os.path.join(base, search_path))
                    log.info("Adding search path %r for prefix, %r",
                             search_path, prefix)

                stylesheet_string = pattern.sub("", stylesheet_string)

            else:
                log.info("%r style sheet not found.", stylesheet)

    # Add the default canvas_icons search path
    dirpath = os.path.abspath(os.path.dirname(__file__))
    QDir.addSearchPath("canvas_icons", os.path.join(dirpath, "icons"))

    canvas_window = CanvasMainWindow()
    canvas_window.setWindowIcon(config.application_icon())

    if stylesheet_string is not None:
        canvas_window.setStyleSheet(stylesheet_string)

    if not options.force_discovery:
        reg_cache = cache.registry_cache()
    else:
        reg_cache = None

    widget_registry = qt.QtWidgetRegistry()
    widget_discovery = config.widget_discovery(widget_registry,
                                               cached_descriptions=reg_cache)

    want_splash = \
        settings.value("startup/show-splash-screen", True, type=bool) and \
        not options.no_splash

    if want_splash:
        pm, rect = config.splash_screen()
        splash_screen = SplashScreen(pixmap=pm, textRect=rect)
        splash_screen.setAttribute(Qt.WA_DeleteOnClose)
        splash_screen.setFont(QFont("Helvetica", 12))
        color = QColor("#FFD39F")

        def show_message(message):
            splash_screen.showMessage(message, color=color)

        widget_registry.category_added.connect(show_message)
        show_splash = splash_screen.show
        close_splash = splash_screen.close
    else:
        show_splash = close_splash = lambda: None

    log.info("Running widget discovery process.")

    cache_filename = os.path.join(config.cache_dir(), "widget-registry.pck")
    if options.no_discovery:
        with open(cache_filename, "rb") as f:
            widget_registry = pickle.load(f)
        widget_registry = qt.QtWidgetRegistry(widget_registry)
    else:
        show_splash()
        widget_discovery.run(config.widgets_entry_points())
        close_splash()

        # Store cached descriptions
        cache.save_registry_cache(widget_discovery.cached_descriptions)
        with open(cache_filename, "wb") as f:
            pickle.dump(WidgetRegistry(widget_registry), f)

    set_global_registry(widget_registry)
    canvas_window.set_widget_registry(widget_registry)
    canvas_window.show()
    canvas_window.raise_()

    want_welcome = \
        settings.value("startup/show-welcome-screen", True, type=bool) \
        and not options.no_welcome

    # Process events to make sure the canvas_window layout has
    # a chance to activate (the welcome dialog is modal and will
    # block the event queue, plus we need a chance to receive open file
    # signals when running without a splash screen)
    app.processEvents()

    app.fileOpenRequest.connect(canvas_window.open_scheme_file)

    if want_welcome and not args and not open_requests:
        canvas_window.welcome_dialog()

    elif args:
        log.info("Loading a scheme from the command line argument %r", args[0])
        canvas_window.load_scheme(args[0])
    elif open_requests:
        log.info("Loading a scheme from an `QFileOpenEvent` for %r",
                 open_requests[-1])
        canvas_window.load_scheme(open_requests[-1].toLocalFile())

    # Tee stdout and stderr into Output dock
    output_view = canvas_window.output_view()
    stdout = TextStream()
    stdout.stream.connect(output_view.write)
    if sys.stdout:
        stdout.stream.connect(sys.stdout.write)
        stdout.flushed.connect(sys.stdout.flush)
    stderr = TextStream()
    error_writer = output_view.formated(color=Qt.red)
    stderr.stream.connect(error_writer.write)
    if sys.stderr:
        stderr.stream.connect(sys.stderr.write)
        stderr.flushed.connect(sys.stderr.flush)
    sys.excepthook = ExceptHook(stream=stderr)

    with ExitStack() as stack:
        stack.enter_context(redirect_stdout(stdout))
        stack.enter_context(redirect_stderr(stderr))
        log.info("Entering main event loop.")
        try:
            status = app.exec_()
        except BaseException:
            log.error("Error in main event loop.", exc_info=True)

    canvas_window.deleteLater()
    app.processEvents()
    app.flush()
    del canvas_window

    # Collect any cycles before deleting the QApplication instance
    gc.collect()

    del app
    return status
예제 #2
0
def profile_command(args):
    callback.update_status(args)

    configure_progress(os.path.join(args.outdir, 'progress.txt'))

    class TrainConfig:
        pass
    config = TrainConfig()
    info = load.load(args.config)

    config.global_config = info.global_config
    config.training_config = info.training_config

    class OptConfig:
        pass
    config.optimizers = OrderedDict()
    for name, opt in info.optimizers.items():
        o = OptConfig()
        o.optimizer = opt
        o.data_iterators = []
        config.optimizers[name] = o

    class MonConfig:
        pass
    config.monitors = OrderedDict()
    for name, mon in info.monitors.items():
        m = MonConfig()
        m.monitor = mon
        m.data_iterators = []
        config.monitors[name] = m

    ext_module = import_extension_module(
        config.global_config.default_context.backend[0].split(':')[0])

    def synchronize(): return ext_module.synchronize(
        device_id=config.global_config.default_context.device_id)

    result_array = [['time in ms']]

    callback.update_status('processing', True)

    # Profile Optimizer
    with ExitStack() as stack:
        # Create data_iterator instance only once for each dataset in optimizers
        optimizer_data_iterators = {}
        for name, o in config.optimizers.items():
            for di in o.optimizer.data_iterators.values():
                if di not in optimizer_data_iterators:
                    di_instance = stack.enter_context(di())
                    optimizer_data_iterators[di] = di_instance
                else:
                    di_instance = optimizer_data_iterators[di]
                o.data_iterators.append(di_instance)
        result_array = profile_optimizer(config, result_array, synchronize)

    # Write profiling result
    import csv
    with open(args.outdir + os.sep + 'profile.csv', 'w') as f:
        writer = csv.writer(f, lineterminator='\n')
        writer.writerows(result_array)

    logger.log(99, 'Profile Completed.')
    progress(None)
    callback.update_status('finished')
    return True
예제 #3
0
def train_command(args):
    if args.ooc_gpu_memory_size is not None:
        ooc_gpu_memory_size = str_to_num(args.ooc_gpu_memory_size)
        if ooc_gpu_memory_size < 0:
            logger.log(
                99,
                f'Fatal error. invalid ooc_gpu_memory_size [{args.ooc_gpu_memory_size}].'
            )
            return False
        args.ooc_gpu_memory_size = ooc_gpu_memory_size
    if args.ooc_window_length is not None:
        ooc_window_length = str_to_num(args.ooc_window_length)
        if ooc_window_length < 0:
            logger.log(
                99,
                f'Fatal error. invalid ooc_window_length [{args.ooc_window_length}].'
            )
            return False
        args.ooc_window_length = ooc_window_length

    callback.update_status(args)

    if single_or_rankzero():
        configure_progress(os.path.join(args.outdir, 'progress.txt'))

    info = load.load([args.config],
                     prepare_data_iterator=None,
                     exclude_parameter=True,
                     context=args.context)

    # Check dataset uri is empty.
    dataset_error = False
    for dataset in info.datasets.values():
        if dataset.uri.strip() == '':
            dataset_error = True
    if dataset_error:
        logger.log(99, 'Fatal error. Dataset URI is empty.')
        return False

    class TrainConfig:
        pass

    config = TrainConfig()
    config.timelimit = -1
    if args.param:
        # If this parameter file contains optimizer information
        # we need to info to recovery.
        #load.load([args.param], parameter_only=True)
        load_train_state(args.param, info)

    config.timelimit = callback.get_timelimit(args)

    config.global_config = info.global_config
    config.training_config = info.training_config

    if single_or_rankzero():
        logger.log(99, 'Train with contexts {}'.format(available_contexts))

    class OptConfig:
        pass

    config.optimizers = OrderedDict()
    for name, opt in info.optimizers.items():
        o = OptConfig()
        o.optimizer = opt
        o.data_iterators = []
        config.optimizers[name] = o

    class MonConfig:
        pass

    config.monitors = OrderedDict()
    for name, mon in info.monitors.items():
        m = MonConfig()
        m.monitor = mon
        m.data_iterators = []
        config.monitors[name] = m

    # Training
    comm = current_communicator()
    config.training_config.iter_per_epoch //= comm.size if comm else 1
    max_iteration = config.training_config.max_epoch * \
        config.training_config.iter_per_epoch

    global _save_parameter_info
    _save_parameter_info = {}
    _, config_ext = os.path.splitext(args.config)
    if config_ext == '.prototxt' or config_ext == '.nntxt':
        _save_parameter_info['config'] = args.config
    elif config_ext == '.nnp':
        with zipfile.ZipFile(args.config, 'r') as nnp:
            for name in nnp.namelist():
                _, ext = os.path.splitext(name)
                if ext == '.nntxt' or ext == '.prototxt':
                    nnp.extract(name, args.outdir)
                    _save_parameter_info['config'] = os.path.join(
                        args.outdir, name)

    result = False
    restart = False
    if max_iteration > 0:
        rng = np.random.RandomState(comm.rank if comm else 0)
        with ExitStack() as stack:
            # Create data_iterator instance only once for each dataset in optimizers
            optimizer_data_iterators = {}
            for name, o in config.optimizers.items():
                for di in o.optimizer.data_iterators.values():
                    if di not in optimizer_data_iterators:
                        di_instance = stack.enter_context(di())
                        if comm and comm.size > 1:
                            di_instance = di_instance.slice(
                                rng, comm.size, comm.rank)
                        optimizer_data_iterators[di] = di_instance
                    else:
                        di_instance = optimizer_data_iterators[di]
                    o.data_iterators.append(di_instance)

            # Create data_iterator instance only once for each dataset in monitors
            monitor_data_iterators = {}
            for name, m in config.monitors.items():
                for di in m.monitor.data_iterators.values():
                    if di not in monitor_data_iterators:
                        di_instance = stack.enter_context(di())
                        if comm and comm.size > 1:
                            di_instance = di_instance.slice(
                                rng, comm.size, comm.rank)
                        monitor_data_iterators[di] = di_instance
                    else:
                        di_instance = monitor_data_iterators[di]
                    m.data_iterators.append(di_instance)
            monitor_data_iterators.update(optimizer_data_iterators)

            result, restart = _train(args, config)
    else:
        # save parameters without training (0 epoch learning)
        logger.log(99, '0 epoch learning. (Just save parameter.)')
        if single_or_rankzero():
            _save_parameters(args, None, 0, config, True)
        result = True

    if single_or_rankzero() and not restart:
        if result:
            logger.log(99, 'Training Completed.')
            callback.update_status('finished')
        else:
            logger.log(99, 'Training Incompleted.')
            callback.update_status('failed')
    if single_or_rankzero():
        progress(None)
    return True
예제 #4
0
def run_test(in_file, test_spec, global_cfg):
    """Run a single tavern test

    Note that each tavern test can consist of multiple requests (log in,
    create, update, delete, etc).

    The global configuration is copied and used as an initial configuration for
    this test. Any values which are saved from any tests are saved into this
    test block and can be used for formatting in later stages in the test.

    Args:
        in_file (str): filename containing this test
        test_spec (dict): The specification for this test
        global_cfg (dict): Any global configuration for this test

    Raises:
        TavernException: If any of the tests failed
    """

    # pylint: disable=too-many-locals

    # Initialise test config for this test with the global configuration before
    # starting
    test_block_config = dict(global_cfg)

    if "variables" not in test_block_config:
        test_block_config["variables"] = {}

    tavern_box = Box({
        "env_vars": dict(os.environ),
    })

    test_block_config["variables"]["tavern"] = tavern_box

    if not test_spec:
        logger.warning("Empty test block in %s", in_file)
        return

    if test_spec.get("includes"):
        for included in test_spec["includes"]:
            if "variables" in included:
                formatted_include = format_keys(included["variables"],
                                                {"tavern": tavern_box})
                test_block_config["variables"].update(formatted_include)

    test_block_name = test_spec["test_name"]

    # Strict on body by default
    default_strictness = test_block_config["strict"]

    logger.info("Running test : %s", test_block_name)

    with ExitStack() as stack:
        sessions = get_extra_sessions(test_spec, test_block_config)

        for name, session in sessions.items():
            logger.debug("Entering context for %s", name)
            stack.enter_context(session)

        # Run tests in a path in order
        for stage in test_spec["stages"]:
            if stage.get('skip'):
                continue

            test_block_config["strict"] = default_strictness

            # Can be overridden per stage
            # NOTE
            # this is hardcoded to check for the 'response' block. In the far
            # future there might not be a response block, but at the moment it
            # is the hardcoded value for any HTTP request.
            if stage.get("response", {}):
                if stage.get("response").get("strict", None) is not None:
                    stage_strictness = stage.get("response").get(
                        "strict", None)
                elif test_spec.get("strict", None) is not None:
                    stage_strictness = test_spec.get("strict", None)
                else:
                    stage_strictness = default_strictness

                logger.debug("Strict key checking for this stage is '%s'",
                             stage_strictness)

                test_block_config["strict"] = stage_strictness
            elif default_strictness:
                logger.debug("Default strictness '%s' ignored for this stage",
                             default_strictness)

            try:
                run_stage(sessions, stage, tavern_box, test_block_config)
            except exceptions.TavernException as e:
                e.stage = stage
                e.test_block_config = test_block_config
                raise

            if stage.get('only'):
                break
예제 #5
0
파일: core.py 프로젝트: rterbush/catalyst
    def ingest(name,
               environ=os.environ,
               timestamp=None,
               assets_versions=(),
               show_progress=False):
        """Ingest data for a given bundle.

        Parameters
        ----------
        name : str
            The name of the bundle.
        environ : mapping, optional
            The environment variables. By default this is os.environ.
        timestamp : datetime, optional
            The timestamp to use for the load.
            By default this is the current time.
        assets_versions : Iterable[int], optional
            Versions of the assets db to which to downgrade.
        show_progress : bool, optional
            Tell the ingest function to display the progress where possible.
        """
        try:
            bundle = bundles[name]
        except KeyError:
            raise UnknownBundle(name)

        calendar = get_calendar(bundle.calendar_name)

        start_session = bundle.start_session
        end_session = bundle.end_session

        if start_session is None or start_session < calendar.first_session:
            start_session = calendar.first_session

        if end_session is None or end_session > calendar.last_session:
            end_session = calendar.last_session

        if timestamp is None:
            timestamp = pd.Timestamp.utcnow()
        timestamp = timestamp.tz_convert('utc').tz_localize(None)

        timestr = to_bundle_ingest_dirname(timestamp)
        cachepath = cache_path(name, environ=environ)
        pth.ensure_directory(pth.data_path([name, timestr], environ=environ))
        pth.ensure_directory(cachepath)
        with dataframe_cache(cachepath, clean_on_failure=False) as cache, \
                ExitStack() as stack:
            # we use `cleanup_on_failure=False` so that we don't purge the
            # cache directory if the load fails in the middle
            if bundle.create_writers:
                wd = stack.enter_context(
                    working_dir(pth.data_path([], environ=environ)))
                daily_bars_path = wd.ensure_dir(*daily_equity_relative(
                    name,
                    timestr,
                    environ=environ,
                ))
                daily_bar_writer = BcolzDailyBarWriter(
                    daily_bars_path,
                    calendar,
                    start_session,
                    end_session,
                )
                # Do an empty write to ensure that the daily ctables exist
                # when we create the SQLiteAdjustmentWriter below. The
                # SQLiteAdjustmentWriter needs to open the daily ctables so
                # that it can compute the adjustment ratios for the dividends.

                daily_bar_writer.write(())
                minute_bar_writer = BcolzMinuteBarWriter(
                    wd.ensure_dir(*minute_equity_relative(
                        name, timestr, environ=environ)),
                    calendar,
                    start_session,
                    end_session,
                    minutes_per_day=bundle.minutes_per_day,
                )
                assets_db_path = wd.getpath(*asset_db_relative(
                    name,
                    timestr,
                    environ=environ,
                ))
                asset_db_writer = AssetDBWriter(assets_db_path)

                adjustment_db_writer = stack.enter_context(
                    SQLiteAdjustmentWriter(
                        wd.getpath(*adjustment_db_relative(
                            name, timestr, environ=environ)),
                        BcolzDailyBarReader(daily_bars_path),
                        calendar.all_sessions,
                        overwrite=True,
                    ))
            else:
                daily_bar_writer = None
                minute_bar_writer = None
                asset_db_writer = None
                adjustment_db_writer = None
                if assets_versions:
                    raise ValueError('Need to ingest a bundle that creates '
                                     'writers in order to downgrade the assets'
                                     ' db.')
            bundle.ingest(
                environ,
                asset_db_writer,
                minute_bar_writer,
                daily_bar_writer,
                adjustment_db_writer,
                calendar,
                start_session,
                end_session,
                cache,
                show_progress,
                pth.data_path([name, timestr], environ=environ),
            )

            for version in sorted(set(assets_versions), reverse=True):
                version_path = wd.getpath(*asset_db_relative(
                    name,
                    timestr,
                    environ=environ,
                    db_version=version,
                ))
                with working_file(version_path) as wf:
                    shutil.copy2(assets_db_path, wf.path)
                    downgrade(wf.path, version)
예제 #6
0
def open_files(files, **kwargs):
    """A plural form of :func:`open_file`."""
    with ExitStack() as stack:
        yield [stack.enter_context(open_file(f, **kwargs)) for f in files]
예제 #7
0
def __nested(context_managers):
    with ExitStack() as stack:
        yield tuple(stack.enter_context(c) for c in context_managers)
예제 #8
0
 def db_lock():
     return ExitStack()
예제 #9
0
파일: delegate.py 프로젝트: hc621/ABZipline
 def ctx(self, *args, **kwargs):
     with ExitStack() as stack:
         for hook in self._hooks:
             sub_ctx = getattr(hook, method_name)(*args, **kwargs)
             stack.enter_context(sub_ctx)
         yield stack
예제 #10
0
def stack():
    """Provide a cleanup stack to use in the test (without indentation)."""
    with ExitStack() as stack:
        yield stack
예제 #11
0
 def prepared_request():
     # If there are open files, create a context manager around each so
     # they will be closed at the end of the request.
     with ExitStack() as stack:
         self._request_args.update(self._get_file_arguments(stack))
         return session.request(**self._request_args)
예제 #12
0
def do_dedup(sess, tt, chunk):

    #print ">>> do dedup ", chunk[0].size, chunk[0].mini_hash, len(chunk)

    global ofile_soft
    global ofile_hard
    global ofile_reserved
    global fs

    files = []
    fds = []
    fd_names = {}
    fd_inodes = {}
    by_hash = collections.defaultdict(list)

    # XXX I have no justification for doubling count3
    ofile_req = 2 * len(chunk) + ofile_reserved
    if ofile_req > ofile_soft:
        if ofile_req <= ofile_hard:
            resource.setrlimit(resource.RLIMIT_OFILE, (ofile_req, ofile_hard))
            ofile_soft = ofile_req
        else:
            tt.notify('Too many duplicates (%d at size %d), '
                      'would bring us over the open files limit (%d, %d).' %
                      (count3, comm3.size, ofile_soft, ofile_hard))
            for inode in comm3.inodes:
                if inode.has_updates:
                    skipped.append(inode)
                    continue

    for inode in chunk:
        # Open everything rw, we can't pick one for the source side
        # yet because the crypto hash might eliminate it.
        # We may also want to defragment the source.
        try:
            path = lookup_ino_path_one(inode.vol.fd, inode.ino)
        except IOError as e:
            if e.errno == errno.ENOENT:
                sess.delete(inode)
                continue
            raise
        try:
            afile = fopenat_rw(inode.vol.fd, path)
        except IOError as e:
            if e.errno == errno.ETXTBSY:
                # The file contains the image of a running process,
                # we can't open it in write mode.
                tt.notify('File %r is busy, skipping' % path)
                skipped.append(inode)
                continue
            elif e.errno == errno.EACCES:
                # Could be SELinux or immutability
                tt.notify('Access denied on %r, skipping' % path)
                skipped.append(inode)
                continue
            elif e.errno == errno.ENOENT:
                # The file was moved or unlinked by a racing process
                tt.notify('File %r may have moved, skipping' % path)
                skipped.append(inode)
                continue
            raise

        # It's not completely guaranteed we have the right inode,
        # there may still be race conditions at this point.
        # Gets re-checked below (tell and fstat).
        fd = afile.fileno()
        fd_inodes[fd] = inode
        fd_names[fd] = path
        files.append(afile)
        fds.append(fd)

    with ExitStack() as stack:
        for afile in files:
            stack.enter_context(closing(afile))
        # Enter this context last
        immutability = stack.enter_context(ImmutableFDs(fds))

        for afile in files:
            fd = afile.fileno()
            inode = fd_inodes[fd]
            if fd in immutability.fds_in_write_use:
                tt.notify('File %r is in use, skipping' % fd_names[fd])
                skipped.append(inode)
                continue
            hasher = hashlib.sha1()
            for buf in iter(lambda: afile.read(BUFSIZE), b''):
                hasher.update(buf)

            # Gets rid of a race condition
            st = os.fstat(fd)
            if st.st_ino != inode.ino:
                skipped.append(inode)
                continue
            if st.st_dev != inode.vol.st_dev:
                skipped.append(inode)
                continue

            size = afile.tell()
            if size != inode.size:
                if size < inode.vol.size_cutoff:
                    # if we didn't delete this inode, it would cause
                    # spurious comm groups in all future invocations.
                    sess.delete(inode)
                else:
                    skipped.append(inode)
                continue

            by_hash[hasher.digest()].append(afile)

        for fileset in by_hash.itervalues():
            if len(fileset) < 2:
                continue
            sfile = fileset[0]
            sfd = sfile.fileno()
            # Commented out, defragmentation can unshare extents.
            # It can also disable compression as a side-effect.
            if False:
                defragment(sfd)
            dfiles = fileset[1:]
            dfiles_successful = []
            for dfile in dfiles:
                dfd = dfile.fileno()
                sname = fd_names[sfd]
                dname = fd_names[dfd]
                if not cmp_files(sfile, dfile):
                    # Probably a bug since we just used a crypto hash
                    tt.notify('Files differ: %r %r' % (sname, dname))
                    assert False, (sname, dname)
                    continue
                if clone_data(dest=dfd, src=sfd, check_first=True):
                    tt.notify('Deduplicated: %r %r' % (sname, dname))
                    dfiles_successful.append(dfile)
                else:
                    tt.notify(
                        'Did not deduplicate (same extents): %r %i %r %i' %
                        (sname, fd_inodes[sfd].ino, dname, fd_inodes[dfd].ino))
            if dfiles_successful:
                evt = DedupEvent(fs=fs,
                                 item_size=inode.size,
                                 created=system_now())
                sess.add(evt)
                for afile in [sfile] + dfiles_successful:
                    inode = fd_inodes[afile.fileno()]
                    evti = DedupEventInode(event=evt,
                                           ino=inode.ino,
                                           vol=inode.vol)
                    sess.add(evti)
                sess.commit()
예제 #13
0
def dedup_tracked1(sess, tt, ofile_reserved, query, fs, skipped):
    space_gain1 = space_gain2 = space_gain3 = 0
    ofile_soft, ofile_hard = resource.getrlimit(resource.RLIMIT_OFILE)

    # Hopefully close any files we left around
    gc.collect()

    # The log can cause frequent commits, we don't mind losing them in
    # a crash (no need for durability). SQLite is in WAL mode, so this pragma
    # should disable most commit-time fsync calls without compromising
    # consistency.
    sess.execute('PRAGMA synchronous=NORMAL;')

    for comm1 in query:
        if len(sess.identity_map) > 300:
            sess.flush()

        space_gain1 += comm1.size * (comm1.inode_count - 1)
        tt.update(comm1=comm1)
        for inode in comm1.inodes:
            # XXX Need to cope with deleted inodes.
            # We cannot find them in the search-new pass, not without doing
            # some tracking of directory modifications to poke updated
            # directories to find removed elements.

            # rehash everytime for now
            # I don't know enough about how inode transaction numbers are
            # updated (as opposed to extent updates) to be able to actually
            # cache the result
            try:
                path = lookup_ino_path_one(inode.vol.fd, inode.ino)
            except IOError as e:
                if e.errno != errno.ENOENT:
                    raise
                # We have a stale record for a removed inode
                # XXX If an inode number is reused and the second instance
                # is below the size cutoff, we won't update the .size
                # attribute and we won't get an IOError to notify us
                # either.  Inode reuse does happen (with and without
                # inode_cache), so this branch isn't enough to rid us of
                # all stale entries.  We can also get into trouble with
                # regular file inodes being replaced by some other kind of
                # inode.
                sess.delete(inode)
                continue
            with closing(fopenat(inode.vol.fd, path)) as rfile:
                inode.mini_hash_from_file(rfile)

        for comm2 in comm1.comm2:
            space_gain2 += comm2.size * (comm2.inode_count - 1)
            tt.update(comm2=comm2)
            for inode in comm2.inodes:
                try:
                    path = lookup_ino_path_one(inode.vol.fd, inode.ino)
                except IOError as e:
                    if e.errno != errno.ENOENT:
                        raise
                    sess.delete(inode)
                    continue
                with closing(fopenat(inode.vol.fd, path)) as rfile:
                    inode.fiemap_hash_from_file(rfile)

            if not comm2.comm3:
                continue

            comm3, = comm2.comm3
            count3 = comm3.inode_count
            space_gain3 += comm3.size * (count3 - 1)
            tt.update(comm3=comm3)
            files = []
            fds = []
            fd_names = {}
            fd_inodes = {}
            by_hash = collections.defaultdict(list)

            # XXX I have no justification for doubling count3
            ofile_req = 2 * count3 + ofile_reserved
            if ofile_req > ofile_soft:
                if ofile_req <= ofile_hard:
                    resource.setrlimit(resource.RLIMIT_OFILE,
                                       (ofile_req, ofile_hard))
                    ofile_soft = ofile_req
                else:
                    tt.notify(
                        'Too many duplicates (%d at size %d), '
                        'would bring us over the open files limit (%d, %d).' %
                        (count3, comm3.size, ofile_soft, ofile_hard))
                    for inode in comm3.inodes:
                        if inode.has_updates:
                            skipped.append(inode)
                    continue

            for inode in comm3.inodes:
                # Open everything rw, we can't pick one for the source side
                # yet because the crypto hash might eliminate it.
                # We may also want to defragment the source.
                try:
                    path = lookup_ino_path_one(inode.vol.fd, inode.ino)
                except IOError as e:
                    if e.errno == errno.ENOENT:
                        sess.delete(inode)
                        continue
                    raise
                try:
                    afile = fopenat_rw(inode.vol.fd, path)
                except IOError as e:
                    if e.errno == errno.ETXTBSY:
                        # The file contains the image of a running process,
                        # we can't open it in write mode.
                        tt.notify('File %r is busy, skipping' % path)
                        skipped.append(inode)
                        continue
                    elif e.errno == errno.EACCES:
                        # Could be SELinux or immutability
                        tt.notify('Access denied on %r, skipping' % path)
                        skipped.append(inode)
                        continue
                    elif e.errno == errno.ENOENT:
                        # The file was moved or unlinked by a racing process
                        tt.notify('File %r may have moved, skipping' % path)
                        skipped.append(inode)
                        continue
                    raise

                # It's not completely guaranteed we have the right inode,
                # there may still be race conditions at this point.
                # Gets re-checked below (tell and fstat).
                fd = afile.fileno()
                fd_inodes[fd] = inode
                fd_names[fd] = path
                files.append(afile)
                fds.append(fd)

            with ExitStack() as stack:
                for afile in files:
                    stack.enter_context(closing(afile))
                # Enter this context last
                immutability = stack.enter_context(ImmutableFDs(fds))

                for afile in files:
                    fd = afile.fileno()
                    inode = fd_inodes[fd]
                    if fd in immutability.fds_in_write_use:
                        tt.notify('File %r is in use, skipping' % fd_names[fd])
                        skipped.append(inode)
                        continue
                    hasher = hashlib.sha1()
                    for buf in iter(lambda: afile.read(BUFSIZE), b''):
                        hasher.update(buf)

                    # Gets rid of a race condition
                    st = os.fstat(fd)
                    if st.st_ino != inode.ino:
                        skipped.append(inode)
                        continue
                    if st.st_dev != inode.vol.st_dev:
                        skipped.append(inode)
                        continue

                    size = afile.tell()
                    if size != comm3.size:
                        if size < inode.vol.size_cutoff:
                            # if we didn't delete this inode, it would cause
                            # spurious comm groups in all future invocations.
                            sess.delete(inode)
                        else:
                            skipped.append(inode)
                        continue

                    by_hash[hasher.digest()].append(afile)

                for fileset in by_hash.itervalues():
                    if len(fileset) < 2:
                        continue
                    sfile = fileset[0]
                    sfd = sfile.fileno()
                    # Commented out, defragmentation can unshare extents.
                    # It can also disable compression as a side-effect.
                    if False:
                        defragment(sfd)
                    dfiles = fileset[1:]
                    dfiles_successful = []
                    for dfile in dfiles:
                        dfd = dfile.fileno()
                        sname = fd_names[sfd]
                        dname = fd_names[dfd]
                        if not cmp_files(sfile, dfile):
                            # Probably a bug since we just used a crypto hash
                            tt.notify('Files differ: %r %r' % (sname, dname))
                            assert False, (sname, dname)
                            continue
                        if clone_data(dest=dfd, src=sfd, check_first=True):
                            tt.notify('Deduplicated: %r %r' % (sname, dname))
                            dfiles_successful.append(dfile)
                        else:
                            tt.notify(
                                'Did not deduplicate (same extents): %r %r' %
                                (sname, dname))
                    if dfiles_successful:
                        evt = DedupEvent(fs=fs,
                                         item_size=comm3.size,
                                         created=system_now())
                        sess.add(evt)
                        for afile in [sfile] + dfiles_successful:
                            inode = fd_inodes[afile.fileno()]
                            evti = DedupEventInode(event=evt,
                                                   ino=inode.ino,
                                                   vol=inode.vol)
                            sess.add(evti)
                        sess.commit()

    tt.format(None)
    tt.notify('Potential space gain: pass 1 %d, pass 2 %d pass 3 %d' %
              (space_gain1, space_gain2, space_gain3))
    # Restore fsync so that the final commit (in dedup_tracked)
    # will be durable.
    sess.commit()
    sess.execute('PRAGMA synchronous=FULL;')
예제 #14
0
def run_test(in_file, test_spec, global_cfg):
    """Run a single tavern test

    Note that each tavern test can consist of multiple requests (log in,
    create, update, delete, etc).

    The global configuration is copied and used as an initial configuration for
    this test. Any values which are saved from any tests are saved into this
    test block and can be used for formatting in later stages in the test.

    Args:
        in_file (str): filename containing this test
        test_spec (dict): The specification for this test
        global_cfg (dict): Any global configuration for this test

    Raises:
        TavernException: If any of the tests failed
    """

    # pylint: disable=too-many-locals

    # Initialise test config for this test with the global configuration before
    # starting
    test_block_config = dict(global_cfg)

    if "variables" not in test_block_config:
        test_block_config["variables"] = {}

    tavern_box = Box({
        "env_vars": dict(os.environ),
    })

    test_block_config["variables"]["tavern"] = tavern_box

    if not test_spec:
        logger.warning("Empty test block in %s", in_file)
        return

    if test_spec.get("includes"):
        for included in test_spec["includes"]:
            if "variables" in included:
                formatted_include = format_keys(included["variables"],
                                                {"tavern": tavern_box})
                test_block_config["variables"].update(formatted_include)

    test_block_name = test_spec["test_name"]

    logger.info("Running test : %s", test_block_name)

    with ExitStack() as stack:
        sessions = get_extra_sessions(test_spec)

        for name, session in sessions.items():
            logger.debug("Entering context for %s", name)
            stack.enter_context(session)

        # Run tests in a path in order
        for stage in test_spec["stages"]:
            name = stage["name"]

            try:
                r = get_request_type(stage, test_block_config, sessions)
            except exceptions.MissingFormatError:
                log_fail(stage, None, None)
                raise

            tavern_box.update(request_vars=r.request_vars)

            try:
                expected = get_expected(stage, test_block_config, sessions)
            except exceptions.TavernException:
                log_fail(stage, None, None)
                raise

            delay(stage, "before")

            logger.info("Running stage : %s", name)

            try:
                response = r.run()
            except exceptions.TavernException:
                log_fail(stage, None, expected)
                raise

            verifiers = get_verifiers(stage, test_block_config, sessions,
                                      expected)

            for v in verifiers:
                try:
                    saved = v.verify(response)
                except exceptions.TavernException:
                    log_fail(stage, v, expected)
                    raise
                else:
                    test_block_config["variables"].update(saved)

            log_pass(stage, verifiers)

            tavern_box.pop("request_vars")
            delay(stage, "after")
예제 #15
0
 def settings(*args, **kwargs):
     with ExitStack() as stack:
         yield tuple(stack.enter_context(cm) for cm in [hide('running'), msg(txt)])
예제 #16
0
def run_test(in_file, test_spec, global_cfg):
    """Run a single tavern test

    Note that each tavern test can consist of multiple requests (log in,
    create, update, delete, etc).

    The global configuration is copied and used as an initial configuration for
    this test. Any values which are saved from any tests are saved into this
    test block and can be used for formatting in later stages in the test.

    Args:
        in_file (str): filename containing this test
        test_spec (dict): The specification for this test
        global_cfg (dict): Any global configuration for this test

    No Longer Raises:
        TavernException: If any of the tests failed
    """

    # pylint: disable=too-many-locals

    # Initialise test config for this test with the global configuration before
    # starting
    test_block_config = dict(global_cfg)

    if "variables" not in test_block_config:
        test_block_config["variables"] = {}

    tavern_box = Box({
        "env_vars": dict(os.environ),
    })

    if not test_spec:
        logger.warning("Empty test block in %s", in_file)
        return

    def stage_ids(s):
        return [i["id"] for i in s]

    available_stages = test_block_config.get("stages", [])

    if test_spec.get("includes"):
        # Need to do this separately here so there is no confusion between global and included stages
        for included in test_spec["includes"]:
            for stage in included.get("stages", {}):
                if stage["id"] in stage_ids(available_stages):
                    msg = "Stage id '{}' defined in stage-included test which was already defined in global configuration - this will be an error in future!".format(
                        stage["id"])
                    logger.warning(msg)
                    warnings.warn(msg, FutureWarning)

        included_stages = []

        for included in test_spec["includes"]:
            if "variables" in included:
                formatted_include = format_keys(included["variables"],
                                                {"tavern": tavern_box})
                test_block_config["variables"].update(formatted_include)

            for stage in included.get("stages", []):
                if stage["id"] in stage_ids(included_stages):
                    raise exceptions.DuplicateStageDefinitionError(
                        "Stage with specified id already defined: {}".format(
                            stage["id"]))
                included_stages.append(stage)
    else:
        included_stages = []

    available_stages = {s["id"]: s for s in available_stages + included_stages}

    test_block_config["variables"]["tavern"] = tavern_box

    test_block_name = test_spec["test_name"]

    # Strict on body by default
    default_strictness = test_block_config["strict"]

    logger.info("Running test : %s", test_block_name)

    with ExitStack() as stack:
        test_spec["stages"] = _resolve_test_stages(test_spec, available_stages)
        sessions = get_extra_sessions(test_spec, test_block_config)

        for name, session in sessions.items():
            logger.debug("Entering context for %s", name)
            stack.enter_context(session)

        # Run tests in a path in order
        for stage in test_spec["stages"]:
            if stage.get('skip'):
                continue

            test_block_config["strict"] = default_strictness

            # Can be overridden per stage
            # NOTE
            # this is hardcoded to check for the 'response' block. In the far
            # future there might not be a response block, but at the moment it
            # is the hardcoded value for any HTTP request.
            if stage.get("response", {}):
                if stage.get("response").get("strict", None) is not None:
                    stage_strictness = stage.get("response").get(
                        "strict", None)
                elif test_spec.get("strict", None) is not None:
                    stage_strictness = test_spec.get("strict", None)
                else:
                    stage_strictness = default_strictness

                logger.debug("Strict key checking for this stage is '%s'",
                             stage_strictness)

                test_block_config["strict"] = stage_strictness
            elif default_strictness:
                logger.debug("Default strictness '%s' ignored for this stage",
                             default_strictness)

            # Wrap run_stage with retry helper
            run_stage_with_retries = retry(stage)(run_stage)

            try:
                run_stage_with_retries(sessions, stage, tavern_box,
                                       test_block_config)
            except exceptions.TavernException as e:
                e.stage = stage
                e.test_block_config = test_block_config
                raise

            if stage.get('only'):
                break
예제 #17
0
    def ingest(name, environ=os.environ, timestamp=None, show_progress=False):
        """Ingest data for a given bundle.

        Parameters
        ----------
        name : str
            The name of the bundle.
        environ : mapping, optional
            The environment variables. By default this is os.environ.
        timestamp : datetime, optional
            The timestamp to use for the load.
            By default this is the current time.
        show_progress : bool, optional
            Tell the ingest function to display the progress where possible.
        """
        try:
            bundle = bundles[name]
        except KeyError:
            raise UnknownBundle(name)

        if timestamp is None:
            timestamp = pd.Timestamp.utcnow()
        timestamp = timestamp.tz_convert('utc').tz_localize(None)
        timestr = to_bundle_ingest_dirname(timestamp)
        cachepath = cache_path(name, environ=environ)
        pth.ensure_directory(pth.data_path([name, timestr], environ=environ))
        pth.ensure_directory(cachepath)

        with dataframe_cache(cachepath, clean_on_failure=False) as cache, \
                ExitStack() as stack:
            # we use `cleanup_on_failure=False` so that we don't purge the
            # cache directory if the load fails in the middle

            if bundle.create_writers:
                daily_bars_path = stack.enter_context(
                    working_dir(
                        daily_equity_path(name, timestr,
                                          environ=environ), )).path
                daily_bar_writer = BcolzDailyBarWriter(
                    daily_bars_path,
                    bundle.calendar,
                )
                # Do an empty write to ensure that the daily ctables exist
                # when we create the SQLiteAdjustmentWriter below. The
                # SQLiteAdjustmentWriter needs to open the daily ctables so
                # that it can compute the adjustment ratios for the dividends.
                daily_bar_writer.write(())
                minute_bar_writer = BcolzMinuteBarWriter(
                    bundle.calendar[0],
                    stack.enter_context(
                        working_dir(
                            minute_equity_path(name, timestr,
                                               environ=environ), )).path,
                    bundle.opens,
                    bundle.closes,
                    minutes_per_day=bundle.minutes_per_day,
                )
                asset_db_writer = AssetDBWriter(
                    stack.enter_context(
                        working_file(
                            asset_db_path(name, timestr,
                                          environ=environ), )).path, )
                adjustment_db_writer = SQLiteAdjustmentWriter(
                    stack.enter_context(
                        working_file(
                            adjustment_db_path(name, timestr,
                                               environ=environ), )).path,
                    BcolzDailyBarReader(daily_bars_path),
                    bundle.calendar,
                    overwrite=True,
                )
            else:
                daily_bar_writer = None
                minute_bar_writer = None
                asset_db_writer = None
                adjustment_db_writer = None

            bundle.ingest(
                environ,
                asset_db_writer,
                minute_bar_writer,
                daily_bar_writer,
                adjustment_db_writer,
                bundle.calendar,
                cache,
                show_progress,
                pth.data_path([name, timestr], environ=environ),
            )
예제 #18
0
    def save_processed_models(cls,
                              processed_forms,
                              cases=None,
                              stock_result=None):
        db_names = {processed_forms.submitted.db}
        if processed_forms.deprecated:
            db_names |= {processed_forms.deprecated.db}

        if cases:
            db_names |= {case.db for case in cases}

        if stock_result:
            db_names |= {
                ledger_value.db
                for ledger_value in stock_result.models_to_save
            }

        all_models = filter(
            None,
            chain(
                processed_forms,
                cases or [],
                stock_result.models_to_save if stock_result else [],
            ))
        try:
            with ExitStack() as stack:
                for db_name in db_names:
                    stack.enter_context(transaction.atomic(db_name))

                # Save deprecated form first to avoid ID conflicts
                if processed_forms.deprecated:
                    FormAccessorSQL.update_form(processed_forms.deprecated,
                                                publish_changes=False)

                FormAccessorSQL.save_new_form(processed_forms.submitted)
                if cases:
                    for case in cases:
                        CaseAccessorSQL.save_case(case)

                if stock_result:
                    ledgers_to_save = stock_result.models_to_save
                    LedgerAccessorSQL.save_ledger_values(
                        ledgers_to_save, stock_result)

            if cases:
                sort_submissions = toggles.SORT_OUT_OF_ORDER_FORM_SUBMISSIONS_SQL.enabled(
                    processed_forms.submitted.domain, toggles.NAMESPACE_DOMAIN)
                if sort_submissions:
                    for case in cases:
                        if SqlCaseUpdateStrategy(
                                case).reconcile_transactions_if_necessary():
                            CaseAccessorSQL.save_case(case)
        except DatabaseError:
            for model in all_models:
                setattr(model, model._meta.pk.attname, None)
                for tracked in model.create_models:
                    setattr(tracked, tracked._meta.pk.attname, None)
            raise

        try:
            cls.publish_changes_to_kafka(processed_forms, cases, stock_result)
        except Exception as e:
            raise KafkaPublishingError(e)
예제 #19
0
    def transform(self, stream_in):
        """
        Main generator work loop.
        """
        # Initialize the mkt_close
        mkt_open = self.algo.perf_tracker.market_open
        mkt_close = self.algo.perf_tracker.market_close

        # inject the current algo
        # snapshot time to any log record generated.

        with ExitStack() as stack:
            stack.enter_context(self.processor)
            stack.enter_context(ZiplineAPI(self.algo))

            data_frequency = self.sim_params.data_frequency

            self._call_before_trading_start(mkt_open)

            for date, snapshot in stream_in:

                self.simulation_dt = date
                self.on_dt_changed(date)

                # If we're still in the warmup period.  Use the event to
                # update our universe, but don't yield any perf messages,
                # and don't send a snapshot to handle_data.
                if date < self.algo_start:
                    for event in snapshot:
                        if event.type == DATASOURCE_TYPE.SPLIT:
                            self.algo.blotter.process_split(event)

                        elif event.type == DATASOURCE_TYPE.TRADE:
                            self.update_universe(event)
                            self.algo.perf_tracker.process_trade(event)
                        elif event.type == DATASOURCE_TYPE.CUSTOM:
                            self.update_universe(event)

                else:
                    messages = self._process_snapshot(
                        date,
                        snapshot,
                        self.algo.instant_fill,
                    )
                    # Perf messages are only emitted if the snapshot contained
                    # a benchmark event.
                    for message in messages:
                        yield message

                    # When emitting minutely, we need to call
                    # before_trading_start before the next trading day begins
                    if date == mkt_close:
                        if mkt_close <= self.algo.perf_tracker.last_close:
                            before_last_close = \
                                mkt_close < self.algo.perf_tracker.last_close
                            try:
                                mkt_open, mkt_close = \
                                    self.env.next_open_and_close(mkt_close)

                            except NoFurtherDataError:
                                # If at the end of backtest history,
                                # skip advancing market close.
                                pass

                            if before_last_close:
                                self._call_before_trading_start(mkt_open)

                    elif data_frequency == 'daily':
                        next_day = self.env.next_trading_day(date)

                        if next_day is not None and \
                           next_day < self.algo.perf_tracker.last_close:
                            self._call_before_trading_start(next_day)

                    self.algo.portfolio_needs_update = True
                    self.algo.account_needs_update = True
                    self.algo.performance_needs_update = True

            risk_message = self.algo.perf_tracker.handle_simulation_end()
            yield risk_message
예제 #20
0
    def transform(self):
        """
        Main generator work loop.
        """
        algo = self.algo
        emission_rate = algo.perf_tracker.emission_rate

        def every_bar(dt_to_use,
                      current_data=self.current_data,
                      handle_data=algo.event_manager.handle_data):
            # called every tick (minute or day).
            algo.on_dt_changed(dt_to_use)

            for capital_change in calculate_minute_capital_changes(dt_to_use):
                yield capital_change

            self.simulation_dt = dt_to_use

            blotter = algo.blotter
            perf_tracker = algo.perf_tracker

            # handle any transactions and commissions coming out new orders
            # placed in the last bar
            new_transactions, new_commissions, closed_orders = \
                blotter.get_transactions(current_data)

            blotter.prune_orders(closed_orders)

            for transaction in new_transactions:
                perf_tracker.process_transaction(transaction)

                # since this order was modified, record it
                order = blotter.orders[transaction.order_id]
                perf_tracker.process_order(order)

            if new_commissions:
                for commission in new_commissions:
                    perf_tracker.process_commission(commission)

            handle_data(algo, current_data, dt_to_use)

            # grab any new orders from the blotter, then clear the list.
            # this includes cancelled orders.
            new_orders = blotter.new_orders
            blotter.new_orders = []

            # if we have any new orders, record them so that we know
            # in what perf period they were placed.
            if new_orders:
                for new_order in new_orders:
                    perf_tracker.process_order(new_order)

            algo.portfolio_needs_update = True
            algo.account_needs_update = True
            algo.performance_needs_update = True

        def once_a_day(midnight_dt,
                       current_data=self.current_data,
                       data_portal=self.data_portal):

            perf_tracker = algo.perf_tracker

            # Get the positions before updating the date so that prices are
            # fetched for trading close instead of midnight
            positions = algo.perf_tracker.position_tracker.positions
            position_assets = algo.asset_finder.retrieve_all(positions)

            # set all the timestamps
            self.simulation_dt = midnight_dt
            algo.on_dt_changed(midnight_dt)

            # process any capital changes that came overnight
            for capital_change in algo.calculate_capital_changes(
                    midnight_dt, emission_rate=emission_rate,
                    is_interday=True):
                yield capital_change

            # we want to wait until the clock rolls over to the next day
            # before cleaning up expired assets.
            self._cleanup_expired_assets(midnight_dt, position_assets)

            # handle any splits that impact any positions or any open orders.
            assets_we_care_about = \
                viewkeys(perf_tracker.position_tracker.positions) | \
                viewkeys(algo.blotter.open_orders)

            if assets_we_care_about:
                splits = data_portal.get_splits(assets_we_care_about,
                                                midnight_dt)
                if splits:
                    algo.blotter.process_splits(splits)
                    perf_tracker.position_tracker.handle_splits(splits)

        def handle_benchmark(date, benchmark_source=self.benchmark_source):
            algo.perf_tracker.all_benchmark_returns[date] = \
                benchmark_source.get_value(date)

        def on_exit():
            # Remove references to algo, data portal, et al to break cycles
            # and ensure deterministic cleanup of these objects when the
            # simulation finishes.
            self.algo = None
            self.benchmark_source = self.current_data = self.data_portal = None

        with ExitStack() as stack:
            stack.callback(on_exit)
            stack.enter_context(self.processor)
            stack.enter_context(ZiplineAPI(self.algo))

            if algo.data_frequency in set(('minute', '5-minute')):

                def execute_order_cancellation_policy():
                    algo.blotter.execute_cancel_policy(SESSION_END)

                def calculate_minute_capital_changes(dt):
                    # process any capital changes that came between the last
                    # and current minutes
                    return algo.calculate_capital_changes(
                        dt, emission_rate=emission_rate, is_interday=False)
            else:

                def execute_order_cancellation_policy():
                    pass

                def calculate_minute_capital_changes(dt):
                    return []

            for dt, action in self.clock:
                if action == BAR:
                    for capital_change_packet in every_bar(dt):
                        yield capital_change_packet
                elif action == SESSION_START:
                    for capital_change_packet in once_a_day(dt):
                        yield capital_change_packet
                elif action == SESSION_END:
                    # End of the session.
                    if emission_rate == 'daily':
                        handle_benchmark(normalize_date(dt))
                    execute_order_cancellation_policy()

                    yield self._get_daily_message(dt, algo, algo.perf_tracker)
                elif action == BEFORE_TRADING_START_BAR:
                    self.simulation_dt = dt
                    algo.on_dt_changed(dt)
                    algo.before_trading_start(self.current_data)
                elif action == MINUTE_END:
                    handle_benchmark(dt)
                    minute_msg = \
                        self._get_minute_message(dt, algo, algo.perf_tracker)

                    yield minute_msg

        risk_message = algo.perf_tracker.handle_simulation_end()
        yield risk_message
예제 #21
0
def compare_with_cpu_command(args):
    configure_progress(os.path.join(args.outdir, 'progress.txt'))

    class TrainConfig:
        pass

    class OptConfig:
        pass

    class MonConfig:
        pass

    # Load config with current context
    files = []
    files.append(args.config)

    with nn.parameter_scope('current'):
        info = load.load(files)
        parameters = get_parameters(grad_only=False)

    config = TrainConfig()
    config.global_config = info.global_config
    config.training_config = info.training_config

    config.optimizers = OrderedDict()
    for name, opt in info.optimizers.items():
        o = OptConfig()
        o.optimizer = opt
        o.data_iterator = None
        config.optimizers[name] = o

    config.monitors = OrderedDict()
    for name, mon in info.monitors.items():
        m = MonConfig()
        m.monitor = mon
        m.data_iterator = None
        config.monitors[name] = m

    # Load config with cpu context
    files = []
    files.append(args.config2)

    with nn.parameter_scope('cpu'):
        info_cpu = load.load(files)
        cpu_parameters = get_parameters(grad_only=False)

    config_cpu = TrainConfig()
    config_cpu.global_config = info_cpu.global_config
    config_cpu.training_config = info_cpu.training_config

    config_cpu.optimizers = OrderedDict()
    for name, opt in info_cpu.optimizers.items():
        o = OptConfig()
        o.optimizer = opt
        o.data_iterator = None
        config_cpu.optimizers[name] = o

    config_cpu.monitors = OrderedDict()
    for name, mon in info_cpu.monitors.items():
        m = MonConfig()
        m.monitor = mon
        m.data_iterator = None
        config_cpu.monitors[name] = m

    result_array = [['1-Correl']]

    # Profile Optimizer
    with ExitStack() as stack:
        for name, o in config.optimizers.items():
            o.data_iterator = stack.enter_context(o.optimizer.data_iterator())
        for name, o in config_cpu.optimizers.items():
            o.data_iterator = stack.enter_context(o.optimizer.data_iterator())
        result_array = compare_optimizer(config, parameters, config_cpu,
                                         cpu_parameters, result_array)

    # Write profiling result
    import csv
    with open(args.outdir + os.sep + 'compare_with_cpu.csv', 'w') as f:
        writer = csv.writer(f, lineterminator='\n')
        writer.writerows(result_array)

    logger.log(99, 'Compare with CPU Completed.')
    progress(None)
예제 #22
0
파일: __main__.py 프로젝트: varialus/bedup
def vol_cmd(args):
    if args.command == 'dedup-vol':
        sys.stderr.write(
            "The dedup-vol command is deprecated, please use dedup.\n")
        args.command = 'dedup'
        args.defrag = False
    elif args.command == 'reset' and not args.filter:
        sys.stderr.write("You need to list volumes explicitly.\n")
        return 1

    with ExitStack() as stack:
        tt = stack.enter_context(closing(TermTemplate()))
        # Adds about 1s to cold startup
        sess = get_session(args)
        whole_fs = WholeFS(sess, size_cutoff=args.size_cutoff)
        stack.enter_context(closing(whole_fs))

        if not args.filter:
            vols = whole_fs.load_all_writable_vols(tt)
        else:
            vols = OrderedDict()
            for filt in args.filter:
                if filt.startswith('vol:/'):
                    volpath = filt[4:]
                    try:
                        filt_vols = whole_fs.load_vols([volpath],
                                                       tt,
                                                       recurse=False)
                    except NotAVolume:
                        sys.stderr.write(
                            'Path doesn\'t point to a btrfs volume: %r\n' %
                            (volpath, ))
                        return 1
                elif filt.startswith('/'):
                    if os.path.realpath(filt).startswith('/dev/'):
                        filt_vols = whole_fs.load_vols_for_device(filt, tt)
                    else:
                        volpath = filt
                        try:
                            filt_vols = whole_fs.load_vols([volpath],
                                                           tt,
                                                           recurse=True)
                        except NotAVolume:
                            sys.stderr.write(
                                'Path doesn\'t point to a btrfs volume: %r\n' %
                                (volpath, ))
                            return 1
                else:
                    try:
                        uuid = UUID(hex=filt)
                    except ValueError:
                        sys.stderr.write('Filter format not recognised: %r\n' %
                                         filt)
                        return 1
                    filt_vols = whole_fs.load_vols_for_fs(
                        whole_fs.get_fs(uuid), tt)
                for vol in filt_vols:
                    vols[vol] = True

        # XXX should group by mountpoint instead.
        # Only a problem when called with volume names instead of an fs filter.
        vols_by_fs = defaultdict(list)

        if args.command == 'reset':
            for vol in vols:
                if user_confirmation(
                        'Reset tracking status of {}?'.format(vol), False):
                    reset_vol(sess, vol)
                    print('Reset of {} done'.format(vol))

        if args.command in ('scan', 'dedup'):
            set_idle_priority()
            for vol in vols:
                if args.flush:
                    tt.format('{elapsed} Flushing %s' % (vol, ))
                    syncfs(vol.fd)
                    tt.format(None)
                track_updated_files(sess, vol, tt)
                vols_by_fs[vol.fs].append(vol)

        if args.command == 'dedup':
            if args.groupby == 'vol':
                for vol in vols:
                    tt.notify('Deduplicating volume %s' % vol)
                    dedup_tracked(sess, [vol], tt, defrag=args.defrag)
            elif args.groupby == 'mpoint':
                for fs, volset in vols_by_fs.iteritems():
                    tt.notify('Deduplicating filesystem %s' % fs)
                    dedup_tracked(sess, volset, tt, defrag=args.defrag)
            else:
                assert False, args.groupby

        # For safety only.
        # The methods we call from the tracking module are expected to commit.
        sess.commit()
예제 #23
0
 def __enter__(self):
     self.stack = ExitStack()
     for tracker in self.trackers:
         self.stack.enter_context(tracker(self.client, self.metric))
예제 #24
0
파일: __main__.py 프로젝트: andgein/pony
from docopt import docopt
import json
import requests

from .util import resolve_name
from .diagram import db_to_diagram

import os
from pony.py23compat import PY2
if PY2:
    from contextlib2 import ExitStack
else:
    from contextlib import ExitStack

with ExitStack() as stack:
    if os.environ.get('DEBUG'):
        import ipdb
        stack.enter_context(ipdb.launch_ipdb_on_exception())

    opts = docopt(__doc__)
    if opts['--export']:
        raise NotImplementedError

    name = opts['<name>']
    path = opts['<path_to_db>']
    db = resolve_name(path)
    diagram = db_to_diagram(db)

    login = opts['--login']
    password = opts['--password']
예제 #25
0
    def ingest(name, environ=os.environ, timestamp=None):
        """Ingest data for a given bundle.

        Parameters
        ----------
        name : str
            The name of the bundle.
        environ : mapping, optional
            The environment variables. By default this is os.environ.
        timestamp : datetime, optional
            The timestamp to use for the load.
            By default this is the current time.
        """
        try:
            bundle = bundles[name]
        except KeyError:
            raise UnknownBundle(name)

        calendar = get_calendar(bundle.calendar_name)

        start_session = bundle.start_session
        end_session = bundle.end_session

        if start_session is None or start_session < calendar.first_session:
            start_session = calendar.first_session

        if end_session is None or end_session > calendar.last_session:
            end_session = calendar.last_session

        if timestamp is None:
            timestamp = pd.Timestamp.utcnow()
        timestamp = timestamp.tz_convert('utc').tz_localize(None)

        timestr = to_bundle_ingest_dirname(timestamp)
        cachepath = cache_path(name, environ=environ)
        pth.ensure_directory(pth.data_path([name, timestr], environ=environ))
        pth.ensure_directory(cachepath)
        with dataframe_cache(cachepath, clean_on_failure=False) as cache, \
                ExitStack() as stack:
            # we use `cleanup_on_failure=False` so that we don't purge the
            # cache directory if the load fails in the middle
            if bundle.create_writers:
                wd = stack.enter_context(
                    working_dir(pth.data_path([], environ=environ)))
                daily_bars_path = wd.ensure_dir(
                    *daily_equity_relative(name, timestr))
                daily_bar_writer = BcolzDailyBarWriter(
                    daily_bars_path,
                    calendar,
                    start_session,
                    end_session,
                )
                # Do an empty write to ensure that the daily ctables exist
                # when we create the SQLiteAdjustmentWriter below. The
                # SQLiteAdjustmentWriter needs to open the daily ctables so
                # that it can compute the adjustment ratios for the dividends.

                daily_bar_writer.write(())
                minute_bar_writer = BcolzMinuteBarWriter(
                    wd.ensure_dir(*minute_equity_relative(name, timestr)),
                    calendar,
                    start_session,
                    end_session,
                    minutes_per_day=bundle.minutes_per_day,
                )
                assets_db_path = wd.getpath(*asset_db_relative(name, timestr))
                asset_db_writer = AssetDBWriter(assets_db_path)

                adjustment_db_writer = stack.enter_context(
                    SQLiteAdjustmentWriter(
                        wd.getpath(*adjustment_db_relative(name, timestr)),
                        BcolzDailyBarReader(daily_bars_path),
                        overwrite=True,
                    ))
            else:
                daily_bar_writer = None
                minute_bar_writer = None
                asset_db_writer = None
                adjustment_db_writer = None

            bundle.ingest(
                environ,
                asset_db_writer,
                minute_bar_writer,
                daily_bar_writer,
                adjustment_db_writer,
                calendar,
                start_session,
                end_session,
                cache,
                pth.data_path([name, timestr], environ=environ),
            )
예제 #26
0
    def main(self, orings):
        for iseqs in izip(
                *
            [iring.read(guarantee=self.guarantee) for iring in self.irings]):
            if self.shutdown_event.is_set():
                break
            for i, iseq in enumerate(iseqs):
                self.sequence_proclogs[i].update(iseq.header)
            oheaders, islices = self._on_sequence(iseqs)
            for ohdr in oheaders:
                if 'time_tag' not in ohdr:
                    ohdr['time_tag'] = self._seq_count
            self._seq_count += 1

            # Allow passing None to mean slice(gulp_nframe)
            if islices is None:
                islices = [None] * len(self.irings)
            default_igulp_nframes = [
                self.gulp_nframe or iseq.header['gulp_nframe']
                for iseq in iseqs
            ]
            islices = [
                islice or slice(igulp_nframe)
                for (islice,
                     igulp_nframe) in zip(islices, default_igulp_nframes)
            ]

            islices = [_span_slice(slice_) for slice_ in islices]
            for iseq, islice in zip(iseqs, islices):
                if self.buffer_factor is None:
                    src_block = iseq.ring.owner
                    if src_block is not None and self.is_fused_with(src_block):
                        buffer_factor = 1
                    else:
                        buffer_factor = None
                else:
                    buffer_factor = self.buffer_factor
                iseq.resize(gulp_nframe=(islice.stop - islice.start),
                            buf_nframe=self.buffer_nframe,
                            buffer_factor=buffer_factor)

            igulp_nframes = [islice.stop - islice.start for islice in islices]

            with ExitStack() as oseq_stack:
                oseqs = self.begin_sequences(oseq_stack, orings, oheaders,
                                             igulp_nframes)
                prev_time = time.time()
                for ispans in izip(*[
                        iseq.read(islice.stop -
                                  islice.start, islice.step, islice.start)
                        for (iseq, islice) in zip(iseqs, islices)
                ]):
                    if self.shutdown_event.is_set():
                        break
                    cur_time = time.time()
                    acquire_time = cur_time - prev_time
                    prev_time = cur_time
                    with ExitStack() as ospan_stack:
                        ospans = self.reserve_spans(ospan_stack, oseqs, ispans)
                        cur_time = time.time()
                        reserve_time = cur_time - prev_time
                        prev_time = cur_time
                        # *TODO: See if can fuse together multiple on_data calls here before
                        #          calling stream_synchronize().
                        #        Consider passing .data instead of rings here
                        ostrides = self._on_data(ispans, ospans)
                        # TODO: // Default to not spinning the CPU: cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync);
                        bf.device.stream_synchronize()
                        # Allow returning None to indicate complete consumption
                        if ostrides is None:
                            ostrides = [ospan.nframe for ospan in ospans]
                        ostrides = [
                            ostride if ostride is not None else ospan.nframe
                            for (ostride, ospan) in zip(ostrides, ospans)
                        ]
                        for ospan, ostride in zip(ospans, ostrides):
                            ospan.commit(ostride)
                    cur_time = time.time()
                    process_time = cur_time - prev_time
                    prev_time = cur_time
                    self.perf_proclog.update({
                        'acquire_time': acquire_time,
                        'reserve_time': reserve_time,
                        'process_time': process_time
                    })
            self._on_sequence_end(iseqs)
예제 #27
0
    def transform(self):
        """
        Main generator work loop.
        """
        algo = self.algo

        def every_bar(dt_to_use, current_data=self.current_data,
                      handle_data=algo.event_manager.handle_data):
            # called every tick (minute or day).

            if dt_to_use in algo.capital_changes:
                process_minute_capital_changes(dt_to_use)

            self.simulation_dt = dt_to_use
            algo.on_dt_changed(dt_to_use)

            blotter = algo.blotter
            perf_tracker = algo.perf_tracker

            # handle any transactions and commissions coming out new orders
            # placed in the last bar
            new_transactions, new_commissions, closed_orders = \
                blotter.get_transactions(current_data)

            blotter.prune_orders(closed_orders)

            for transaction in new_transactions:
                perf_tracker.process_transaction(transaction)

                # since this order was modified, record it
                order = blotter.orders[transaction.order_id]
                perf_tracker.process_order(order)

            if new_commissions:
                for commission in new_commissions:
                    perf_tracker.process_commission(commission)

            handle_data(algo, current_data, dt_to_use)

            # grab any new orders from the blotter, then clear the list.
            # this includes cancelled orders.
            new_orders = blotter.new_orders
            blotter.new_orders = []

            # if we have any new orders, record them so that we know
            # in what perf period they were placed.
            if new_orders:
                for new_order in new_orders:
                    perf_tracker.process_order(new_order)

            self.algo.portfolio_needs_update = True
            self.algo.account_needs_update = True
            self.algo.performance_needs_update = True

        def once_a_day(midnight_dt, current_data=self.current_data,
                       data_portal=self.data_portal):

            perf_tracker = algo.perf_tracker

            if midnight_dt in algo.capital_changes:
                # process any capital changes that came overnight
                change = algo.capital_changes[midnight_dt]
                log.info('Processing capital change of %s at %s' %
                         (change, midnight_dt))
                perf_tracker.process_capital_changes(change, is_interday=True)

            # Get the positions before updating the date so that prices are
            # fetched for trading close instead of midnight
            positions = algo.perf_tracker.position_tracker.positions
            position_assets = algo.asset_finder.retrieve_all(positions)

            # set all the timestamps
            self.simulation_dt = midnight_dt
            algo.on_dt_changed(midnight_dt)

            # we want to wait until the clock rolls over to the next day
            # before cleaning up expired assets.
            self._cleanup_expired_assets(midnight_dt, position_assets)

            # handle any splits that impact any positions or any open orders.
            assets_we_care_about = \
                viewkeys(perf_tracker.position_tracker.positions) | \
                viewkeys(algo.blotter.open_orders)

            if assets_we_care_about:
                splits = data_portal.get_splits(assets_we_care_about,
                                                midnight_dt)
                if splits:
                    algo.blotter.process_splits(splits)
                    perf_tracker.position_tracker.handle_splits(splits)

            # call before trading start
            algo.before_trading_start(current_data)

        def handle_benchmark(date, benchmark_source=self.benchmark_source):
            algo.perf_tracker.all_benchmark_returns[date] = \
                benchmark_source.get_value(date)

        def on_exit():
            self.benchmark_source = self.current_data = self.data_portal = None

        with ExitStack() as stack:
            stack.callback(on_exit)
            stack.enter_context(self.processor)
            stack.enter_context(ZiplineAPI(self.algo))

            if algo.data_frequency == 'minute':
                def execute_order_cancellation_policy():
                    algo.blotter.execute_cancel_policy(DAY_END)

                def process_minute_capital_changes(dt):
                    # If we are running daily emission, prices won't
                    # necessarily be synced at the end of every minute, and we
                    # need the up-to-date prices for capital change
                    # calculations. We want to sync the prices as of the
                    # last market minute, and this is okay from a data portal
                    # perspective as we have technically not "advanced" to the
                    # current dt yet.
                    algo.perf_tracker.position_tracker.sync_last_sale_prices(
                        self.env.previous_market_minute(dt),
                        False,
                        self.data_portal
                    )

                    # process any capital changes that came between the last
                    # and current minutes
                    change = algo.capital_changes[dt]
                    log.info('Processing capital change of %s at %s' %
                             (change, dt))
                    algo.perf_tracker.process_capital_changes(
                        change,
                        is_interday=False
                    )
            else:
                def execute_order_cancellation_policy():
                    pass

                def process_minute_capital_changes(dt):
                    pass

            for dt, action in self.clock:
                if action == BAR:
                    every_bar(dt)
                elif action == DAY_START:
                    once_a_day(dt)
                elif action == DAY_END:
                    # End of the day.
                    if algo.perf_tracker.emission_rate == 'daily':
                        handle_benchmark(normalize_date(dt))
                    execute_order_cancellation_policy()

                    yield self._get_daily_message(dt, algo, algo.perf_tracker)
                elif action == MINUTE_END:
                    handle_benchmark(dt)
                    minute_msg = \
                        self._get_minute_message(dt, algo, algo.perf_tracker)

                    yield minute_msg

        risk_message = algo.perf_tracker.handle_simulation_end()
        yield risk_message
예제 #28
0
def settings(*args, **kwargs):
    """
    Nest context managers and/or override ``env`` variables.

    `settings` serves two purposes:

    * Most usefully, it allows temporary overriding/updating of ``env`` with
      any provided keyword arguments, e.g. ``with settings(user='******'):``.
      Original values, if any, will be restored once the ``with`` block closes.

        * The keyword argument ``clean_revert`` has special meaning for
          ``settings`` itself (see below) and will be stripped out before
          execution.

    * In addition, it will use `contextlib.nested`_ to nest any given
      non-keyword arguments, which should be other context managers, e.g.
      ``with settings(hide('stderr'), show('stdout')):``.

    .. _contextlib.nested: http://docs.python.org/library/contextlib.html#contextlib.nested

    These behaviors may be specified at the same time if desired. An example
    will hopefully illustrate why this is considered useful::

        def my_task():
            with settings(
                hide('warnings', 'running', 'stdout', 'stderr'),
                warn_only=True
            ):
                if run('ls /etc/lsb-release'):
                    return 'Ubuntu'
                elif run('ls /etc/redhat-release'):
                    return 'RedHat'

    The above task executes a `run` statement, but will warn instead of
    aborting if the ``ls`` fails, and all output -- including the warning
    itself -- is prevented from printing to the user. The end result, in this
    scenario, is a completely silent task that allows the caller to figure out
    what type of system the remote host is, without incurring the handful of
    output that would normally occur.

    Thus, `settings` may be used to set any combination of environment
    variables in tandem with hiding (or showing) specific levels of output, or
    in tandem with any other piece of Fabric functionality implemented as a
    context manager.

    If ``clean_revert`` is set to ``True``, ``settings`` will **not** revert
    keys which are altered within the nested block, instead only reverting keys
    whose values remain the same as those given. More examples will make this
    clear; below is how ``settings`` operates normally::

        # Before the block, env.parallel defaults to False, host_string to None
        with settings(parallel=True, host_string='myhost'):
            # env.parallel is True
            # env.host_string is 'myhost'
            env.host_string = 'otherhost'
            # env.host_string is now 'otherhost'
        # Outside the block:
        # * env.parallel is False again
        # * env.host_string is None again

    The internal modification of ``env.host_string`` is nullified -- not always
    desirable. That's where ``clean_revert`` comes in::

        # Before the block, env.parallel defaults to False, host_string to None
        with settings(parallel=True, host_string='myhost', clean_revert=True):
            # env.parallel is True
            # env.host_string is 'myhost'
            env.host_string = 'otherhost'
            # env.host_string is now 'otherhost'
        # Outside the block:
        # * env.parallel is False again
        # * env.host_string remains 'otherhost'

    Brand new keys which did not exist in ``env`` prior to using ``settings``
    are also preserved if ``clean_revert`` is active. When ``False``, such keys
    are removed when the block exits.

    .. versionadded:: 1.4.1
        The ``clean_revert`` kwarg.
    """
    managers = list(args)
    if kwargs:
        managers.append(_setenv(kwargs))
    with ExitStack() as stack:
        yield tuple(stack.enter_context(cm) for cm in managers)
예제 #29
0
    def transform(self):
        """
        Main generator work loop.
        """
        algo = self.algo
        metrics_tracker = algo.metrics_tracker
        emission_rate = metrics_tracker.emission_rate

        def every_bar(dt_to_use,
                      current_data=self.current_data,
                      handle_data=algo.event_manager.handle_data):
            for capital_change in calculate_minute_capital_changes(dt_to_use):
                yield capital_change

            self.simulation_dt = dt_to_use
            # called every tick (minute or day).
            algo.on_dt_changed(dt_to_use)

            blotter = algo.blotter

            # handle any transactions and commissions coming out new orders
            # placed in the last bar
            new_transactions, new_commissions, closed_orders = \
                blotter.get_transactions(current_data)

            blotter.prune_orders(closed_orders)

            for transaction in new_transactions:
                metrics_tracker.process_transaction(transaction)

                # since this order was modified, record it
                order = blotter.orders[transaction.order_id]
                metrics_tracker.process_order(order)

            for commission in new_commissions:
                metrics_tracker.process_commission(commission)

            handle_data(algo, current_data, dt_to_use)

            # grab any new orders from the blotter, then clear the list.
            # this includes cancelled orders.
            new_orders = blotter.new_orders
            blotter.new_orders = []

            # if we have any new orders, record them so that we know
            # in what perf period they were placed.
            for new_order in new_orders:
                metrics_tracker.process_order(new_order)

        def once_a_day(midnight_dt,
                       current_data=self.current_data,
                       data_portal=self.data_portal):
            # process any capital changes that came overnight
            for capital_change in algo.calculate_capital_changes(
                    midnight_dt, emission_rate=emission_rate,
                    is_interday=True):
                yield capital_change

            # set all the timestamps
            self.simulation_dt = midnight_dt
            algo.on_dt_changed(midnight_dt)

            metrics_tracker.handle_market_open(
                midnight_dt,
                algo.data_portal,
            )

            # handle any splits that impact any positions or any open orders.
            assets_we_care_about = (viewkeys(metrics_tracker.positions)
                                    | viewkeys(algo.blotter.open_orders))

            if assets_we_care_about:
                splits = data_portal.get_splits(assets_we_care_about,
                                                midnight_dt)
                if splits:
                    algo.blotter.process_splits(splits)
                    metrics_tracker.handle_splits(splits)

        def on_exit():
            # Remove references to algo, data portal, et al to break cycles
            # and ensure deterministic cleanup of these objects when the
            # simulation finishes.
            self.algo = None
            self.benchmark_source = self.current_data = self.data_portal = None

        with ExitStack() as stack:
            stack.callback(on_exit)
            stack.enter_context(self.processor)
            stack.enter_context(ZiplineAPI(self.algo))
            if algo.data_frequency == 'minute':

                def execute_order_cancellation_policy():
                    algo.blotter.execute_cancel_policy(SESSION_END)

                def calculate_minute_capital_changes(dt):
                    # process any capital changes that came between the last
                    # and current minutes
                    return algo.calculate_capital_changes(
                        dt, emission_rate=emission_rate, is_interday=False)
            else:

                def execute_order_cancellation_policy():
                    pass

                def calculate_minute_capital_changes(dt):
                    return []

            for dt, action in self.clock:
                print(dt, action)
                if action == BAR:
                    for capital_change_packet in every_bar(dt):
                        yield capital_change_packet
                elif action == SESSION_START:
                    for capital_change_packet in once_a_day(dt):
                        yield capital_change_packet
                elif action == SESSION_END:
                    # End of the session.
                    positions = metrics_tracker.positions
                    position_assets = algo.asset_finder.retrieve_all(positions)
                    self._cleanup_expired_assets(dt, position_assets)

                    execute_order_cancellation_policy()
                    algo.validate_account_controls()
                    yield self._get_daily_message(dt, algo, metrics_tracker)
                elif action == BEFORE_TRADING_START_BAR:
                    self.simulation_dt = dt
                    algo.on_dt_changed(dt)
                    algo.before_trading_start(self.current_data)
                elif action == MINUTE_END:
                    minute_msg = self._get_minute_message(
                        dt,
                        algo,
                        metrics_tracker,
                    )
                    yield minute_msg

            risk_message = metrics_tracker.handle_simulation_end(
                self.data_portal, )
            yield risk_message
예제 #30
0
    def transform(self, stream_in):
        """
        Main generator work loop.
        """
        # Initialize the mkt_close
        mkt_open = self.algo.perf_tracker.market_open
        mkt_close = self.algo.perf_tracker.market_close

        # inject the current algo
        # snapshot time to any log record generated.
        # with 。。。as 是一种上下文管理器,打开与关闭。exitstack()是一个语法糖

        with ExitStack() as stack:
            stack.enter_context(self.processor)
            stack.enter_context(ZiplineAPI(self.algo))

            data_frequency = self.sim_params.data_frequency
            self._call_before_trading_start(mkt_open)

            for date, snapshot in stream_in:
                #print date,u'在主循环之内的date',self.algo_start
                #raw_input()
                #for i in snapshot:
                #    print i
                #raw_input()

                # snapshot,为迭代的数据系统,包括时间,股票数据等

                # 进入主循环,跟随日期进行循环
                self.simulation_dt = date  #模拟日期
                self.on_dt_changed(date)

                # If we're still in the warmup period.  Use the event to
                # update our universe, but don't yield any perf messages,
                # and don't send a snapshot to handle_data.
                # 如果在热身阶段
                # 判断是否进入交易日期,若开始了则handle_data发送给
                # 判断是否进行到达模拟开始的时间
                if date < self.algo_start:
                    for event in snapshot:
                        if event.type == DATASOURCE_TYPE.SPLIT:
                            self.algo.blotter.process_split(event)
                        elif event.type == DATASOURCE_TYPE.TRADE:
                            self.update_universe(event)
                            self.algo.perf_tracker.process_trade(event)
                        elif event.type == DATASOURCE_TYPE.CUSTOM:
                            self.update_universe(event)
                    if self.algo.history_container:
                        #print self.current_data
                        self.algo.history_container.update(
                            self.current_data, date)
                else:

                    # 进入每日信息的处理,
                    messages = self._process_snapshot(
                        date,
                        snapshot,
                        self.algo.instant_fill,
                    )
                    # Perf messages are only emitted if the snapshot contained
                    # a benchmark event.
                    for message in messages:
                        yield message

                    # When emitting minutely, we need to call
                    # before_trading_start before the next trading day begins
                    if date == mkt_close:
                        if mkt_close <= self.algo.perf_tracker.last_close:
                            before_last_close = \
                                mkt_close < self.algo.perf_tracker.last_close
                            try:
                                mkt_open, mkt_close = \
                                    self.env.next_open_and_close(mkt_close)

                            except NoFurtherDataError:
                                # If at the end of backtest history,
                                # skip advancing market close.
                                pass

                            if before_last_close:
                                self._call_before_trading_start(mkt_open)

                    elif data_frequency == 'daily':
                        next_day = self.env.next_trading_day(date)

                        if next_day is not None and \
                           next_day < self.algo.perf_tracker.last_close:
                            self._call_before_trading_start(
                                next_day
                            )  #如果下一天非空,并且next不是表现的最后一天。last_close,就
                    self.algo.portfolio_needs_update = True
                    self.algo.account_needs_update = True
                    self.algo.performance_needs_update = True
            risk_message = self.algo.perf_tracker.handle_simulation_end()
            yield risk_message